## Contributions

In [1]:
import os
import json
import torch
import random
import math
import re
import torch.nn as nn
import torch.distributed as dist
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Optimizer
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
from tqdm import tqdm
 
class myDataset(Dataset):
  def __init__(self, data_dir, segment_len=128):
    self.data_dir = data_dir
    self.segment_len = segment_len
 
    # Load the mapping from speaker neme to their corresponding id. 
    mapping_path = Path(data_dir) / "mapping.json"
    mapping = json.load(mapping_path.open())
    self.speaker2id = mapping["speaker2id"]
 
    # Load metadata of training data.
    metadata_path = Path(data_dir) / "metadata.json"
    metadata = json.load(open(metadata_path))["speakers"]
 
    # Get the total number of speaker.
    self.speaker_num = len(metadata.keys())
    self.data = []
    for speaker in metadata.keys():
      for utterances in metadata[speaker]:
        self.data.append([utterances["feature_path"], self.speaker2id[speaker]])
 
  def __len__(self):
    return len(self.data)
 
  def __getitem__(self, index):
    feat_path, speaker = self.data[index]
    # Load preprocessed mel-spectrogram.
    mel = torch.load(os.path.join(self.data_dir, feat_path))
 
    # Segmemt mel-spectrogram into "segment_len" frames.
    if len(mel) > self.segment_len:
      # Randomly get the starting point of the segment.
      start = random.randint(0, len(mel) - self.segment_len)
      # Get a segment with "segment_len" frames.
      mel = torch.FloatTensor(mel[start:start+self.segment_len])
    else:
      mel = torch.FloatTensor(mel)
    # Turn the speaker id into long for computing loss later.
    speaker = torch.FloatTensor([speaker]).long()
    return mel, speaker
 
  def get_speaker_number(self):
    return self.speaker_num


def collate_batch(batch):
  # Process features within a batch.
  """Collate a batch of data."""
  mel, speaker = zip(*batch)
  # Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.
  mel = pad_sequence(mel, batch_first=True, padding_value=-20)    # pad log 10^(-20) which is very small value.
  # mel: (batch size, length, 40)
  return mel, torch.FloatTensor(speaker).long()

def get_dataloader(data_dir, batch_size, n_workers):
  """Generate dataloader"""
  dataset = myDataset(data_dir)
  speaker_num = dataset.get_speaker_number()
  # Split dataset into training dataset and validation dataset
  trainlen = int(0.9 * len(dataset))
  lengths = [trainlen, len(dataset) - trainlen]
  trainset, validset = random_split(dataset, lengths)
  train_loader = DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=collate_batch,
  )
  valid_loader = DataLoader(
    validset,
    batch_size=batch_size,
    num_workers=n_workers,
    drop_last=True,
    pin_memory=True,
    collate_fn=collate_batch,
  )
  return train_loader, valid_loader, speaker_num

def model_fn(batch, model, criterion, device):
  """Forward a batch through the model."""

  mels, labels = batch
  mels = mels.to(device)
  labels = labels.to(device)

  outs = model(mels)

  loss = criterion(outs, labels)

  # Get the speaker id with highest probability.
  preds = outs.argmax(1)
  # Compute accuracy.
  accuracy = torch.mean((preds == labels).float())

  return loss, accuracy

class Classifier(nn.Module):
  def __init__(self, d_model=80, n_spks=600, dropout=0.1):
    super().__init__()
    
    self.prenet = nn.Linear(40, d_model)
    self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, dim_feedforward=1024, nhead=2)  # Increased dim_feedforward
    self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)  # Increased num_layers
    
    self.pred_layer = nn.Sequential(
      nn.Linear(d_model, d_model*2),
      nn.ReLU(),
      nn.Dropout(dropout),
      nn.Linear(d_model*2, d_model),  # Additional linear layer
      nn.ReLU(),
      nn.Dropout(dropout),
      nn.Linear(d_model, d_model//2),  # Additional linear layer
      nn.ReLU(),
      nn.Dropout(dropout),
      nn.Linear(d_model//2, n_spks),
    )  
    
  def forward(self, mels):      
    out = self.prenet(mels)
    out = out.permute(1, 0, 2)
    out = self.encoder(out)
    out = out.transpose(0, 1)
    stats = out.mean(dim=1)
    out = self.pred_layer(stats)
    return out
  

def get_cosine_schedule_with_warmup(
  optimizer: Optimizer,
  num_warmup_steps: int,
  num_training_steps: int,
  num_cycles: float = 0.5,
  last_epoch: int = -1,
):
  def lr_lambda(current_step):
    # Warmup
    if current_step < num_warmup_steps:
      return float(current_step) / float(max(1, num_warmup_steps))
    # decadence
    progress = float(current_step - num_warmup_steps) / float(
      max(1, num_training_steps - num_warmup_steps)
    )
    return max(
      0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
    )

  return LambdaLR(optimizer, lr_lambda, last_epoch)

def valid(dataloader, model, criterion, device): 
  """Validate on validation set."""

  model.eval()
  running_loss = 0.0
  running_accuracy = 0.0
  pbar = tqdm(total=len(dataloader.dataset), ncols=0, desc="Valid", unit=" uttr")

  for i, batch in enumerate(dataloader):
    with torch.no_grad():
      loss, accuracy = model_fn(batch, model, criterion, device)
      running_loss += loss.item()
      running_accuracy += accuracy.item()

    pbar.update(dataloader.batch_size)
    pbar.set_postfix(
      loss=f"{running_loss / (i+1):.2f}",
      accuracy=f"{running_accuracy / (i+1):.2f}",
    )

  pbar.close()
  model.train()

  return running_accuracy / len(dataloader)

def parse_args():
  """arguments"""
  config = {
    "data_dir": "../Dataset",
    "save_path": "model.ckpt",
    "batch_size": 64,
    "n_workers": 0,
    "valid_steps": 2000,
    "warmup_steps": 10000,
    "save_steps": 4000,
    "total_steps": 350000,
  }

  return config

def find_highest_iter_file(directory):
    ckpt_files = os.listdir(directory)
    max_iter = 0
    max_file = ""

    for file in ckpt_files:
        if file.endswith(".ckpt"):
            iter_no = int(re.findall("\d+", file)[0])
            if iter_no > max_iter:
                max_iter = iter_no
                max_file = file
    return max_file, max_iter

def main(data_dir, save_path, batch_size, n_workers, valid_steps, warmup_steps, total_steps, save_steps):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"[Info]: Use {device} now!")

  train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)
  train_iterator = iter(train_loader)
  highest_model_name, start_step = find_highest_iter_file(".")
  model = Classifier(n_spks=speaker_num).to(device)
  
  if highest_model_name:
    model.load_state_dict(torch.load(highest_model_name))
    print(f"[Info]: Loaded pre-trained model: {highest_model_name}")
  
  criterion = nn.CrossEntropyLoss()
  optimizer = AdamW(model.parameters(), lr=1e-3)
  scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

  best_accuracy = -1.0
  best_state_dict = None

  pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit="step")

  for step in range(start_step, total_steps):
    try:
      batch = next(train_iterator)
    except StopIteration:
      train_iterator = iter(train_loader)
      batch = next(train_iterator)
      
    loss, accuracy = model_fn(batch, model, criterion, device)
    batch_loss = loss.item()
    batch_accuracy = accuracy.item()

    loss.backward()
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

    adj_step = step + 1

    pbar.update()
    pbar.set_postfix(
      loss=f"{batch_loss:.2f}",
      accuracy=f"{batch_accuracy:.2f}",
      step=adj_step,
    )

    if (adj_step) % valid_steps == 0:
      pbar.close()

      valid_accuracy = valid(valid_loader, model, criterion, device)
      if valid_accuracy > best_accuracy:
        best_accuracy = valid_accuracy
        best_state_dict = model.state_dict()

      pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit="step")

    if (adj_step) % save_steps == 0 and best_state_dict is not None:
      checkpoint_path = f"model_{adj_step}_{best_accuracy:.4f}.ckpt"
      torch.save(best_state_dict, checkpoint_path)
      pbar.write(f"Step {adj_step}, best model saved. (accuracy={best_accuracy:.4f})")

  pbar.close()

if __name__ == "__main__":
  main(**parse_args())

[Info]: Use cuda now!


Train: 100% 2000/2000 [02:37<00:00, 12.69step/s, accuracy=0.03, loss=5.44, step=2000]
Valid:  99% 4800/4836 [00:09<00:00, 508.22 uttr/s, accuracy=0.06, loss=5.09]
Train: 100% 2000/2000 [01:14<00:00, 26.93step/s, accuracy=0.11, loss=4.47, step=4000]
Valid:  99% 4800/4836 [00:02<00:00, 2193.74 uttr/s, accuracy=0.16, loss=4.18]
Train:   0% 5/2000 [00:00<01:30, 21.95step/s, accuracy=0.14, loss=4.37, step=4005]

Step 4000, best model saved. (accuracy=0.1554)


Train: 100% 2000/2000 [01:14<00:00, 26.87step/s, accuracy=0.20, loss=3.79, step=6000]
Valid:  99% 4800/4836 [00:02<00:00, 2182.60 uttr/s, accuracy=0.24, loss=3.57]
Train: 100% 2000/2000 [01:14<00:00, 26.87step/s, accuracy=0.25, loss=3.37, step=8000]
Valid:  99% 4800/4836 [00:02<00:00, 2248.11 uttr/s, accuracy=0.33, loss=3.07]
Train:   0% 5/2000 [00:00<01:25, 23.33step/s, accuracy=0.23, loss=3.57, step=8005]

Step 8000, best model saved. (accuracy=0.3315)


Train: 100% 2000/2000 [01:14<00:00, 26.86step/s, accuracy=0.28, loss=3.17, step=1e+4]
Valid:  99% 4800/4836 [00:02<00:00, 2266.30 uttr/s, accuracy=0.36, loss=2.91]
Train: 100% 2000/2000 [01:14<00:00, 26.92step/s, accuracy=0.44, loss=2.38, step=12000]
Valid:  99% 4800/4836 [00:02<00:00, 2210.01 uttr/s, accuracy=0.40, loss=2.69]
Train:   0% 5/2000 [00:00<01:31, 21.87step/s, accuracy=0.47, loss=2.70, step=12005]

Step 12000, best model saved. (accuracy=0.4027)


Train: 100% 2000/2000 [01:14<00:00, 26.85step/s, accuracy=0.41, loss=2.52, step=14000]
Valid:  99% 4800/4836 [00:02<00:00, 2187.63 uttr/s, accuracy=0.44, loss=2.51]
Train: 100% 2000/2000 [01:14<00:00, 26.89step/s, accuracy=0.44, loss=2.67, step=16000]
Valid:  99% 4800/4836 [00:02<00:00, 2276.21 uttr/s, accuracy=0.48, loss=2.30]
Train:   0% 5/2000 [00:00<01:24, 23.56step/s, accuracy=0.41, loss=2.36, step=16005]

Step 16000, best model saved. (accuracy=0.4846)


Train: 100% 2000/2000 [01:14<00:00, 26.83step/s, accuracy=0.52, loss=2.01, step=18000]
Valid:  99% 4800/4836 [00:02<00:00, 2275.77 uttr/s, accuracy=0.50, loss=2.28]
Train: 100% 2000/2000 [01:14<00:00, 26.89step/s, accuracy=0.42, loss=2.24, step=2e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2209.31 uttr/s, accuracy=0.53, loss=2.13]
Train:   0% 5/2000 [00:00<01:29, 22.24step/s, accuracy=0.59, loss=1.89, step=2e+4]

Step 20000, best model saved. (accuracy=0.5308)


Train: 100% 2000/2000 [01:14<00:00, 26.90step/s, accuracy=0.58, loss=1.66, step=22000]
Valid:  99% 4800/4836 [00:02<00:00, 2206.90 uttr/s, accuracy=0.56, loss=2.09]
Train: 100% 2000/2000 [01:14<00:00, 26.87step/s, accuracy=0.69, loss=1.61, step=24000]
Valid:  99% 4800/4836 [00:02<00:00, 2238.19 uttr/s, accuracy=0.59, loss=1.91]
Train:   0% 5/2000 [00:00<01:26, 23.15step/s, accuracy=0.50, loss=2.16, step=24005]

Step 24000, best model saved. (accuracy=0.5935)


Train: 100% 2000/2000 [01:14<00:00, 26.92step/s, accuracy=0.73, loss=1.09, step=26000]
Valid:  99% 4800/4836 [00:02<00:00, 2225.81 uttr/s, accuracy=0.58, loss=2.01]
Train: 100% 2000/2000 [01:14<00:00, 26.75step/s, accuracy=0.62, loss=1.35, step=28000]
Valid:  99% 4800/4836 [00:02<00:00, 2253.70 uttr/s, accuracy=0.61, loss=1.84]
Train:   0% 5/2000 [00:00<01:26, 23.02step/s, accuracy=0.64, loss=1.34, step=28005]

Step 28000, best model saved. (accuracy=0.6083)


Train: 100% 2000/2000 [01:14<00:00, 26.83step/s, accuracy=0.67, loss=1.25, step=3e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2168.02 uttr/s, accuracy=0.62, loss=1.75]
Train: 100% 2000/2000 [01:14<00:00, 26.84step/s, accuracy=0.62, loss=1.59, step=32000]
Valid:  99% 4800/4836 [00:02<00:00, 2274.86 uttr/s, accuracy=0.64, loss=1.72]
Train:   0% 5/2000 [00:00<01:29, 22.39step/s, accuracy=0.66, loss=1.51, step=32005]

Step 32000, best model saved. (accuracy=0.6385)


Train: 100% 2000/2000 [01:14<00:00, 26.84step/s, accuracy=0.62, loss=1.42, step=34000]
Valid:  99% 4800/4836 [00:02<00:00, 2245.14 uttr/s, accuracy=0.64, loss=1.76]
Train: 100% 2000/2000 [01:14<00:00, 26.84step/s, accuracy=0.69, loss=1.28, step=36000]
Valid:  99% 4800/4836 [00:02<00:00, 2269.52 uttr/s, accuracy=0.66, loss=1.72]
Train:   0% 5/2000 [00:00<01:25, 23.46step/s, accuracy=0.64, loss=1.56, step=36005]

Step 36000, best model saved. (accuracy=0.6567)


Train: 100% 2000/2000 [01:14<00:00, 26.78step/s, accuracy=0.70, loss=1.14, step=38000]
Valid:  99% 4800/4836 [00:02<00:00, 2212.02 uttr/s, accuracy=0.67, loss=1.58]
Train: 100% 2000/2000 [01:14<00:00, 26.76step/s, accuracy=0.69, loss=1.58, step=4e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2227.43 uttr/s, accuracy=0.66, loss=1.65]
Train:   0% 5/2000 [00:00<01:27, 22.83step/s, accuracy=0.75, loss=1.17, step=4e+4]

Step 40000, best model saved. (accuracy=0.6710)


Train: 100% 2000/2000 [01:14<00:00, 26.79step/s, accuracy=0.69, loss=1.14, step=42000]
Valid:  99% 4800/4836 [00:02<00:00, 2241.36 uttr/s, accuracy=0.67, loss=1.59]
Train: 100% 2000/2000 [01:14<00:00, 26.81step/s, accuracy=0.62, loss=1.61, step=44000]
Valid:  99% 4800/4836 [00:02<00:00, 2270.92 uttr/s, accuracy=0.68, loss=1.55]
Train:   0% 5/2000 [00:00<01:26, 23.14step/s, accuracy=0.59, loss=1.90, step=44005]

Step 44000, best model saved. (accuracy=0.6779)


Train: 100% 2000/2000 [01:14<00:00, 26.75step/s, accuracy=0.80, loss=0.99, step=46000]
Valid:  99% 4800/4836 [00:02<00:00, 2241.77 uttr/s, accuracy=0.70, loss=1.50]
Train: 100% 2000/2000 [01:14<00:00, 26.84step/s, accuracy=0.75, loss=1.31, step=48000]
Valid:  99% 4800/4836 [00:02<00:00, 2206.58 uttr/s, accuracy=0.70, loss=1.53]
Train:   0% 5/2000 [00:00<01:28, 22.66step/s, accuracy=0.73, loss=1.24, step=48005]

Step 48000, best model saved. (accuracy=0.6973)


Train: 100% 2000/2000 [01:14<00:00, 26.74step/s, accuracy=0.72, loss=1.21, step=5e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2173.03 uttr/s, accuracy=0.70, loss=1.58]
Train: 100% 2000/2000 [01:14<00:00, 26.78step/s, accuracy=0.75, loss=1.13, step=52000]
Valid:  99% 4800/4836 [00:02<00:00, 2248.31 uttr/s, accuracy=0.70, loss=1.54]
Train:   0% 5/2000 [00:00<01:25, 23.34step/s, accuracy=0.67, loss=1.39, step=52005]

Step 52000, best model saved. (accuracy=0.6973)


Train: 100% 2000/2000 [01:14<00:00, 26.68step/s, accuracy=0.70, loss=1.21, step=54000]
Valid:  99% 4800/4836 [00:02<00:00, 2244.61 uttr/s, accuracy=0.71, loss=1.58]
Train: 100% 2000/2000 [01:14<00:00, 26.82step/s, accuracy=0.66, loss=1.26, step=56000]
Valid:  99% 4800/4836 [00:02<00:00, 2303.36 uttr/s, accuracy=0.71, loss=1.53]
Train:   0% 5/2000 [00:00<01:25, 23.45step/s, accuracy=0.73, loss=1.22, step=56005]

Step 56000, best model saved. (accuracy=0.7121)


Train: 100% 2000/2000 [01:14<00:00, 26.81step/s, accuracy=0.66, loss=1.79, step=58000]
Valid:  99% 4800/4836 [00:02<00:00, 2199.23 uttr/s, accuracy=0.70, loss=1.60]
Train: 100% 2000/2000 [01:14<00:00, 26.81step/s, accuracy=0.69, loss=1.46, step=6e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2216.57 uttr/s, accuracy=0.72, loss=1.51]
Train:   0% 5/2000 [00:00<01:25, 23.29step/s, accuracy=0.81, loss=0.84, step=6e+4]

Step 60000, best model saved. (accuracy=0.7160)


Train: 100% 2000/2000 [01:14<00:00, 26.72step/s, accuracy=0.78, loss=1.29, step=62000]
Valid:  99% 4800/4836 [00:02<00:00, 2233.19 uttr/s, accuracy=0.73, loss=1.40]
Train: 100% 2000/2000 [01:14<00:00, 26.76step/s, accuracy=0.73, loss=1.14, step=64000]
Valid:  99% 4800/4836 [00:02<00:00, 2269.17 uttr/s, accuracy=0.75, loss=1.33]
Train:   0% 5/2000 [00:00<01:24, 23.56step/s, accuracy=0.80, loss=0.90, step=64005]

Step 64000, best model saved. (accuracy=0.7465)


Train: 100% 2000/2000 [01:14<00:00, 26.75step/s, accuracy=0.77, loss=1.10, step=66000]
Valid:  99% 4800/4836 [00:02<00:00, 2233.20 uttr/s, accuracy=0.74, loss=1.36]
Train: 100% 2000/2000 [01:14<00:00, 26.77step/s, accuracy=0.80, loss=1.25, step=68000]
Valid:  99% 4800/4836 [00:02<00:00, 2241.29 uttr/s, accuracy=0.73, loss=1.39]
Train:   0% 5/2000 [00:00<01:29, 22.36step/s, accuracy=0.78, loss=1.07, step=68005]

Step 68000, best model saved. (accuracy=0.7465)


Train: 100% 2000/2000 [01:14<00:00, 26.70step/s, accuracy=0.75, loss=1.08, step=7e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2243.99 uttr/s, accuracy=0.74, loss=1.38]
Train: 100% 2000/2000 [01:14<00:00, 26.76step/s, accuracy=0.73, loss=1.25, step=72000]
Valid:  99% 4800/4836 [00:02<00:00, 2239.93 uttr/s, accuracy=0.74, loss=1.39]
Train:   0% 5/2000 [00:00<01:25, 23.39step/s, accuracy=0.77, loss=1.17, step=72005]

Step 72000, best model saved. (accuracy=0.7465)


Train: 100% 2000/2000 [01:15<00:00, 26.50step/s, accuracy=0.75, loss=1.10, step=74000]
Valid:  99% 4800/4836 [00:02<00:00, 2168.29 uttr/s, accuracy=0.76, loss=1.30]
Train: 100% 2000/2000 [01:16<00:00, 26.09step/s, accuracy=0.73, loss=1.16, step=76000]
Valid:  99% 4800/4836 [00:02<00:00, 2027.77 uttr/s, accuracy=0.74, loss=1.34]
Train:   0% 5/2000 [00:00<01:31, 21.88step/s, accuracy=0.75, loss=1.08, step=76005]

Step 76000, best model saved. (accuracy=0.7579)


Train: 100% 2000/2000 [01:15<00:00, 26.32step/s, accuracy=0.75, loss=1.10, step=78000]
Valid:  99% 4800/4836 [00:02<00:00, 2131.80 uttr/s, accuracy=0.75, loss=1.34]
Train: 100% 2000/2000 [01:16<00:00, 26.31step/s, accuracy=0.81, loss=0.92, step=8e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2180.30 uttr/s, accuracy=0.75, loss=1.38]
Train:   0% 5/2000 [00:00<01:28, 22.64step/s, accuracy=0.78, loss=0.96, step=8e+4]

Step 80000, best model saved. (accuracy=0.7579)


Train: 100% 2000/2000 [01:15<00:00, 26.36step/s, accuracy=0.88, loss=0.76, step=82000]
Valid:  99% 4800/4836 [00:02<00:00, 2192.21 uttr/s, accuracy=0.76, loss=1.29]
Train: 100% 2000/2000 [01:16<00:00, 26.29step/s, accuracy=0.73, loss=1.34, step=84000]
Valid:  99% 4800/4836 [00:02<00:00, 2189.96 uttr/s, accuracy=0.77, loss=1.29]
Train:   0% 5/2000 [00:00<01:27, 22.81step/s, accuracy=0.75, loss=0.96, step=84005]

Step 84000, best model saved. (accuracy=0.7692)


Train: 100% 2000/2000 [01:16<00:00, 26.28step/s, accuracy=0.81, loss=0.96, step=86000]
Valid:  99% 4800/4836 [00:02<00:00, 2168.60 uttr/s, accuracy=0.76, loss=1.28]
Train: 100% 2000/2000 [01:15<00:00, 26.33step/s, accuracy=0.83, loss=0.65, step=88000]
Valid:  99% 4800/4836 [00:02<00:00, 2216.29 uttr/s, accuracy=0.76, loss=1.28]
Train:   0% 5/2000 [00:00<01:29, 22.32step/s, accuracy=0.83, loss=0.94, step=88005]

Step 88000, best model saved. (accuracy=0.7692)


Train: 100% 2000/2000 [01:15<00:00, 26.34step/s, accuracy=0.86, loss=0.69, step=9e+4] 
Valid:  99% 4800/4836 [00:02<00:00, 2236.33 uttr/s, accuracy=0.77, loss=1.26]
Train: 100% 2000/2000 [01:16<00:00, 26.08step/s, accuracy=0.78, loss=0.89, step=92000]
Valid:  99% 4800/4836 [00:02<00:00, 2160.96 uttr/s, accuracy=0.77, loss=1.33]
Train:   0% 4/2000 [00:00<01:56, 17.09step/s, accuracy=0.83, loss=0.79, step=92004]

Step 92000, best model saved. (accuracy=0.7706)


Train: 100% 2000/2000 [01:21<00:00, 24.41step/s, accuracy=0.86, loss=0.64, step=94000]
Valid:  99% 4800/4836 [00:02<00:00, 2266.27 uttr/s, accuracy=0.77, loss=1.29]
Train: 100% 2000/2000 [01:15<00:00, 26.55step/s, accuracy=0.77, loss=1.17, step=96000]
Valid:  99% 4800/4836 [00:02<00:00, 1894.74 uttr/s, accuracy=0.78, loss=1.25]
Train:   0% 5/2000 [00:00<01:26, 23.19step/s, accuracy=0.81, loss=0.76, step=96005]

Step 96000, best model saved. (accuracy=0.7781)


Train: 100% 2000/2000 [01:13<00:00, 27.15step/s, accuracy=0.88, loss=0.61, step=102000]
Valid:  99% 4800/4836 [00:02<00:00, 2275.81 uttr/s, accuracy=0.77, loss=1.33]
Train:  12% 240/2000 [00:08<01:03, 27.69step/s, accuracy=0.78, loss=1.10, step=102240]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:13<00:00, 27.19step/s, accuracy=0.89, loss=0.47, step=114000]
Valid:  99% 4800/4836 [00:02<00:00, 2255.61 uttr/s, accuracy=0.78, loss=1.27]
Train:   2% 47/2000 [00:01<01:14, 26.36step/s, accuracy=0.88, loss=0.72, step=114047]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config vari

Step 140000, best model saved. (accuracy=0.8094)


Train:  18% 363/2000 [00:13<00:58, 28.11step/s, accuracy=0.92, loss=0.62, step=140363]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train:  89% 1781/2000 [01:05<00:08, 27.11step/s, accuracy=0.81, loss=0.81, step=141781]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:13<00:00, 27.36step/s, accuracy=0.88, loss=0.49, step=152000]
Valid:  99% 4800/4836 [00:02<00:00, 2297.11 uttr/s, accuracy=0.81, loss=1.09]
Train:   0%

Step 152000, best model saved. (accuracy=0.8200)


Train:  78% 1563/2000 [00:57<00:16, 26.97step/s, accuracy=0.94, loss=0.38, step=153563]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train:  41% 820/2000 [00:30<00:42, 27.92step/s, accuracy=0.94, loss=0.16, step=154820]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:15<00:00, 26.66step/s, accuracy=0.91, loss=0.35, step=166000]
Valid:  99% 4800/4836 [00:02<00:00, 2335.87 uttr/s, accuracy=0.82, loss=1.14]
Train:  51%

Step 204000, best model saved. (accuracy=0.8492)


Train:  13% 266/2000 [00:09<01:03, 27.52step/s, accuracy=0.97, loss=0.15, step=204266]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train:  85% 1709/2000 [01:02<00:11, 26.44step/s, accuracy=0.92, loss=0.37, step=205709]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:15<00:00, 26.46step/s, accuracy=0.95, loss=0.15, step=216000]
Valid:  99% 4800/4836 [00:02<00:00, 2294.93 uttr/s, accuracy=0.85, loss=1.05]
Train:   0%

Step 216000, best model saved. (accuracy=0.8562)


Train:  55% 1103/2000 [00:40<00:32, 27.98step/s, accuracy=0.94, loss=0.38, step=217103]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train:  30% 605/2000 [00:21<00:50, 27.70step/s, accuracy=0.94, loss=0.39, step=218605]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:14<00:00, 26.94step/s, accuracy=0.98, loss=0.06, step=230000]
Valid:  58% 2816/4836 [00:01<00:00, 2254.60 uttr/s, accuracy=0.85, loss=1.12]IOPub messag

Step 256000, best model saved. (accuracy=0.8733)


Train:  21% 423/2000 [00:15<00:57, 27.23step/s, accuracy=0.97, loss=0.13, step=256423]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:12<00:00, 27.51step/s, accuracy=0.98, loss=0.06, step=268000]
Valid:  99% 4800/4836 [00:02<00:00, 2338.17 uttr/s, accuracy=0.87, loss=1.20]
Train:   0% 5/2000 [00:00<01:23, 23.76step/s, accuracy=0.95, loss=0.19, step=268005]

Step 268000, best model saved. (accuracy=0.8777)


Train:  70% 1391/2000 [00:50<00:22, 27.65step/s, accuracy=1.00, loss=0.02, step=269391]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:12<00:00, 27.58step/s, accuracy=1.00, loss=0.01, step=282000]
Valid:  99% 4800/4836 [00:02<00:00, 2306.16 uttr/s, accuracy=0.88, loss=1.12]
Train:  15% 298/2000 [00:10<01:01, 27.80step/s, accuracy=1.00, loss=0.03, step=282298]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100%

Step 308000, best model saved. (accuracy=0.8892)


Train:   6% 116/2000 [00:04<01:07, 28.00step/s, accuracy=1.00, loss=0.01, step=308116]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:15<00:00, 26.55step/s, accuracy=1.00, loss=0.00, step=320000]
Valid:  99% 4800/4836 [00:02<00:00, 2255.69 uttr/s, accuracy=0.89, loss=1.15]
Train:   0% 5/2000 [00:00<01:26, 23.17step/s, accuracy=1.00, loss=0.01, step=320005]

Step 320000, best model saved. (accuracy=0.8940)


Train:  51% 1014/2000 [00:36<00:35, 27.47step/s, accuracy=0.98, loss=0.03, step=321014]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:14<00:00, 26.94step/s, accuracy=1.00, loss=0.01, step=332000]
Valid:  99% 4800/4836 [00:02<00:00, 2117.36 uttr/s, accuracy=0.89, loss=1.08]
Train:   0% 5/2000 [00:00<01:29, 22.37step/s, accuracy=1.00, loss=0.02, step=332005]

Step 332000, best model saved. (accuracy=0.8940)


Train: 100% 2000/2000 [01:14<00:00, 26.73step/s, accuracy=1.00, loss=0.00, step=334000]
Valid:  99% 4800/4836 [00:02<00:00, 2221.95 uttr/s, accuracy=0.89, loss=1.12]
Train:   3% 51/2000 [00:01<01:14, 26.10step/s, accuracy=1.00, loss=0.00, step=334051]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Train: 100% 2000/2000 [01:13<00:00, 27.31step/s, accuracy=0.98, loss=0.04, step=346000]
Valid:  99% 4800/4836 [00:02<00:00, 2160.09 uttr/s, accuracy=0.90, loss=1.02]
Train:  60% 1209/2000 [00:44<00:29, 26.80step/s, accuracy=1.00, loss=0.00, step=347208]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config var

## Dataset of inference

In [2]:
import os
import json
import torch
from pathlib import Path
from torch.utils.data import Dataset


class InferenceDataset(Dataset):
  def __init__(self, data_dir):
    testdata_path = Path(data_dir) / "testdata.json"
    metadata = json.load(testdata_path.open())
    self.data_dir = data_dir
    self.data = metadata["utterances"]

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    utterance = self.data[index]
    feat_path = utterance["feature_path"]
    mel = torch.load(os.path.join(self.data_dir, feat_path))

    return feat_path, mel


def inference_collate_batch(batch):
  """Collate a batch of data."""
  feat_paths, mels = zip(*batch)

  return feat_paths, torch.stack(mels)


## Main funcrion of Inference

In [3]:
import json
import csv
import glob
from pathlib import Path
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

def parse_args():
  """arguments"""
  config = {
    "data_dir": "../Dataset",
    "model_path": "./model.ckpt",
    "output_path": "./output.csv",
  }

  return config


def main(
  data_dir,
  model_path,
  output_path,
):
  """Main function."""
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"[Info]: Use {device} now!")

  mapping_path = Path(data_dir) / "mapping.json"
  mapping = json.load(mapping_path.open())

  dataset = InferenceDataset(data_dir)
  dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    collate_fn=inference_collate_batch,
  )
  print(f"[Info]: Finish loading data!",flush = True)

  speaker_num = len(mapping["id2speaker"])
  model = Classifier(n_spks=speaker_num).to(device)
  # new code to automatically find and load the best model
  list_of_files = glob.glob(f"./*.ckpt")  # get list of all model files
  print(list_of_files)
  latest_file = max(list_of_files, key=lambda x: float(x.split('_')[2].split('.')[1]))  # get the best model file
  print(f"Loading from file {latest_file}")
  model.load_state_dict(torch.load(latest_file))
  model.eval()
  print(f"[Info]: Finish creating model!",flush = True)

  results = [["Id", "Category"]]
  for feat_paths, mels in tqdm(dataloader):
    with torch.no_grad():
      mels = mels.to(device)
      outs = model(mels)
      preds = outs.argmax(1).cpu().numpy()
      for feat_path, pred in zip(feat_paths, preds):
        results.append([feat_path, mapping["id2speaker"][str(pred)]])
  
  with open(output_path, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerows(results)


if __name__ == "__main__":
  main(**parse_args())

[Info]: Use cuda now!
[Info]: Finish loading data!
['./model_188000_0.8404.ckpt', './model_312000_0.8938.ckpt', './model_140000_0.8094.ckpt', './model_280000_0.8831.ckpt', './model_192000_0.8425.ckpt', './model_236000_0.8608.ckpt', './model_340000_0.8950.ckpt', './model_268000_0.8777.ckpt', './model_56000_0.7121.ckpt', './model_304000_0.8877.ckpt', './model_292000_0.8867.ckpt', './model_328000_0.8940.ckpt', './model_132000_0.8046.ckpt', './model_52000_0.6973.ckpt', './model_256000_0.8733.ckpt', './model_212000_0.8562.ckpt', './model_80000_0.7579.ckpt', './model_224000_0.8602.ckpt', './model_152000_0.8200.ckpt', './model_64000_0.7465.ckpt', './model_172000_0.8313.ckpt', './model_228000_0.8602.ckpt', './model_60000_0.7160.ckpt', './model_40000_0.6710.ckpt', './model_48000_0.6973.ckpt', './model_88000_0.7692.ckpt', './model_232000_0.8602.ckpt', './model_216000_0.8562.ckpt', './model_108000_0.7867.ckpt', './model_220000_0.8579.ckpt', './model_120000_0.7929.ckpt', './model_240000_0.8631.ckp

100%|███████████████████████████████████████████████████████████████████████████████████| 6657/6657 [00:32<00:00, 205.44it/s]
