In [18]:
import os
import json
import torch
import random
from pathlib import Path
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class myDataset(Dataset):
    def __init__(self, data_dir, segment_len = 128):
        self.data_dir = data_dir
        self.segment_len = segment_len
        
        mapping_path = Path(data_dir) / 'mapping.json'
        mapping = json.load(mapping_path.open())
        self.speaker2id = mapping['speaker2id']
        
        meta_path = Path(data_dir) / 'metadata.json'
        metadata = json.load(open(meta_path))
        speakers = metadata['speakers']
        self.speaker_num = len(speakers.keys())
        
        self.data = []
        
        for speaker in speakers.keys():
            for utterances in speakers[speaker]:
                feature_path = utterances['feature_path']
                self.data.append([feature_path, self.speaker2id[speaker]])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        feat_path, speaker = self.data[index]
        mel = torch.load(os.path.join(self.data_dir, feat_path))
        
        if len(mel) >= self.segment_len:
            start = random.randint(0, len(mel) - self.segment_len + 1)
            mel = torch.FloatTensor(mel[start:start+self.segment_len])
        else:
            mel = torch.FloatTensor(mel)
        
        speaker = torch.FloatTensor([speaker]).long()
        
        return mel, speaker
    
    def get_speaker_number(self):
        return self.speaker_num

In [19]:
import torch
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

def collate_batch(batch):
    mel, speaker = zip(*batch)
    mel = pad_sequence(mel, batch_first = True, padding_value = -20)
    return mel, torch.FloatTensor(speaker).long()

def get_dataloader(data_dir, batch_size, n_workers):
    dataset = myDataset(data_dir)
    speaker_num = dataset.get_speaker_number()
    trainlen = int(0.9 * len(dataset))
    length = [trainlen, len(dataset) - trainlen]
    trainset, validset = random_split(dataset, length)
    
    train_loader = DataLoader(
        trainset,
        batch_size = batch_size,
        shuffle = True,
        drop_last = True,
        num_workers = n_workers,
        pin_memory = True,
        collate_fn = collate_batch
    )
    
    valid_loader = DataLoader(
        validset,
        batch_size = batch_size,
        shuffle = False,
        drop_last = True,
        num_workers = n_workers,
        pin_memory = True,
        collate_fn = collate_batch
    )
    
    return train_loader, valid_loader, speaker_num

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AMSoftmax(nn.Module):
    def __init__(self, in_features, n_classes, s = 30, m = 0.4):
        super(AMSoftmax, self).__init__()
        self.linear = nn.Linear(in_features, n_classes, bias = False)
        self.s = s
        self.m = m
    
    def forward(self, *inputs):
        x_vector = F.normalize(inputs[0], p = 2, dim = -1)
        self.linear.weight.data = F.normalize(self.linear.weight.data, p = 2, dim = -1)
        logits = self.linear(x_vector)
        scaled_logits = (logits - self.m) * self.s
        return scaled_logits - self._am_logsumexp(logits)
    
    def _am_logsumexp(self, logits):
        max_x = torch.max(logits, dim = -1)[0].unsqueeze(-1)
        term1 = (self.s * (logits - (max_x + self.m))).exp()
        term2 = (self.s * (logits - max_x)).exp().sum(-1).unsqueeze(-1) - (self.s * (logits - max_x)).exp()
        return self.s * max_x + (term2 + term1).log()

In [21]:
class Self_Attentive_Pooling(nn.Module):
    def __init__(self, dim):
        super(Self_Attentive_Pooling, self).__init__()
        self.sap_linear = nn.Linear(dim, dim)
        self.attention = nn.Parameter(torch.FloatTensor(dim, 1))
    
    def forward(self, x):
        x = x.permute(0, 2, 1)
        h = torch.tanh(self.sap_linear(x))
        w = torch.matmul(h, self.attention).squeeze(dim = 2)
        w = F.softmax(w, dim = 1).view(x.size(0), x.size(1), 1)
        x = torch.sum(x * w, dim = 1)
        return x

In [22]:
class FocalSoftmax(nn.Module):
    def __init__(self, gamma = 2):
        super(FocalSoftmax, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()
    
    def forward(self, inputs, target):
        logp = self.ce(inputs, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from conformer import ConformerBlock

class Classifier(nn.Module):
    def __init__(self, d_model = 512, n_spks = 600, dropout = 0.1):
        super(Classifier, self).__init__()
        
        self.prenet = nn.Linear(40, d_model)
        # self.encoder_layer = nn.TransformerEncoderLayer(
        #     d_model = d_model, dim_feedforward = 256, nhead = 2
        # )
        self.conformer_block = ConformerBlock(
            dim = d_model,
            dim_head = 64,
            heads = 8,
            ff_mult = 4,
            conv_expansion_factor = 2,
            conv_kernel_size = 31,
            attn_dropout = dropout,
            ff_dropout = dropout,
            conv_dropout = dropout
        )
        
        self.pooling = Self_Attentive_Pooling(d_model)
        # self.pred_layer = nn.Sequential(
        #     nn.Linear(d_model, d_model),
        #     nn.ReLU(),
        #     nn.Linear(d_model, n_spks)
        # )
        self.pred_layer = AMSoftmax(in_features = d_model, n_classes = n_spks)
    
    def forward(self, mels):
        out = self.prenet(mels)
        out = out.permute(1, 0, 2)
        out = self.conformer_block(out)
        out = out.permute(1, 2, 0)
        stats = self.pooling(out)
        out = self.pred_layer(stats)
        return out

In [24]:
import math
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float = 0.5,
    last_epoch: int = -1
):
    
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        else:
            progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * 2.0 * float(num_cycles) * progress)))
    
    return LambdaLR(optimizer, lr_lambda, last_epoch)

In [25]:
import torch

def model_fn(batch, model, criterion, device):
    mels, labels = batch
    mels, labels = mels.to(device), labels.to(device)
    outs = model(mels)
    
    loss = criterion(outs, labels)
    preds = outs.argmax(1)
    
    accuracy = (preds.cpu() == labels.cpu()).float().mean()
    
    return loss, accuracy

In [26]:
import torch
from tqdm import tqdm

def valid(dataloader, model, criterion, device):
    model.eval()
    pbar = tqdm(total = len(dataloader.dataset), ncols = 0, desc = 'Valid', unit = 'uttr')
    running_loss = 0.0
    running_acc = 0.0
    
    for i, batch in enumerate(dataloader):
        with torch.no_grad():
            loss, accuracy = model_fn(batch, model, criterion, device)
        running_loss += loss.item()
        running_acc += accuracy.item()
        
        pbar.update(dataloader.batch_size)
        pbar.set_postfix(
            loss = f"{running_loss / (i + 1):.2f}",
            accuracy = f"{running_acc / (i + 1):.2f}"
        )
    
    pbar.close()
    model.train()
    return running_acc / len(dataloader)

In [27]:
def parse_args():
    config = {
        'data_dir': './Dataset',
        'model_path': 'model.ckpt',
        'batch_size': 32,
        'n_workers': 0,
        'valid_steps': 2000,
        'warmup_steps': 1000,
        'save_steps': 10000,
        'total_steps': 100000,
    }
    
    return config

In [28]:
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split

def main(data_dir, model_path, batch_size, n_workers, valid_steps, warmup_steps, save_steps, total_steps):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"[Info]: Use {device} now!")
    
    pbar = tqdm(total = valid_steps, ncols = 0, desc = 'Train', unit = 'step')
    
    train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)
    train_iterator = iter(train_loader)
    print(f"[Info]: Finish loading data!",flush = True)
    
    model = Classifier(n_spks = speaker_num).to(device)
    criterion = FocalSoftmax()
    optimizer = AdamW(model.parameters(), lr = 1e-3, weight_decay = 1e-4)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
    print(f"[Info]: Finish creating model!",flush = True)
    
    best_accuracy = 0.0
    best_state_dict = None
    
    model.train()
    for step in range(total_steps):
        try:
            batch = next(train_iterator)
        except StopIteration:
            train_iterator = iter(train_loader)
            batch = next(train_iterator)
        
        optimizer.zero_grad()
        loss, accuracy = model_fn(batch, model, criterion, device)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        pbar.update()
        pbar.set_postfix(
            loss = f"{loss.item():.2f}",
            accuracy = f"{accuracy.item():.2f}",
            step = step + 1,
        )
        
        if (step + 1) % valid_steps == 0:
            pbar.close()
            valid_accuracy = valid(valid_loader, model, criterion, device)
            
            if valid_accuracy > best_accuracy:
                best_accuracy = valid_accuracy
                best_state_dict = model.state_dict()
            
            pbar = tqdm(total = valid_steps, ncols = 0, desc = 'Train', unit = 'step')
        
        if (step + 1) % save_steps == 0 and best_state_dict is not None:
            torch.save(best_state_dict, model_path)
            pbar.write(f"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})")
    
    pbar.close()

In [29]:
if __name__ == '__main__':
    main(**parse_args())

[Info]: Use cuda now!


Train:   0% 0/2000 [00:00<?, ?step/s]

[Info]: Finish loading data!
[Info]: Finish creating model!


Train: 100% 2000/2000 [02:52<00:00, 11.57step/s, accuracy=0.31, loss=2.52, step=2000]
Valid: 100% 6944/6944 [00:09<00:00, 770.04uttr/s, accuracy=0.31, loss=3.01]
Train: 100% 2000/2000 [02:50<00:00, 11.76step/s, accuracy=0.38, loss=2.61, step=4000]
Valid: 100% 6944/6944 [00:09<00:00, 771.37uttr/s, accuracy=0.45, loss=2.05]
Train: 100% 2000/2000 [02:46<00:00, 12.04step/s, accuracy=0.75, loss=0.47, step=6000]
Valid: 100% 6944/6944 [00:09<00:00, 768.71uttr/s, accuracy=0.53, loss=1.56]
Train: 100% 2000/2000 [02:42<00:00, 12.29step/s, accuracy=0.53, loss=1.14, step=8000]
Valid: 100% 6944/6944 [00:08<00:00, 776.81uttr/s, accuracy=0.59, loss=1.28]
Train: 100% 2000/2000 [02:42<00:00, 12.29step/s, accuracy=0.72, loss=0.60, step=1e+4]
Valid: 100% 6944/6944 [00:08<00:00, 778.85uttr/s, accuracy=0.63, loss=1.03]
Train:   0% 2/2000 [00:00<03:20,  9.99step/s, accuracy=0.56, loss=1.42, step=1e+4]

Step 10000, best model saved. (accuracy=0.6300)


Train: 100% 2000/2000 [02:43<00:00, 12.27step/s, accuracy=0.91, loss=0.09, step=12000]
Valid: 100% 6944/6944 [00:08<00:00, 774.09uttr/s, accuracy=0.65, loss=0.94]
Train: 100% 2000/2000 [02:44<00:00, 12.15step/s, accuracy=0.69, loss=0.64, step=14000]
Valid: 100% 6944/6944 [00:09<00:00, 695.85uttr/s, accuracy=0.66, loss=0.90]
Train: 100% 2000/2000 [02:47<00:00, 11.94step/s, accuracy=0.72, loss=0.50, step=16000]
Valid: 100% 6944/6944 [00:06<00:00, 1048.88uttr/s, accuracy=0.68, loss=0.77]
Train: 100% 2000/2000 [02:42<00:00, 12.28step/s, accuracy=0.81, loss=0.41, step=18000]
Valid: 100% 6944/6944 [00:06<00:00, 1061.71uttr/s, accuracy=0.71, loss=0.64]
Train: 100% 2000/2000 [02:42<00:00, 12.27step/s, accuracy=0.84, loss=0.09, step=2e+4] 
Valid: 100% 6944/6944 [00:06<00:00, 1059.02uttr/s, accuracy=0.72, loss=0.59]
Train:   0% 2/2000 [00:00<03:39,  9.12step/s, accuracy=0.81, loss=0.27, step=2e+4]

Step 20000, best model saved. (accuracy=0.7209)


Train: 100% 2000/2000 [02:43<00:00, 12.24step/s, accuracy=0.78, loss=0.17, step=22000]
Valid: 100% 6944/6944 [00:06<00:00, 1056.37uttr/s, accuracy=0.74, loss=0.54]
Train: 100% 2000/2000 [02:43<00:00, 12.24step/s, accuracy=0.81, loss=0.12, step=24000]
Valid: 100% 6944/6944 [00:06<00:00, 1057.54uttr/s, accuracy=0.74, loss=0.52]
Train: 100% 2000/2000 [02:43<00:00, 12.23step/s, accuracy=0.88, loss=0.13, step=26000]
Valid: 100% 6944/6944 [00:06<00:00, 1061.52uttr/s, accuracy=0.75, loss=0.49]
Train: 100% 2000/2000 [02:43<00:00, 12.26step/s, accuracy=0.88, loss=0.09, step=28000]
Valid: 100% 6944/6944 [00:06<00:00, 1058.91uttr/s, accuracy=0.75, loss=0.49]
Train: 100% 2000/2000 [02:43<00:00, 12.21step/s, accuracy=0.91, loss=0.15, step=3e+4] 
Valid: 100% 6944/6944 [00:06<00:00, 1034.14uttr/s, accuracy=0.77, loss=0.42]
Train:   0% 2/2000 [00:00<02:52, 11.56step/s, accuracy=0.81, loss=0.40, step=3e+4]

Step 30000, best model saved. (accuracy=0.7668)


Train: 100% 2000/2000 [02:45<00:00, 12.11step/s, accuracy=0.88, loss=0.11, step=32000]
Valid: 100% 6944/6944 [00:06<00:00, 1015.11uttr/s, accuracy=0.78, loss=0.39]
Train: 100% 2000/2000 [02:44<00:00, 12.16step/s, accuracy=0.84, loss=0.05, step=34000]
Valid: 100% 6944/6944 [00:06<00:00, 1051.63uttr/s, accuracy=0.78, loss=0.37]
Train: 100% 2000/2000 [02:44<00:00, 12.19step/s, accuracy=0.84, loss=0.09, step=36000]
Valid: 100% 6944/6944 [00:06<00:00, 1052.01uttr/s, accuracy=0.79, loss=0.37]
Train: 100% 2000/2000 [02:43<00:00, 12.21step/s, accuracy=0.88, loss=0.07, step=38000]
Valid: 100% 6944/6944 [00:06<00:00, 1058.96uttr/s, accuracy=0.80, loss=0.31]
Train: 100% 2000/2000 [02:43<00:00, 12.26step/s, accuracy=0.97, loss=0.00, step=4e+4] 
Valid: 100% 6944/6944 [00:06<00:00, 1056.55uttr/s, accuracy=0.78, loss=0.37]
Train:   0% 2/2000 [00:00<03:40,  9.06step/s, accuracy=0.91, loss=0.03, step=4e+4]

Step 40000, best model saved. (accuracy=0.7980)


Train: 100% 2000/2000 [02:44<00:00, 12.19step/s, accuracy=0.97, loss=0.01, step=42000]
Valid: 100% 6944/6944 [00:06<00:00, 1051.53uttr/s, accuracy=0.81, loss=0.28]
Train: 100% 2000/2000 [02:43<00:00, 12.24step/s, accuracy=0.84, loss=0.05, step=44000]
Valid: 100% 6944/6944 [00:06<00:00, 1058.96uttr/s, accuracy=0.81, loss=0.27]
Train: 100% 2000/2000 [02:43<00:00, 12.26step/s, accuracy=0.91, loss=0.05, step=46000]
Valid: 100% 6944/6944 [00:06<00:00, 1058.17uttr/s, accuracy=0.82, loss=0.25]
Train: 100% 2000/2000 [02:42<00:00, 12.28step/s, accuracy=0.91, loss=0.02, step=48000]
Valid: 100% 6944/6944 [00:06<00:00, 1058.78uttr/s, accuracy=0.82, loss=0.25]
Train: 100% 2000/2000 [02:42<00:00, 12.28step/s, accuracy=0.91, loss=0.02, step=5e+4] 
Valid: 100% 6944/6944 [00:06<00:00, 1057.63uttr/s, accuracy=0.82, loss=0.24]
Train:   0% 2/2000 [00:00<02:59, 11.15step/s, accuracy=0.97, loss=0.01, step=5e+4]

Step 50000, best model saved. (accuracy=0.8240)


Train: 100% 2000/2000 [02:43<00:00, 12.22step/s, accuracy=0.94, loss=0.00, step=52000]
Valid: 100% 6944/6944 [00:06<00:00, 1053.97uttr/s, accuracy=0.83, loss=0.22]
Train: 100% 2000/2000 [02:43<00:00, 12.22step/s, accuracy=0.91, loss=0.02, step=54000]
Valid: 100% 6944/6944 [00:06<00:00, 1058.60uttr/s, accuracy=0.83, loss=0.22]
Train: 100% 2000/2000 [02:43<00:00, 12.21step/s, accuracy=0.91, loss=0.04, step=56000]
Valid: 100% 6944/6944 [00:06<00:00, 1057.48uttr/s, accuracy=0.84, loss=0.20]
Train: 100% 2000/2000 [02:44<00:00, 12.19step/s, accuracy=0.94, loss=0.02, step=58000]
Valid: 100% 6944/6944 [00:06<00:00, 1054.65uttr/s, accuracy=0.85, loss=0.18]
Train: 100% 2000/2000 [02:43<00:00, 12.20step/s, accuracy=0.97, loss=0.02, step=6e+4] 
Valid: 100% 6944/6944 [00:06<00:00, 1054.39uttr/s, accuracy=0.85, loss=0.17]
Train:   0% 2/2000 [00:00<03:08, 10.60step/s, accuracy=0.97, loss=0.00, step=6e+4]

Step 60000, best model saved. (accuracy=0.8476)


Train: 100% 2000/2000 [02:44<00:00, 12.18step/s, accuracy=0.94, loss=0.01, step=62000]
Valid: 100% 6944/6944 [00:06<00:00, 1051.47uttr/s, accuracy=0.85, loss=0.18]
Train: 100% 2000/2000 [02:44<00:00, 12.18step/s, accuracy=1.00, loss=0.00, step=64000]
Valid: 100% 6944/6944 [00:06<00:00, 1054.00uttr/s, accuracy=0.86, loss=0.15]
Train: 100% 2000/2000 [02:42<00:00, 12.34step/s, accuracy=0.94, loss=0.01, step=66000]
Valid: 100% 6944/6944 [00:08<00:00, 857.23uttr/s, accuracy=0.86, loss=0.14]
Train: 100% 2000/2000 [02:39<00:00, 12.57step/s, accuracy=0.91, loss=0.03, step=68000]
Valid: 100% 6944/6944 [00:08<00:00, 794.93uttr/s, accuracy=0.86, loss=0.13]
Train: 100% 2000/2000 [02:39<00:00, 12.54step/s, accuracy=1.00, loss=0.00, step=7e+4] 
Valid: 100% 6944/6944 [00:08<00:00, 782.31uttr/s, accuracy=0.87, loss=0.12]
Train:   0% 2/2000 [00:00<03:08, 10.62step/s, accuracy=0.97, loss=0.00, step=7e+4]

Step 70000, best model saved. (accuracy=0.8692)


Train: 100% 2000/2000 [02:39<00:00, 12.51step/s, accuracy=1.00, loss=0.00, step=72000]
Valid: 100% 6944/6944 [00:08<00:00, 774.06uttr/s, accuracy=0.87, loss=0.12]
Train: 100% 2000/2000 [02:39<00:00, 12.52step/s, accuracy=1.00, loss=0.00, step=74000]
Valid: 100% 6944/6944 [00:09<00:00, 765.52uttr/s, accuracy=0.87, loss=0.11]
Train: 100% 2000/2000 [02:39<00:00, 12.50step/s, accuracy=0.97, loss=0.00, step=76000]
Valid: 100% 6944/6944 [00:09<00:00, 755.57uttr/s, accuracy=0.88, loss=0.10]
Train: 100% 2000/2000 [02:39<00:00, 12.51step/s, accuracy=1.00, loss=0.00, step=78000]
Valid: 100% 6944/6944 [00:08<00:00, 776.80uttr/s, accuracy=0.88, loss=0.11]
Train: 100% 2000/2000 [02:39<00:00, 12.55step/s, accuracy=1.00, loss=0.00, step=8e+4] 
Valid: 100% 6944/6944 [00:09<00:00, 768.67uttr/s, accuracy=0.88, loss=0.11]
Train:   0% 2/2000 [00:00<03:39,  9.11step/s, accuracy=0.97, loss=0.00, step=8e+4]

Step 80000, best model saved. (accuracy=0.8821)


Train: 100% 2000/2000 [02:38<00:00, 12.58step/s, accuracy=1.00, loss=0.00, step=82000]
Valid: 100% 6944/6944 [00:08<00:00, 775.07uttr/s, accuracy=0.89, loss=0.10]
Train: 100% 2000/2000 [02:38<00:00, 12.62step/s, accuracy=1.00, loss=0.00, step=84000]
Valid: 100% 6944/6944 [00:08<00:00, 779.52uttr/s, accuracy=0.89, loss=0.10]
Train: 100% 2000/2000 [02:38<00:00, 12.60step/s, accuracy=1.00, loss=0.00, step=86000]
Valid: 100% 6944/6944 [00:08<00:00, 776.80uttr/s, accuracy=0.89, loss=0.08]
Train: 100% 2000/2000 [02:38<00:00, 12.58step/s, accuracy=1.00, loss=0.00, step=88000]
Valid: 100% 6944/6944 [00:08<00:00, 776.80uttr/s, accuracy=0.89, loss=0.09]
Train: 100% 2000/2000 [02:38<00:00, 12.59step/s, accuracy=0.97, loss=0.01, step=9e+4] 
Valid: 100% 6944/6944 [00:08<00:00, 774.06uttr/s, accuracy=0.89, loss=0.09]
Train:   0% 2/2000 [00:00<02:52, 11.58step/s, accuracy=1.00, loss=0.00, step=9e+4]

Step 90000, best model saved. (accuracy=0.8949)


Train: 100% 2000/2000 [02:39<00:00, 12.55step/s, accuracy=1.00, loss=0.00, step=92000]
Valid: 100% 6944/6944 [00:08<00:00, 772.71uttr/s, accuracy=0.90, loss=0.08]
Train: 100% 2000/2000 [02:39<00:00, 12.55step/s, accuracy=0.97, loss=0.00, step=94000]
Valid: 100% 6944/6944 [00:08<00:00, 775.41uttr/s, accuracy=0.90, loss=0.08]
Train: 100% 2000/2000 [02:39<00:00, 12.55step/s, accuracy=1.00, loss=0.00, step=96000]
Valid: 100% 6944/6944 [00:08<00:00, 778.19uttr/s, accuracy=0.89, loss=0.08]
Train: 100% 2000/2000 [02:38<00:00, 12.59step/s, accuracy=0.97, loss=0.00, step=98000]
Valid: 100% 6944/6944 [00:06<00:00, 1064.30uttr/s, accuracy=0.89, loss=0.08]
Train: 100% 2000/2000 [02:38<00:00, 12.62step/s, accuracy=1.00, loss=0.00, step=1e+5] 
Valid: 100% 6944/6944 [00:06<00:00, 1061.84uttr/s, accuracy=0.90, loss=0.08]
Train:   0% 0/2000 [00:00<?, ?step/s]

Step 100000, best model saved. (accuracy=0.8992)





In [30]:
import os
import torch
import json
from pathlib import Path
from torch.utils.data import Dataset

class InferenceDataset(Dataset):
    def __init__(self, data_dir):
        test_path = Path(data_dir) / "testdata.json"
        testdata = json.load(test_path.open())
        self.data_dir = data_dir
        self.data = testdata['utterances']
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        feat_path = self.data[index]['feature_path']
        mel = torch.load(os.path.join(self.data_dir, feat_path))
        return feat_path, mel

In [31]:
def inference_collate_batch(batch):
    feat_paths, mels = zip(*batch)
    return feat_paths, torch.stack(mels)

In [32]:
def parse_inference_args():
    config = {
        'data_dir': './Dataset',
        'model_path': './model.ckpt',
        'output_path': './output.csv',
    }
    return config

In [33]:
import torch
import csv
import json
from tqdm.notebook import tqdm
from pathlib import Path
from torch.utils.data import DataLoader

def inference_main(data_dir, model_path, output_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"[Info]: Use {device} now!")
    
    dataset = InferenceDataset(data_dir)
    test_loader = DataLoader(
        dataset,
        batch_size = 1,
        num_workers = 0,
        shuffle = False,
        drop_last = False,
        collate_fn = inference_collate_batch,
    )
    
    mapping_path = Path(data_dir) / 'mapping.json'
    mapping = json.load(mapping_path.open())
    
    id2speaker = mapping['id2speaker']
    speaker_num = len(id2speaker)
    
    model = Classifier(n_spks = speaker_num).to(device)
    model.load_state_dict(torch.load(model_path))
    
    results = [['Id','Category']]
    
    model.eval()
    for feat_paths, mels in test_loader:
        mels = mels.to(device)
        with torch.no_grad():
            outs = model(mels)
        preds = outs.argmax(1).cpu().numpy()
        
        for feat_path, pred in zip(feat_paths, preds):
            results.append([feat_path, id2speaker[str(pred)]])
        
    with open(output_path, 'w', newline = '') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(results)

In [34]:
if __name__ == "__main__":
    inference_main(**parse_inference_args())

[Info]: Use cuda now!


In [40]:
!pip install conformer

Collecting conformer
  Downloading conformer-0.2.5-py3-none-any.whl (4.1 kB)
Collecting einops
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops, conformer
Successfully installed conformer-0.2.5 einops-0.3.2


In [41]:
from conformer import ConformerBlock