In [16]:
import os
import json
import torch
import random
from pathlib import Path
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class myDataset(Dataset):
    def __init__(self, data_dir, segment_len = 128):
        self.data_dir = data_dir
        self.segment_len = segment_len
        
        mapping_path = Path(data_dir) / 'mapping.json'
        mapping = json.load(mapping_path.open())
        self.speaker2id = mapping['speaker2id']
        
        meta_path = Path(data_dir) / 'metadata.json'
        metadata = json.load(open(meta_path))
        speakers = metadata['speakers']
        self.speaker_num = len(speakers.keys())
        
        self.data = []
        
        for speaker in speakers.keys():
            for utterances in speakers[speaker]:
                feature_path = utterances['feature_path']
                self.data.append([feature_path, self.speaker2id[speaker]])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        feat_path, speaker = self.data[index]
        mel = torch.load(os.path.join(self.data_dir, feat_path))
        
        if len(mel) >= self.segment_len:
            start = random.randint(0, len(mel) - self.segment_len + 1)
            mel = torch.FloatTensor(mel[start:start+self.segment_len])
        else:
            mel = torch.FloatTensor(mel)
        
        speaker = torch.FloatTensor([speaker]).long()
        
        return mel, speaker
    
    def get_speaker_number(self):
        return self.speaker_num

In [27]:
import torch
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

def collate_batch(batch):
    mel, speaker = zip(*batch)
    mel = pad_sequence(mel, batch_first = True, padding_value = -20)
    return mel, torch.FloatTensor(speaker).long()

def get_dataloader(data_dir, batch_size, n_workers):
    dataset = myDataset(data_dir)
    speaker_num = dataset.get_speaker_number()
    trainlen = int(0.9 * len(dataset))
    length = [trainlen, len(dataset) - trainlen]
    trainset, validset = random_split(dataset, length)
    
    train_loader = DataLoader(
        trainset,
        batch_size = batch_size,
        shuffle = True,
        drop_last = True,
        num_workers = n_workers,
        pin_memory = True,
        collate_fn = collate_batch
    )
    
    valid_loader = DataLoader(
        validset,
        batch_size = batch_size,
        shuffle = False,
        drop_last = True,
        num_workers = n_workers,
        pin_memory = True,
        collate_fn = collate_batch
    )
    
    return train_loader, valid_loader, speaker_num

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, d_model = 80, n_spks = 600, dropout = 0.1):
        super(Classifier, self).__init__()
        
        self.prenet = nn.Linear(40, d_model)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model = d_model, dim_feedforward = 256, nhead = 2
        )
        self.pred_layer = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, n_spks)
        )
    
    def forward(self, mels):
        out = self.prenet(mels)
        out = out.permute(1, 0, 2)
        out = self.encoder_layer(out)
        out = out.transpose(0, 1)
        stats = out.mean(dim = 1)
        out = self.pred_layer(stats)
        return out

In [19]:
import math
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float = 0.5,
    last_epoch: int = -1
):
    
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        else:
            progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * 2.0 * float(num_cycles) * progress)))
    
    return LambdaLR(optimizer, lr_lambda, last_epoch)

In [20]:
import torch

def model_fn(batch, model, criterion, device):
    mels, labels = batch
    mels, labels = mels.to(device), labels.to(device)
    outs = model(mels)
    
    loss = criterion(outs, labels)
    preds = outs.argmax(1)
    
    accuracy = (preds.cpu() == labels.cpu()).float().mean()
    
    return loss, accuracy

In [21]:
import torch
from tqdm import tqdm

def valid(dataloader, model, criterion, device):
    model.eval()
    pbar = tqdm(total = len(dataloader.dataset), ncols = 0, desc = 'Valid', unit = 'uttr')
    running_loss = 0.0
    running_acc = 0.0
    
    for i, batch in enumerate(dataloader):
        with torch.no_grad():
            loss, accuracy = model_fn(batch, model, criterion, device)
        running_loss += loss.item()
        running_acc += accuracy.item()
        
        pbar.update(dataloader.batch_size)
        pbar.set_postfix(
            loss = f"{running_loss / (i + 1):.2f}",
            accuracy = f"{running_acc / (i + 1):.2f}"
        )
    
    pbar.close()
    model.train()
    return running_acc / len(dataloader)

In [22]:
def parse_args():
    config = {
        'data_dir': './Dataset',
        'model_path': 'model.ckpt',
        'batch_size': 32,
        'n_workers': 0,
        'valid_steps': 2000,
        'warmup_steps': 1000,
        'save_steps': 10000,
        'total_steps': 70000,
    }
    
    return config

In [23]:
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split

def main(data_dir, model_path, batch_size, n_workers, valid_steps, warmup_steps, save_steps, total_steps):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"[Info]: Use {device} now!")
    
    pbar = tqdm(total = valid_steps, ncols = 0, desc = 'Train', unit = 'step')
    
    train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)
    train_iterator = iter(train_loader)
    print(f"[Info]: Finish loading data!",flush = True)
    
    model = Classifier(n_spks = speaker_num).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr = 1e-3)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
    print(f"[Info]: Finish creating model!",flush = True)
    
    best_accuracy = 0.0
    best_state_dict = None
    
    model.train()
    for step in range(total_steps):
        try:
            batch = next(train_iterator)
        except StopIteration:
            train_iterator = iter(train_loader)
            batch = next(train_iterator)
        
        optimizer.zero_grad()
        loss, accuracy = model_fn(batch, model, criterion, device)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        pbar.update()
        pbar.set_postfix(
            loss = f"{loss.item():.2f}",
            accuracy = f"{accuracy.item():.2f}",
            step = step + 1,
        )
        
        if (step + 1) % valid_steps == 0:
            pbar.close()
            valid_accuracy = valid(valid_loader, model, criterion, device)
            
            if valid_accuracy > best_accuracy:
                best_accuracy = valid_accuracy
                best_state_dict = model.state_dict()
            
            pbar = tqdm(total = valid_steps, ncols = 0, desc = 'Train', unit = 'step')
        
        if (step + 1) % save_steps == 0 and best_state_dict is not None:
            torch.save(best_state_dict, model_path)
            pbar.write(f"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})")
    
    pbar.close()

In [28]:
if __name__ == '__main__':
    main(**parse_args())

[Info]: Use cuda now!


Train:   0% 0/2000 [01:47<?, ?step/s]

[Info]: Finish loading data!
[Info]: Finish creating model!



Train: 100% 2000/2000 [01:06<00:00, 30.08step/s, accuracy=0.31, loss=3.46, step=2000]
Valid: 100% 6944/6944 [00:05<00:00, 1265.58uttr/s, accuracy=0.20, loss=3.95]
Train: 100% 2000/2000 [00:41<00:00, 47.86step/s, accuracy=0.25, loss=3.26, step=4000]
Valid: 100% 6944/6944 [00:04<00:00, 1389.65uttr/s, accuracy=0.30, loss=3.25]
Train: 100% 2000/2000 [00:41<00:00, 48.24step/s, accuracy=0.47, loss=3.36, step=6000]
Valid: 100% 6944/6944 [00:03<00:00, 1757.07uttr/s, accuracy=0.36, loss=2.92]
Train: 100% 2000/2000 [00:35<00:00, 55.99step/s, accuracy=0.28, loss=3.06, step=8000]
Valid: 100% 6944/6944 [00:02<00:00, 2681.94uttr/s, accuracy=0.40, loss=2.69]
Train: 100% 2000/2000 [00:34<00:00, 57.80step/s, accuracy=0.62, loss=1.82, step=1e+4]
Valid: 100% 6944/6944 [00:02<00:00, 2653.05uttr/s, accuracy=0.44, loss=2.54]
                                     

Train:   0% 0/2000 [00:00<?, ?step/s]

Train:   1% 11/2000 [00:00<00:36, 54.69step/s, accuracy=0.34, loss=2.88, step=1e+4]

Step 10000, best model saved. (accuracy=0.4359)


Train: 100% 2000/2000 [00:35<00:00, 56.99step/s, accuracy=0.38, loss=2.73, step=12000]
Valid: 100% 6944/6944 [00:02<00:00, 2634.35uttr/s, accuracy=0.45, loss=2.46]
Train: 100% 2000/2000 [00:34<00:00, 57.57step/s, accuracy=0.50, loss=2.10, step=14000]
Valid: 100% 6944/6944 [00:02<00:00, 2651.18uttr/s, accuracy=0.48, loss=2.31]
Train: 100% 2000/2000 [00:35<00:00, 56.90step/s, accuracy=0.50, loss=2.47, step=16000]
Valid: 100% 6944/6944 [00:02<00:00, 2666.68uttr/s, accuracy=0.50, loss=2.21]
Train: 100% 2000/2000 [00:34<00:00, 57.42step/s, accuracy=0.53, loss=2.39, step=18000]
Valid: 100% 6944/6944 [00:02<00:00, 2548.42uttr/s, accuracy=0.51, loss=2.19]
Train: 100% 2000/2000 [00:35<00:00, 55.67step/s, accuracy=0.47, loss=2.54, step=2e+4] 
Valid: 100% 6944/6944 [00:02<00:00, 2623.94uttr/s, accuracy=0.53, loss=2.09]
                                     

Train:   0% 0/2000 [00:00<?, ?step/s]

Train:   0% 10/2000 [00:00<00:36, 54.66step/s, accuracy=0.53, loss=2.46, step=2e+4]

Step 20000, best model saved. (accuracy=0.5252)


Train: 100% 2000/2000 [00:35<00:00, 56.32step/s, accuracy=0.66, loss=1.36, step=22000]
Valid: 100% 6944/6944 [00:02<00:00, 2634.59uttr/s, accuracy=0.52, loss=2.08]
Train: 100% 2000/2000 [00:35<00:00, 55.71step/s, accuracy=0.59, loss=1.67, step=24000]
Valid: 100% 6944/6944 [00:02<00:00, 2564.57uttr/s, accuracy=0.55, loss=1.96]
Train: 100% 2000/2000 [00:35<00:00, 56.95step/s, accuracy=0.59, loss=1.71, step=26000]
Valid: 100% 6944/6944 [00:02<00:00, 2651.62uttr/s, accuracy=0.55, loss=1.93]
Train: 100% 2000/2000 [00:35<00:00, 56.72step/s, accuracy=0.53, loss=2.23, step=28000]
Valid: 100% 6944/6944 [00:02<00:00, 2619.18uttr/s, accuracy=0.57, loss=1.88]
Train: 100% 2000/2000 [00:34<00:00, 57.20step/s, accuracy=0.50, loss=2.05, step=3e+4] 
Valid: 100% 6944/6944 [00:02<00:00, 2558.59uttr/s, accuracy=0.57, loss=1.85]
                                     

Train:   0% 0/2000 [00:00<?, ?step/s]

Train:   0% 10/2000 [00:00<00:36, 54.66step/s, accuracy=0.56, loss=1.88, step=3e+4]

Step 30000, best model saved. (accuracy=0.5730)


Train: 100% 2000/2000 [00:35<00:00, 56.55step/s, accuracy=0.53, loss=1.55, step=32000]
Valid: 100% 6944/6944 [00:02<00:00, 2655.08uttr/s, accuracy=0.58, loss=1.82]
Train: 100% 2000/2000 [00:35<00:00, 56.87step/s, accuracy=0.59, loss=1.52, step=34000]
Valid: 100% 6944/6944 [00:02<00:00, 2611.17uttr/s, accuracy=0.60, loss=1.76]
Train: 100% 2000/2000 [00:35<00:00, 55.70step/s, accuracy=0.59, loss=2.00, step=36000]
Valid: 100% 6944/6944 [00:02<00:00, 2625.93uttr/s, accuracy=0.59, loss=1.73]
Train: 100% 2000/2000 [00:35<00:00, 56.06step/s, accuracy=0.56, loss=1.68, step=38000]
Valid: 100% 6944/6944 [00:02<00:00, 2618.89uttr/s, accuracy=0.62, loss=1.63]
Train: 100% 2000/2000 [00:36<00:00, 55.20step/s, accuracy=0.50, loss=2.15, step=4e+4] 
Valid: 100% 6944/6944 [00:02<00:00, 2526.36uttr/s, accuracy=0.62, loss=1.63]
                                     

Train:   0% 0/2000 [00:00<?, ?step/s]

Train:   0% 10/2000 [00:00<00:41, 48.07step/s, accuracy=0.66, loss=0.95, step=4e+4]

Step 40000, best model saved. (accuracy=0.6192)


Train: 100% 2000/2000 [00:35<00:00, 56.16step/s, accuracy=0.72, loss=1.00, step=42000]
Valid: 100% 6944/6944 [00:02<00:00, 2575.53uttr/s, accuracy=0.62, loss=1.60]
Train: 100% 2000/2000 [00:35<00:00, 55.66step/s, accuracy=0.69, loss=1.22, step=44000]
Valid: 100% 6944/6944 [00:02<00:00, 2589.27uttr/s, accuracy=0.64, loss=1.53]
Train: 100% 2000/2000 [00:35<00:00, 56.16step/s, accuracy=0.53, loss=1.82, step=46000]
Valid: 100% 6944/6944 [00:02<00:00, 2602.79uttr/s, accuracy=0.64, loss=1.54]
Train: 100% 2000/2000 [00:36<00:00, 55.32step/s, accuracy=0.62, loss=1.58, step=48000]
Valid: 100% 6944/6944 [00:02<00:00, 2593.12uttr/s, accuracy=0.65, loss=1.49]
Train: 100% 2000/2000 [00:35<00:00, 55.99step/s, accuracy=0.59, loss=1.57, step=5e+4] 
Valid: 100% 6944/6944 [00:02<00:00, 2587.60uttr/s, accuracy=0.65, loss=1.49]
                                     

Train:   0% 0/2000 [00:00<?, ?step/s]

Train:   1% 11/2000 [00:00<00:36, 54.66step/s, accuracy=0.72, loss=1.22, step=5e+4]

Step 50000, best model saved. (accuracy=0.6547)


Train: 100% 2000/2000 [00:35<00:00, 56.66step/s, accuracy=0.72, loss=0.99, step=52000]
Valid: 100% 6944/6944 [00:02<00:00, 2600.92uttr/s, accuracy=0.66, loss=1.46]
Train: 100% 2000/2000 [00:35<00:00, 56.22step/s, accuracy=0.66, loss=1.41, step=54000]
Valid: 100% 6944/6944 [00:02<00:00, 2588.81uttr/s, accuracy=0.67, loss=1.43]
Train: 100% 2000/2000 [00:35<00:00, 55.87step/s, accuracy=0.75, loss=1.18, step=56000]
Valid: 100% 6944/6944 [00:02<00:00, 2598.69uttr/s, accuracy=0.66, loss=1.43]
Train: 100% 2000/2000 [00:35<00:00, 56.21step/s, accuracy=0.78, loss=0.73, step=58000]
Valid: 100% 6944/6944 [00:02<00:00, 2586.69uttr/s, accuracy=0.67, loss=1.43]
Train: 100% 2000/2000 [00:36<00:00, 54.39step/s, accuracy=0.69, loss=1.15, step=6e+4] 
Valid: 100% 6944/6944 [00:02<00:00, 2508.63uttr/s, accuracy=0.68, loss=1.38]
                                     

Train:   0% 0/2000 [00:00<?, ?step/s]

Train:   0% 10/2000 [00:00<00:43, 45.55step/s, accuracy=0.75, loss=0.88, step=6e+4]

Step 60000, best model saved. (accuracy=0.6779)


Train: 100% 2000/2000 [00:36<00:00, 54.08step/s, accuracy=0.59, loss=1.38, step=62000]
Valid: 100% 6944/6944 [00:02<00:00, 2503.40uttr/s, accuracy=0.67, loss=1.39]
Train: 100% 2000/2000 [00:36<00:00, 54.43step/s, accuracy=0.66, loss=1.42, step=64000]
Valid: 100% 6944/6944 [00:02<00:00, 2514.68uttr/s, accuracy=0.68, loss=1.37]
Train: 100% 2000/2000 [00:36<00:00, 54.50step/s, accuracy=0.81, loss=0.86, step=66000]
Valid: 100% 6944/6944 [00:02<00:00, 2529.61uttr/s, accuracy=0.68, loss=1.39]
Train: 100% 2000/2000 [00:36<00:00, 54.83step/s, accuracy=0.81, loss=0.89, step=68000]
Valid: 100% 6944/6944 [00:02<00:00, 2486.00uttr/s, accuracy=0.68, loss=1.37]
Train: 100% 2000/2000 [00:36<00:00, 54.67step/s, accuracy=0.66, loss=1.34, step=7e+4] 
Valid: 100% 6944/6944 [00:02<00:00, 2514.71uttr/s, accuracy=0.68, loss=1.37]
                                     

Train:   0% 0/2000 [00:00<?, ?step/s]

Train:   0% 0/2000 [00:00<?, ?step/s][A[A


Step 70000, best model saved. (accuracy=0.6835)


In [30]:
import os
import torch
import json
from pathlib import Path
from torch.utils.data import Dataset

class InferenceDataset(Dataset):
    def __init__(self, data_dir):
        test_path = Path(data_dir) / "testdata.json"
        testdata = json.load(test_path.open())
        self.data_dir = data_dir
        self.data = testdata['utterances']
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        feat_path = self.data[index]['feature_path']
        mel = torch.load(os.path.join(self.data_dir, feat_path))
        return feat_path, mel

In [31]:
def inference_collate_batch(batch):
    feat_paths, mels = zip(*batch)
    return feat_paths, torch.stack(mels)

In [32]:
def parse_inference_args():
    config = {
        'data_dir': './Dataset',
        'model_path': './model.ckpt',
        'output_path': './output.csv',
    }
    return config

In [33]:
import torch
import csv
import json
from tqdm.notebook import tqdm
from pathlib import Path
from torch.utils.data import DataLoader

def inference_main(data_dir, model_path, output_path):
    