# Task description
- Classify the speakers of given features.
- Main goal: Learn how to use transformer.
- Baselines:
  - Easy: Run sample code and know how to use transformer.
  - Medium: Know how to adjust parameters of transformer.
  - Strong: Construct [conformer](https://arxiv.org/abs/2005.08100) which is a variety of transformer. 
  - Boss: Implement [Self-Attention Pooling](https://arxiv.org/pdf/2008.01077v1.pdf) & [Additive Margin Softmax](https://arxiv.org/pdf/1801.05599.pdf) to further boost the performance.

- Other links
  - Kaggle: [link](https://www.kaggle.com/t/ac77388c90204a4c8daebeddd40ff916)
  - Slide: [link](https://docs.google.com/presentation/d/1HLAj7UUIjZOycDe7DaVLSwJfXVd3bXPOyzSb6Zk3hYU/edit?usp=sharing)
  - Data: [link](https://drive.google.com/drive/folders/1vI1kuLB-q1VilIftiwnPOCAeOOFfBZge?usp=sharing)

# Download dataset
- Data is [here](https://drive.google.com/drive/folders/1vI1kuLB-q1VilIftiwnPOCAeOOFfBZge?usp=sharing)

## Fix Random Seed

In [1]:
import numpy as np
import torch
import random

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_seed(2023)
kfold = 8

# Data

## Dataset
- Original dataset is [Voxceleb2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html).
- The [license](https://creativecommons.org/licenses/by/4.0/) and [complete version](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/files/license.txt) of Voxceleb2.
- We randomly select 600 speakers from Voxceleb2.
- Then preprocess the raw waveforms into mel-spectrograms.

- Args:
  - data_dir: The path to the data directory.
  - metadata_path: The path to the metadata.
  - segment_len: The length of audio segment for training. 
- The architecture of data directory  
  - data directory  
  |---- metadata.json  
  |---- testdata.json  
  |---- mapping.json  
  |---- uttr-{random string}.pt

- The information in metadata
  - "n_mels": The dimention of mel-spectrogram.
  - "speakers": A dictionary. 
    - Key: speaker ids.
    - value: "feature_path" and "mel_len"


For efficiency, we segment the mel-spectrograms into segments in the traing step.

In [2]:
import os
import json
import torch
import random
from pathlib import Path
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class VoxDataset(Dataset):
    def __init__(self, data_dir, segment_len=512):
        self.training = False
        self.data_dir = data_dir
        self.segment_len = segment_len
        self.p = 0.95
        print(f"dataset configs: segment_len={segment_len}, p={self.p}")
    
        # Load the mapping from speaker neme to their corresponding id. 
        mapping_path = Path(data_dir) / "mapping.json"
        mapping = json.load(mapping_path.open())
        self.speaker2id = mapping["speaker2id"]
    
        # Load metadata of training data.
        metadata_path = Path(data_dir) / "metadata.json"
        metadata = json.load(open(metadata_path))["speakers"]
    
        # Get the total number of speaker.
        self.speaker_num = len(metadata.keys())
        self.data = []
        for speaker in metadata.keys():
            for utterances in metadata[speaker]:
                self.data.append([utterances["feature_path"], self.speaker2id[speaker]])
 
    def __len__(self):
            return len(self.data)
 
    def __getitem__(self, index):
        feat_path, speaker = self.data[index]
        # Load preprocessed mel-spectrogram.
        mel = torch.load(os.path.join(self.data_dir, feat_path))

        if len(mel) > self.segment_len:
            start = random.randint(0, len(mel) - self.segment_len)
            mel = torch.FloatTensor(mel[start:start + self.segment_len])
        else:
            mel = torch.FloatTensor(mel)
        
        if self.training is True:
            seglen, speclen = mel.shape
            ptime = random.uniform(0, self.p)
            pspec = random.uniform(0, self.p)
            time_mask_len = int(seglen * ptime)
            time_mask_beg = random.randint(0, seglen - time_mask_len)
            spec_mask_len = int(speclen * pspec)
            spec_mask_beg = random.randint(0, speclen)
            mel[time_mask_beg:time_mask_beg + time_mask_len, :] = 0
            if spec_mask_beg + spec_mask_len < speclen:
                mel[:, spec_mask_beg:spec_mask_beg + spec_mask_len] = 0
            else:
                mel[:, spec_mask_beg:speclen] = 0
                mel[:, 0:spec_mask_len - (speclen - spec_mask_beg)] = 0
        
        # Turn the speaker id into long for computing loss later.
        speaker = torch.FloatTensor([speaker]).long()
        return mel, speaker
 
    def get_speaker_number(self):
        return self.speaker_num

## Dataloader
- Split dataset into training dataset(90%) and validation dataset(10%).
- Create dataloader to iterate the data.

In [3]:
import torch
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence


def collate_batch(batch):
    # Process features within a batch.
    """Collate a batch of data."""
    mel, speaker = zip(*batch)
    # Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.
    mel = pad_sequence(mel, batch_first=True, padding_value=-20)    # pad log 10^(-20) which is very small value.
    # mel: (batch size, length, 40)
    return mel, torch.FloatTensor(speaker).long()


def get_dataloader(data_dir, batch_size, n_workers):
    """Generate dataloader"""
    dataset = VoxDataset(data_dir)
    speaker_num = dataset.get_speaker_number()
    # Split dataset into training dataset and validation dataset
    proportions = [1. / kfold for i in range(kfold)]
    lengths = [int(p * len(dataset)) for p in proportions]
    lengths[-1] = len(dataset) - sum(lengths[:-1])
    sets = random_split(dataset, lengths, generator=torch.Generator().manual_seed(2023))

    for trainset in sets:
        trainset.training = True
    train_loader = [DataLoader(
            trainset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=n_workers,
            pin_memory=True,
            collate_fn=collate_batch,
        ) for trainset in sets]

    for validset in sets:
        validset.training = False
    valid_loader = [DataLoader(
            validset,
            batch_size=batch_size,
            num_workers=n_workers,
            drop_last=True,
            pin_memory=True,
            collate_fn=collate_batch,
        ) for validset in sets]

    return train_loader, valid_loader, speaker_num

# Model
- TransformerEncoderLayer:
  - Base transformer encoder layer in [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
  - Parameters:
    - embed: the number of expected features of the input (required).
    - nhead: the number of heads of the multiheadattention models (required).
    - dim_feedforward: the dimension of the feedforward network model (default=2048).
    - dropout: the dropout value (default=0.1).
    - activation: the activation function of intermediate layer, relu or gelu (default=relu).

- TransformerEncoder:
  - TransformerEncoder is a stack of N transformer encoder layers
  - Parameters:
    - encoder_layer: an instance of the TransformerEncoderLayer() class (required).
    - num_layers: the number of sub-encoder-layers in the encoder (required).
    - norm: the layer normalization component (optional).

In [4]:
## import torch
import torch.nn as nn
import torch.nn.functional as F


class SubSampling(nn.Module):
    def __init__(self, in_channel=1, out_channel=3):
        super(SubSampling, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
        )
    def forward(self, x):
        outputs = self.model(x.unsqueeze(1))
        batch_size, c, seqlen, width = outputs.size()
        outputs = outputs.permute(0, 2, 1, 3)
        outputs = outputs.contiguous().view(batch_size, seqlen, c * width)
        return outputs

class Perm(nn.Module):
    def __init__(self, order):
        super(Perm, self).__init__()
        self.order = order
    def forward(self, x):
        return x.permute(*self.order)

class FFN(nn.Module):
    def __init__(self, embed, dropout):
        super(FFN, self).__init__()
        # inout: (batch_size, seqlen, embed)
        self.model = nn.Sequential(
            nn.LayerNorm(embed),
            nn.Linear(embed, embed * 4),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(embed * 4, embed),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        out = self.model(x)
        return out + x

class PointwiseConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        # (batch, c, seq)
        super(PointwiseConv, self).__init__()
        self.model = nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        return self.model(x)

class DepthwiseConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size):
        # (batch, c, seq)
        print(f"depwise_conv config: kernel_size={kernel_size}")
        super(DepthwiseConv, self).__init__()
        assert out_channel % in_channel == 0
        padding = (kernel_size - 1) // 2
        self.model = nn.Conv1d(in_channel, out_channel, kernel_size=kernel_size,
                               stride=1, padding=padding, groups=in_channel)
    def forward(self, x):
        return self.model(x)

class ConvBlock(nn.Module):
    def __init__(self, embed, dropout):
        super(ConvBlock, self).__init__()
        # inout: (batch, seq, embed)
        self.model = nn.Sequential(
            nn.LayerNorm(embed),                                  # out: (batch, seq, embed)
            Perm((0, 2, 1)),                                        # out: (batch, embed, seq)
            PointwiseConv(in_channel=embed, out_channel=embed * 2), # out: (batch, embed * 2, seq)
            nn.GLU(dim=1),
            DepthwiseConv(in_channel=embed, out_channel=embed, kernel_size=9),     # inout: (batch, embed, seq)
            nn.BatchNorm1d(embed),
            nn.SiLU(),
            PointwiseConv(in_channel=embed, out_channel=embed),
            nn.Dropout(dropout),
            Perm((0, 2, 1))                                         # out: (batch, seq, embed)
        )
    def forward(self, x):
        out = self.model(x)
        return out + x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()
        self.layernorm = nn.LayerNorm(embed)
        self.attention = nn.MultiheadAttention(embed, num_heads, dropout=dropout, batch_first=True)
    def forward(self, x, mask=None):
        xx = self.layernorm(x)
        out, _ = self.attention(xx, xx, xx, need_weights=False)
        return out + x

class ConformerBlock(nn.Module):
    def __init__(self, embed, dropout):
        super(ConformerBlock, self).__init__()
        self.model = nn.Sequential(
            FFN(embed=embed, dropout=dropout),
            MultiHeadAttention(embed=embed, num_heads=8, dropout=dropout),
            ConvBlock(embed=embed, dropout=dropout),
            FFN(embed=embed, dropout=dropout),
            nn.LayerNorm(embed)
        )
    def forward(self, x):
        return self.model(x)

class AdditiveMarginSoftmax(nn.Module):
    def __init__(self, embed, n_spks):
        super(AdditiveMarginSoftmax, self).__init__()
        self.margin = 0.35
        self.factor = 30
        self.n_spks = n_spks
        self.W = torch.nn.Parameter(torch.randn(embed, n_spks), requires_grad=True)
        nn.init.xavier_normal_(self.W, gain=1)
    def forward(self, x, y=None):
        norm_x = torch.norm(x, dim=1, p=2, keepdim=True).clamp(min=1e-12)
        x_norm = torch.div(x, norm_x)
        norm_W = torch.norm(self.W, dim=0, p=2, keepdim=True).clamp(min=1e-12)
        W_norm = torch.div(self.W, norm_W)
        out = torch.mm(x_norm, W_norm)
        sub = torch.zeros_like(out)
        if (y != None):
            sub = sub.scatter_(1, y.view(-1, 1), self.margin)
        out = self.factor * (out - sub)
        return out

class SelfAttentionPooling(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)
    def forward(self, batch_rep):
        att_w = F.softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)
        return utter_rep

class Classifier(nn.Module):
    def __init__(self, feat=40, embed=256, n_spks=600, dropout=0.3):
        print(f"running config: embed={embed}, dropout={dropout}")
        super().__init__()
        subsample_expand = 4
        subsample_feat = (feat + 1) // 2
        
        self.model = nn.Sequential(
            SubSampling(in_channel=1, out_channel=subsample_expand), # out: (batch, seq, feat * channel)
            nn.Linear(subsample_feat * subsample_expand, embed),
            #nn.Linear(feat, embed),
            nn.Dropout(dropout),
            ConformerBlock(embed=embed, dropout=dropout),
            SelfAttentionPooling(embed),                   # out: (batch, embed)
            nn.LayerNorm(embed),
            nn.Linear(embed, embed)
        )
        self.output = AdditiveMarginSoftmax(embed, n_spks)
        print(f"AM-Softmax config: margin={self.output.margin}, factor={self.output.factor}")
    def forward(self, mels, y=None):
                                        # in:  (batch, seq, feat)
        out = self.model(mels)          # out: (batch, n_spks)
        spaced = self.output(out, y)        
        raw = self.output(out, None)
        return spaced, raw

# Learning rate schedule
- For transformer architecture, the design of learning rate schedule is different from that of CNN.
- Previous works show that the warmup of learning rate is useful for training models with transformer architectures.
- The warmup schedule
  - Set learning rate to 0 in the beginning.
  - The learning rate increases linearly from 0 to initial learning rate during warmup period.

In [5]:
import math

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR


def get_cosine_schedule_with_warmup(
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: float = 0.5,
    last_epoch: int = -1,
):
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.

    Args:
        optimizer (:class:`~torch.optim.Optimizer`):
        The optimizer for which to schedule the learning rate.
        num_warmup_steps (:obj:`int`):
        The number of steps for the warmup phase.
        num_training_steps (:obj:`int`):
        The total number of training steps.
        num_cycles (:obj:`float`, `optional`, defaults to 0.5):
        The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
        following a half-cosine).
        last_epoch (:obj:`int`, `optional`, defaults to -1):
        The index of the last epoch when resuming training.

    Return:
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """
    def lr_lambda(current_step):
        # Warmup
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        # decadence
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )
        return 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

# Model Function
- Model forward function.

In [6]:
import torch


def model_fn(batch, model, criterion, device, training=False):
    """Forward a batch through the model."""

    mels, labels = batch
    mels = mels.to(device)
    labels = labels.to(device)

    outs, raw = model(mels, labels)

    if training:
        loss = criterion(outs, labels)
    else:
        loss = criterion(raw, labels)

    # Get the speaker id with highest probability.
    preds = raw.argmax(1)
    # Compute accuracy.
    accuracy = torch.mean((preds == labels).float())

    return loss, accuracy

# Validate
- Calculate accuracy of the validation set.

In [7]:
from tqdm import tqdm
import torch


def valid(dataloader, model, criterion, device): 
    """Validate on validation set."""

    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0
    pbar = tqdm(total=len(dataloader.dataset), ncols=0, desc="Valid", unit=" uttr")

    for i, batch in enumerate(dataloader):
        with torch.no_grad():
            loss, accuracy = model_fn(batch, model, criterion, device)
            running_loss += loss.item()
            running_accuracy += accuracy.item()

        pbar.update(dataloader.batch_size)
        pbar.set_postfix(
            loss=f"{running_loss / (i+1):.4f}",
            accuracy=f"{running_accuracy / (i+1):.4f}",
        )

    pbar.close()
    model.train()

    return running_accuracy / len(dataloader)

# Main function

In [8]:
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split


def parse_args():
    """arguments"""
    config = {
        "data_dir": "/kaggle/input/ml2022spring-hw4/Dataset",
        "save_path": "model",
        "batch_size": 32,
        "n_workers": 0,
        "valid_steps": 2000,
        "warmup_steps": 1000,
        "total_steps": 70000,
    }

    return config


def main(
    data_dir,
    save_path,
    batch_size,
    n_workers,
    valid_steps,
    warmup_steps,
    total_steps,
):
    """Main function."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Info]: Use {device} now!")

    train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)
    running_patch = 0
    train_iterator = iter(train_loader[running_patch])
    print(f"[Info]: Finish loading data!",flush = True)

    model = [Classifier(n_spks=speaker_num).to(device) for i in range(kfold)]
    criterion = nn.CrossEntropyLoss()
    optimizer = [AdamW(model[i].parameters(), lr=1e-3) for i in range(kfold)]
    scheduler = [get_cosine_schedule_with_warmup(optimizer[i], warmup_steps, total_steps) for i in range(kfold)]
    print(f"[Info]: Finish creating model!",flush = True)

    best_accuracy = [-1.0 for i in range(kfold)]

    pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")

    for step in range(total_steps):
        # Get data
        try:
            batch = next(train_iterator)
        except StopIteration:
            running_patch = (running_patch + 1) % kfold
            train_iterator = iter(train_loader[running_patch])
            batch = next(train_iterator)
        for i in range(kfold):
            if i == running_patch:
                continue
            loss, accuracy = model_fn(batch, model[i], criterion, device, training=True)
            batch_loss = loss.item()
            batch_accuracy = accuracy.item()
            loss.backward()
            optimizer[i].step()
            scheduler[i].step()
            optimizer[i].zero_grad()

        # Log
        pbar.update()
        pbar.set_postfix(
            loss=f"{batch_loss:.2f}",
            accuracy=f"{batch_accuracy:.2f}",
            step=step + 1,
            running_patch=running_patch
        )

        # Do validation
        if (step + 1) % valid_steps == 0:
            pbar.close()

            for i in range(kfold):
                valid_accuracy = valid(valid_loader[i], model[i], criterion, device)

                # keep the best model
                if valid_accuracy > best_accuracy[i]:
                    best_accuracy[i] = valid_accuracy
                    torch.save(model[i].state_dict(), f"{save_path}_{i}.ckpt")
                    
            print(best_accuracy)
            pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")

    pbar.close()


if __name__ == "__main__":
    main(**parse_args())

[Info]: Use cuda now!
dataset configs: segment_len=512, p=0.95
[Info]: Finish loading data!
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: mar

Train: 100% 2000/2000 [11:06<00:00,  3.00 step/s, accuracy=0.53, loss=12.14, running_patch=1, step=2000]
Valid: 100% 7072/7083 [00:07<00:00, 898.71 uttr/s, accuracy=0.4337, loss=2.5714]
Valid: 100% 7072/7083 [00:08<00:00, 876.41 uttr/s, accuracy=0.4907, loss=2.2871]
Valid: 100% 7072/7083 [00:08<00:00, 871.94 uttr/s, accuracy=0.4104, loss=2.6835]
Valid: 100% 7072/7083 [00:08<00:00, 853.57 uttr/s, accuracy=0.4714, loss=2.4406]
Valid: 100% 7072/7083 [00:07<00:00, 894.08 uttr/s, accuracy=0.4941, loss=2.2681]
Valid: 100% 7072/7083 [00:08<00:00, 871.91 uttr/s, accuracy=0.4420, loss=2.5751]
Valid: 100% 7072/7083 [00:08<00:00, 871.21 uttr/s, accuracy=0.4726, loss=2.3596]
Valid: 100% 7072/7085 [00:08<00:00, 875.83 uttr/s, accuracy=0.4641, loss=2.4363]


[0.4336821266968326, 0.4906674208144796, 0.41035067873303166, 0.4714366515837104, 0.4940610859728507, 0.44202488687782804, 0.4725678733031674, 0.46408371040723984]


Train: 100% 2000/2000 [07:37<00:00,  4.38 step/s, accuracy=0.78, loss=8.51, running_patch=2, step=4000] 
Valid: 100% 7072/7083 [00:08<00:00, 858.34 uttr/s, accuracy=0.7040, loss=1.3173]
Valid: 100% 7072/7083 [00:07<00:00, 893.59 uttr/s, accuracy=0.7223, loss=1.2312]
Valid: 100% 7072/7083 [00:07<00:00, 889.08 uttr/s, accuracy=0.6739, loss=1.4515]
Valid: 100% 7072/7083 [00:08<00:00, 861.94 uttr/s, accuracy=0.7188, loss=1.2732]
Valid: 100% 7072/7083 [00:07<00:00, 893.99 uttr/s, accuracy=0.7288, loss=1.2325]
Valid: 100% 7072/7083 [00:08<00:00, 853.81 uttr/s, accuracy=0.6973, loss=1.3969]
Valid: 100% 7072/7083 [00:08<00:00, 880.65 uttr/s, accuracy=0.7149, loss=1.2646]
Valid: 100% 7072/7085 [00:08<00:00, 830.83 uttr/s, accuracy=0.7052, loss=1.3325]


[0.7040441176470589, 0.7222850678733032, 0.6739253393665159, 0.71875, 0.728789592760181, 0.6972567873303167, 0.7149321266968326, 0.7051753393665159]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=0.69, loss=8.82, running_patch=3, step=6000] 
Valid: 100% 7072/7083 [00:09<00:00, 726.79 uttr/s, accuracy=0.7896, loss=0.9562]
Valid: 100% 7072/7083 [00:09<00:00, 721.34 uttr/s, accuracy=0.8042, loss=0.8702]
Valid: 100% 7072/7083 [00:10<00:00, 651.39 uttr/s, accuracy=0.7723, loss=1.0274]
Valid: 100% 7072/7083 [00:08<00:00, 882.56 uttr/s, accuracy=0.7834, loss=0.9683]
Valid: 100% 7072/7083 [00:10<00:00, 656.31 uttr/s, accuracy=0.7957, loss=0.9168]
Valid: 100% 7072/7083 [00:09<00:00, 766.61 uttr/s, accuracy=0.7791, loss=1.0088]
Valid: 100% 7072/7083 [00:10<00:00, 676.50 uttr/s, accuracy=0.7999, loss=0.8838]
Valid: 100% 7072/7085 [00:12<00:00, 571.92 uttr/s, accuracy=0.7787, loss=1.0140]


[0.7895927601809954, 0.8041572398190046, 0.772341628959276, 0.7833710407239819, 0.7956730769230769, 0.7791289592760181, 0.7999151583710408, 0.7787047511312217]


Train: 100% 2000/2000 [07:47<00:00,  4.28 step/s, accuracy=0.84, loss=5.98, running_patch=4, step=8000]
Valid: 100% 7072/7083 [00:08<00:00, 794.85 uttr/s, accuracy=0.8363, loss=0.7347]
Valid: 100% 7072/7083 [00:08<00:00, 872.71 uttr/s, accuracy=0.8531, loss=0.6647]
Valid: 100% 7072/7083 [00:08<00:00, 881.90 uttr/s, accuracy=0.8326, loss=0.7724]
Valid: 100% 7072/7083 [00:07<00:00, 899.98 uttr/s, accuracy=0.8320, loss=0.7900]
Valid: 100% 7072/7083 [00:08<00:00, 811.00 uttr/s, accuracy=0.8396, loss=0.7463]
Valid: 100% 7072/7083 [00:08<00:00, 880.11 uttr/s, accuracy=0.8241, loss=0.8221]
Valid: 100% 7072/7083 [00:08<00:00, 877.07 uttr/s, accuracy=0.8507, loss=0.6644]
Valid: 100% 7072/7085 [00:07<00:00, 892.96 uttr/s, accuracy=0.8347, loss=0.7659]


[0.8362556561085973, 0.8530825791855203, 0.832579185520362, 0.8320135746606335, 0.8396493212669683, 0.8240950226244343, 0.8506787330316742, 0.8347002262443439]


Train: 100% 2000/2000 [07:41<00:00,  4.33 step/s, accuracy=0.91, loss=4.35, running_patch=5, step=1e+4]
Valid: 100% 7072/7083 [00:08<00:00, 872.26 uttr/s, accuracy=0.8676, loss=0.5824]
Valid: 100% 7072/7083 [00:07<00:00, 909.58 uttr/s, accuracy=0.8818, loss=0.5546]
Valid: 100% 7072/7083 [00:08<00:00, 851.93 uttr/s, accuracy=0.8559, loss=0.6688]
Valid: 100% 7072/7083 [00:08<00:00, 879.09 uttr/s, accuracy=0.8616, loss=0.6456]
Valid: 100% 7072/7083 [00:07<00:00, 908.26 uttr/s, accuracy=0.8524, loss=0.6970]
Valid: 100% 7072/7083 [00:07<00:00, 903.91 uttr/s, accuracy=0.8558, loss=0.6906]
Valid: 100% 7072/7083 [00:08<00:00, 883.04 uttr/s, accuracy=0.8743, loss=0.5884]
Valid: 100% 7072/7085 [00:07<00:00, 888.82 uttr/s, accuracy=0.8641, loss=0.6402]


[0.8676470588235294, 0.8817873303167421, 0.8559106334841629, 0.861566742081448, 0.8523755656108597, 0.8557692307692307, 0.8742929864253394, 0.8641119909502263]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=0.88, loss=5.57, running_patch=6, step=12000]
Valid: 100% 7072/7083 [00:08<00:00, 873.63 uttr/s, accuracy=0.8839, loss=0.5303]
Valid: 100% 7072/7083 [00:08<00:00, 855.69 uttr/s, accuracy=0.8959, loss=0.4952]
Valid: 100% 7072/7083 [00:08<00:00, 864.60 uttr/s, accuracy=0.8790, loss=0.5680]
Valid: 100% 7072/7083 [00:07<00:00, 891.04 uttr/s, accuracy=0.8763, loss=0.5968]
Valid: 100% 7072/7083 [00:08<00:00, 849.51 uttr/s, accuracy=0.8794, loss=0.5677]
Valid: 100% 7072/7083 [00:08<00:00, 822.77 uttr/s, accuracy=0.8703, loss=0.6267]
Valid: 100% 7072/7083 [00:07<00:00, 901.88 uttr/s, accuracy=0.8901, loss=0.5089]
Valid: 100% 7072/7085 [00:07<00:00, 894.55 uttr/s, accuracy=0.8792, loss=0.5622]


[0.883908371040724, 0.8959276018099548, 0.8789592760180995, 0.8762726244343891, 0.8793834841628959, 0.8703337104072398, 0.8901300904977375, 0.8792420814479638]


Train: 100% 2000/2000 [07:37<00:00,  4.38 step/s, accuracy=0.94, loss=2.99, running_patch=7, step=14000]
Valid: 100% 7072/7083 [00:07<00:00, 902.34 uttr/s, accuracy=0.8901, loss=0.5309]
Valid: 100% 7072/7083 [00:08<00:00, 879.87 uttr/s, accuracy=0.9048, loss=0.4639]
Valid: 100% 7072/7083 [00:08<00:00, 855.59 uttr/s, accuracy=0.8884, loss=0.5367]
Valid: 100% 7072/7083 [00:08<00:00, 881.49 uttr/s, accuracy=0.9027, loss=0.4697]
Valid: 100% 7072/7085 [00:07<00:00, 884.94 uttr/s, accuracy=0.8894, loss=0.5285]


[0.8901300904977375, 0.9048359728506787, 0.888433257918552, 0.8889988687782805, 0.8881504524886877, 0.8737273755656109, 0.9027149321266968, 0.8894230769230769]


Train: 100% 2000/2000 [07:49<00:00,  4.26 step/s, accuracy=1.00, loss=1.98, running_patch=0, step=16000]
Valid: 100% 7072/7083 [00:12<00:00, 588.59 uttr/s, accuracy=0.9005, loss=0.4911]
Valid: 100% 7072/7083 [00:09<00:00, 771.11 uttr/s, accuracy=0.9047, loss=0.4570]
Valid: 100% 7072/7083 [00:07<00:00, 896.30 uttr/s, accuracy=0.8956, loss=0.4961]
Valid: 100% 7072/7083 [00:08<00:00, 870.56 uttr/s, accuracy=0.8928, loss=0.5535]
Valid: 100% 7072/7083 [00:12<00:00, 576.98 uttr/s, accuracy=0.9054, loss=0.4801]
Valid: 100% 7072/7083 [00:08<00:00, 817.29 uttr/s, accuracy=0.8863, loss=0.5699]
Valid: 100% 7072/7083 [00:08<00:00, 875.10 uttr/s, accuracy=0.9118, loss=0.4279]
Valid: 100% 7072/7085 [00:08<00:00, 876.87 uttr/s, accuracy=0.9021, loss=0.4879]


[0.9004524886877828, 0.9048359728506787, 0.8956447963800905, 0.892816742081448, 0.9054015837104072, 0.8863122171945701, 0.9117647058823529, 0.9021493212669683]


Train: 100% 2000/2000 [07:36<00:00,  4.38 step/s, accuracy=1.00, loss=2.36, running_patch=1, step=18000]
Valid: 100% 7072/7083 [00:08<00:00, 862.18 uttr/s, accuracy=0.9094, loss=0.4407]
Valid: 100% 7072/7083 [00:07<00:00, 891.51 uttr/s, accuracy=0.9156, loss=0.4326]
Valid: 100% 7072/7083 [00:07<00:00, 884.21 uttr/s, accuracy=0.9062, loss=0.4476]
Valid: 100% 7072/7083 [00:07<00:00, 911.97 uttr/s, accuracy=0.9023, loss=0.5008]
Valid: 100% 7072/7083 [00:07<00:00, 890.08 uttr/s, accuracy=0.9071, loss=0.4789]
Valid: 100% 7072/7083 [00:07<00:00, 906.39 uttr/s, accuracy=0.8979, loss=0.5256]
Valid: 100% 7072/7083 [00:07<00:00, 888.60 uttr/s, accuracy=0.9133, loss=0.4276]
Valid: 100% 7072/7085 [00:07<00:00, 904.23 uttr/s, accuracy=0.9081, loss=0.4610]


[0.9093608597285068, 0.9155825791855203, 0.90625, 0.9022907239819005, 0.9070984162895928, 0.8979072398190046, 0.9133201357466063, 0.9080882352941176]


Train: 100% 2000/2000 [07:36<00:00,  4.38 step/s, accuracy=0.97, loss=1.86, running_patch=2, step=2e+4] 
Valid: 100% 7072/7083 [00:07<00:00, 925.25 uttr/s, accuracy=0.9123, loss=0.4440]
Valid: 100% 7072/7083 [00:08<00:00, 880.07 uttr/s, accuracy=0.9156, loss=0.4137]
Valid: 100% 7072/7083 [00:08<00:00, 843.58 uttr/s, accuracy=0.9000, loss=0.4653]
Valid: 100% 7072/7083 [00:08<00:00, 863.58 uttr/s, accuracy=0.9089, loss=0.4643]
Valid: 100% 7072/7083 [00:07<00:00, 902.70 uttr/s, accuracy=0.9125, loss=0.4618]
Valid: 100% 7072/7083 [00:08<00:00, 879.75 uttr/s, accuracy=0.9096, loss=0.4682]
Valid: 100% 7072/7083 [00:07<00:00, 893.99 uttr/s, accuracy=0.9211, loss=0.4017]
Valid: 100% 7072/7085 [00:08<00:00, 868.39 uttr/s, accuracy=0.9115, loss=0.4519]


[0.9123303167420814, 0.9155825791855203, 0.90625, 0.9089366515837104, 0.9124717194570136, 0.909643665158371, 0.9210972850678733, 0.9114819004524887]


Train: 100% 2000/2000 [07:36<00:00,  4.38 step/s, accuracy=0.94, loss=2.20, running_patch=3, step=22000]
Valid: 100% 7072/7083 [00:07<00:00, 898.98 uttr/s, accuracy=0.9200, loss=0.4094]
Valid: 100% 7072/7083 [00:08<00:00, 873.41 uttr/s, accuracy=0.9208, loss=0.3971]
Valid: 100% 7072/7083 [00:07<00:00, 900.12 uttr/s, accuracy=0.9149, loss=0.4272]
Valid: 100% 7072/7083 [00:07<00:00, 902.09 uttr/s, accuracy=0.9023, loss=0.5056]
Valid: 100% 7072/7083 [00:08<00:00, 800.18 uttr/s, accuracy=0.9156, loss=0.4514]
Valid: 100% 7072/7083 [00:08<00:00, 882.91 uttr/s, accuracy=0.9089, loss=0.4613]
Valid: 100% 7072/7083 [00:07<00:00, 886.81 uttr/s, accuracy=0.9265, loss=0.3768]
Valid: 100% 7072/7085 [00:08<00:00, 878.31 uttr/s, accuracy=0.9181, loss=0.4357]


[0.9199660633484162, 0.920814479638009, 0.9148755656108597, 0.9089366515837104, 0.9155825791855203, 0.909643665158371, 0.9264705882352942, 0.9181278280542986]


Train: 100% 2000/2000 [07:44<00:00,  4.31 step/s, accuracy=1.00, loss=2.30, running_patch=4, step=24000]
Valid: 100% 7072/7083 [00:08<00:00, 865.33 uttr/s, accuracy=0.9186, loss=0.4281]
Valid: 100% 7072/7083 [00:09<00:00, 748.16 uttr/s, accuracy=0.9225, loss=0.3906]
Valid: 100% 7072/7083 [00:09<00:00, 749.83 uttr/s, accuracy=0.9136, loss=0.4250]
Valid: 100% 7072/7083 [00:08<00:00, 864.93 uttr/s, accuracy=0.9068, loss=0.4694]
Valid: 100% 7072/7083 [00:08<00:00, 857.27 uttr/s, accuracy=0.9142, loss=0.4577]
Valid: 100% 7072/7083 [00:08<00:00, 831.15 uttr/s, accuracy=0.9123, loss=0.4585]
Valid: 100% 7072/7083 [00:09<00:00, 754.40 uttr/s, accuracy=0.9280, loss=0.3795]
Valid: 100% 7072/7085 [00:10<00:00, 680.57 uttr/s, accuracy=0.9195, loss=0.4202]


[0.9199660633484162, 0.9225113122171946, 0.9148755656108597, 0.9089366515837104, 0.9155825791855203, 0.9123303167420814, 0.9280260180995475, 0.9195418552036199]


Train: 100% 2000/2000 [07:40<00:00,  4.34 step/s, accuracy=1.00, loss=1.69, running_patch=5, step=26000]
Valid: 100% 7072/7083 [00:08<00:00, 865.43 uttr/s, accuracy=0.9229, loss=0.4051]
Valid: 100% 7072/7083 [00:08<00:00, 802.90 uttr/s, accuracy=0.9299, loss=0.3639]
Valid: 100% 7072/7083 [00:07<00:00, 886.91 uttr/s, accuracy=0.9169, loss=0.4102]
Valid: 100% 7072/7083 [00:07<00:00, 899.68 uttr/s, accuracy=0.9214, loss=0.4175]
Valid: 100% 7072/7083 [00:07<00:00, 895.74 uttr/s, accuracy=0.9210, loss=0.4372]
Valid: 100% 7072/7083 [00:07<00:00, 902.31 uttr/s, accuracy=0.9170, loss=0.4454]
Valid: 100% 7072/7083 [00:08<00:00, 878.17 uttr/s, accuracy=0.9340, loss=0.3545]
Valid: 100% 7072/7085 [00:07<00:00, 913.33 uttr/s, accuracy=0.9294, loss=0.3861]


[0.922935520361991, 0.9298642533936652, 0.9168552036199095, 0.9213800904977375, 0.9209558823529411, 0.9169966063348416, 0.9339649321266968, 0.9294400452488688]


Valid: 100% 7072/7085 [00:08<00:00, 875.62 uttr/s, accuracy=0.9255, loss=0.4093]ing_patch=5, step=27755]


[0.928591628959276, 0.9360859728506787, 0.9199660633484162, 0.9213800904977375, 0.9209558823529411, 0.9169966063348416, 0.9339649321266968, 0.9294400452488688]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=0.97, loss=1.85, running_patch=7, step=3e+4] 
Valid: 100% 7072/7083 [00:08<00:00, 872.64 uttr/s, accuracy=0.9296, loss=0.3828]
Valid: 100% 7072/7083 [00:08<00:00, 869.57 uttr/s, accuracy=0.9314, loss=0.3880]
Valid: 100% 7072/7083 [00:07<00:00, 901.72 uttr/s, accuracy=0.9290, loss=0.3795]
Valid: 100% 7072/7083 [00:08<00:00, 870.69 uttr/s, accuracy=0.9163, loss=0.4635]
Valid: 100% 7072/7083 [00:08<00:00, 856.36 uttr/s, accuracy=0.9253, loss=0.4236]
Valid: 100% 7072/7083 [00:08<00:00, 880.45 uttr/s, accuracy=0.9183, loss=0.4478]
Valid: 100% 7072/7083 [00:08<00:00, 877.27 uttr/s, accuracy=0.9347, loss=0.3562]
Valid: 100% 7072/7085 [00:08<00:00, 819.66 uttr/s, accuracy=0.9306, loss=0.3992]


[0.9295814479638009, 0.9360859728506787, 0.9290158371040724, 0.9213800904977375, 0.9253393665158371, 0.9182692307692307, 0.9346719457013575, 0.9305712669683258]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=1.00, loss=0.91, running_patch=0, step=32000]
Valid: 100% 7072/7083 [00:07<00:00, 905.91 uttr/s, accuracy=0.9316, loss=0.3732]
Valid: 100% 7072/7083 [00:07<00:00, 884.27 uttr/s, accuracy=0.9352, loss=0.3869]
Valid: 100% 7072/7083 [00:07<00:00, 897.06 uttr/s, accuracy=0.9286, loss=0.3834]
Valid: 100% 7072/7083 [00:08<00:00, 877.20 uttr/s, accuracy=0.9227, loss=0.4298]
Valid: 100% 7072/7083 [00:07<00:00, 909.02 uttr/s, accuracy=0.9260, loss=0.4256]
Valid: 100% 7072/7083 [00:07<00:00, 902.31 uttr/s, accuracy=0.9272, loss=0.4225]
Valid: 100% 7072/7083 [00:08<00:00, 826.13 uttr/s, accuracy=0.9389, loss=0.3455]
Valid: 100% 7072/7085 [00:08<00:00, 857.44 uttr/s, accuracy=0.9253, loss=0.4233]


[0.9315610859728507, 0.9360859728506787, 0.9290158371040724, 0.9226527149321267, 0.9260463800904978, 0.9271776018099548, 0.9389140271493213, 0.9305712669683258]


Train: 100% 2000/2000 [07:43<00:00,  4.31 step/s, accuracy=0.97, loss=1.17, running_patch=1, step=34000]
Valid: 100% 7072/7083 [00:08<00:00, 866.28 uttr/s, accuracy=0.9361, loss=0.3672]
Valid: 100% 7072/7083 [00:07<00:00, 897.77 uttr/s, accuracy=0.9345, loss=0.3734]
Valid: 100% 7072/7083 [00:07<00:00, 890.08 uttr/s, accuracy=0.9303, loss=0.3822]
Valid: 100% 7072/7083 [00:08<00:00, 876.78 uttr/s, accuracy=0.9280, loss=0.4129]
Valid: 100% 7072/7083 [00:08<00:00, 882.50 uttr/s, accuracy=0.9260, loss=0.4228]
Valid: 100% 7072/7083 [00:09<00:00, 779.02 uttr/s, accuracy=0.9333, loss=0.4083]
Valid: 100% 7072/7083 [00:07<00:00, 902.93 uttr/s, accuracy=0.9385, loss=0.3586]
Valid: 100% 7072/7085 [00:08<00:00, 871.61 uttr/s, accuracy=0.9307, loss=0.3953]


[0.9360859728506787, 0.9360859728506787, 0.9302884615384616, 0.9280260180995475, 0.9260463800904978, 0.9332579185520362, 0.9389140271493213, 0.9307126696832579]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=1.00, loss=0.58, running_patch=2, step=36000]
Valid: 100% 7072/7083 [00:08<00:00, 883.66 uttr/s, accuracy=0.9362, loss=0.3668]
Valid: 100% 7072/7083 [00:07<00:00, 885.34 uttr/s, accuracy=0.9359, loss=0.3742]
Valid: 100% 7072/7083 [00:08<00:00, 846.97 uttr/s, accuracy=0.9324, loss=0.3619]
Valid: 100% 7072/7083 [00:08<00:00, 836.63 uttr/s, accuracy=0.9292, loss=0.3978]
Valid: 100% 7072/7083 [00:08<00:00, 850.31 uttr/s, accuracy=0.9263, loss=0.4294]
Valid: 100% 7072/7083 [00:08<00:00, 830.36 uttr/s, accuracy=0.9310, loss=0.3955]
Valid: 100% 7072/7083 [00:08<00:00, 823.80 uttr/s, accuracy=0.9400, loss=0.3616]
Valid: 100% 7072/7085 [00:09<00:00, 712.87 uttr/s, accuracy=0.9340, loss=0.3820]


[0.9362273755656109, 0.9360859728506787, 0.9324095022624435, 0.9291572398190046, 0.926329185520362, 0.9332579185520362, 0.9400452488687783, 0.9339649321266968]


Train: 100% 2000/2000 [07:48<00:00,  4.27 step/s, accuracy=1.00, loss=0.47, running_patch=3, step=38000]
Valid: 100% 7072/7083 [00:08<00:00, 869.60 uttr/s, accuracy=0.9365, loss=0.3580]
Valid: 100% 7072/7083 [00:08<00:00, 879.54 uttr/s, accuracy=0.9350, loss=0.3949]
Valid: 100% 7072/7083 [00:08<00:00, 845.47 uttr/s, accuracy=0.9376, loss=0.3500]
Valid: 100% 7072/7083 [00:08<00:00, 807.30 uttr/s, accuracy=0.9330, loss=0.4012]
Valid: 100% 7072/7083 [00:11<00:00, 609.50 uttr/s, accuracy=0.9352, loss=0.3902]
Valid: 100% 7072/7083 [00:08<00:00, 807.97 uttr/s, accuracy=0.9344, loss=0.3936]
Valid: 100% 7072/7083 [00:08<00:00, 828.87 uttr/s, accuracy=0.9419, loss=0.3542]
Valid: 100% 7072/7085 [00:08<00:00, 833.28 uttr/s, accuracy=0.9310, loss=0.4042]


[0.9365101809954751, 0.9360859728506787, 0.9376414027149321, 0.932975113122172, 0.935237556561086, 0.9343891402714932, 0.9418834841628959, 0.9339649321266968]


Train: 100% 2000/2000 [08:09<00:00,  4.09 step/s, accuracy=1.00, loss=0.68, running_patch=4, step=4e+4] 
Valid: 100% 7072/7083 [00:08<00:00, 834.11 uttr/s, accuracy=0.9345, loss=0.3853]
Valid: 100% 7072/7083 [00:08<00:00, 870.31 uttr/s, accuracy=0.9416, loss=0.3626]
Valid: 100% 7072/7083 [00:08<00:00, 877.61 uttr/s, accuracy=0.9364, loss=0.3761]
Valid: 100% 7072/7083 [00:08<00:00, 810.61 uttr/s, accuracy=0.9317, loss=0.3824]
Valid: 100% 7072/7083 [00:08<00:00, 859.02 uttr/s, accuracy=0.9331, loss=0.4092]
Valid: 100% 7072/7083 [00:15<00:00, 444.18 uttr/s, accuracy=0.9293, loss=0.4187]
Valid: 100% 7072/7083 [00:09<00:00, 708.03 uttr/s, accuracy=0.9432, loss=0.3638]
Valid: 100% 7072/7085 [00:08<00:00, 840.01 uttr/s, accuracy=0.9348, loss=0.3735]


[0.9365101809954751, 0.9416006787330317, 0.9376414027149321, 0.932975113122172, 0.935237556561086, 0.9343891402714932, 0.943156108597285, 0.9348133484162896]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=0.97, loss=1.08, running_patch=6, step=42000]
Valid: 100% 7072/7083 [00:08<00:00, 823.14 uttr/s, accuracy=0.9409, loss=0.3684]
Valid: 100% 7072/7083 [00:09<00:00, 762.06 uttr/s, accuracy=0.9417, loss=0.3615]
Valid: 100% 7072/7083 [00:08<00:00, 800.79 uttr/s, accuracy=0.9379, loss=0.3430]
Valid: 100% 7072/7083 [00:08<00:00, 792.63 uttr/s, accuracy=0.9344, loss=0.3902]
Valid: 100% 7072/7083 [00:08<00:00, 855.87 uttr/s, accuracy=0.9320, loss=0.4204]
Valid: 100% 7072/7083 [00:07<00:00, 892.68 uttr/s, accuracy=0.9331, loss=0.4147]
Valid: 100% 7072/7083 [00:08<00:00, 879.22 uttr/s, accuracy=0.9446, loss=0.3529]
Valid: 100% 7072/7085 [00:08<00:00, 854.65 uttr/s, accuracy=0.9409, loss=0.3506]


[0.940893665158371, 0.9417420814479638, 0.9379242081447964, 0.9343891402714932, 0.935237556561086, 0.9343891402714932, 0.9445701357466063, 0.940893665158371]


Train: 100% 2000/2000 [07:35<00:00,  4.39 step/s, accuracy=0.97, loss=0.76, running_patch=7, step=44000]
Valid: 100% 7072/7083 [00:07<00:00, 899.85 uttr/s, accuracy=0.9422, loss=0.3544]
Valid: 100% 7072/7083 [00:08<00:00, 823.17 uttr/s, accuracy=0.9456, loss=0.3456]
Valid: 100% 7072/7083 [00:07<00:00, 887.92 uttr/s, accuracy=0.9423, loss=0.3436]
Valid: 100% 7072/7083 [00:08<00:00, 872.99 uttr/s, accuracy=0.9344, loss=0.4095]
Valid: 100% 7072/7083 [00:07<00:00, 905.92 uttr/s, accuracy=0.9365, loss=0.4155]
Valid: 100% 7072/7083 [00:08<00:00, 880.84 uttr/s, accuracy=0.9361, loss=0.4158]
Valid: 100% 7072/7083 [00:07<00:00, 898.36 uttr/s, accuracy=0.9444, loss=0.3631]
Valid: 100% 7072/7085 [00:08<00:00, 878.13 uttr/s, accuracy=0.9389, loss=0.3716]


[0.9421662895927602, 0.9455599547511312, 0.9423076923076923, 0.9343891402714932, 0.9365101809954751, 0.9360859728506787, 0.9445701357466063, 0.940893665158371]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=1.00, loss=0.23, running_patch=0, step=46000]
Valid: 100% 7072/7083 [00:08<00:00, 874.93 uttr/s, accuracy=0.9447, loss=0.3406]
Valid: 100% 7072/7083 [00:08<00:00, 857.39 uttr/s, accuracy=0.9419, loss=0.3661]
Valid: 100% 7072/7083 [00:07<00:00, 893.19 uttr/s, accuracy=0.9439, loss=0.3384]
Valid: 100% 7072/7083 [00:08<00:00, 872.63 uttr/s, accuracy=0.9365, loss=0.4026]
Valid: 100% 7072/7083 [00:07<00:00, 900.23 uttr/s, accuracy=0.9388, loss=0.4154]
Valid: 100% 7072/7083 [00:08<00:00, 832.78 uttr/s, accuracy=0.9385, loss=0.4031]
Valid: 100% 7072/7083 [00:08<00:00, 855.25 uttr/s, accuracy=0.9454, loss=0.3556]
Valid: 100% 7072/7085 [00:08<00:00, 827.71 uttr/s, accuracy=0.9400, loss=0.3899]


[0.9447115384615384, 0.9455599547511312, 0.9438631221719457, 0.9365101809954751, 0.9387726244343891, 0.9384898190045249, 0.9454185520361991, 0.940893665158371]


Train: 100% 2000/2000 [07:40<00:00,  4.35 step/s, accuracy=1.00, loss=0.52, running_patch=1, step=48000]
Valid: 100% 7072/7083 [00:07<00:00, 911.72 uttr/s, accuracy=0.9454, loss=0.3464]
Valid: 100% 7072/7083 [00:07<00:00, 884.55 uttr/s, accuracy=0.9426, loss=0.3792]
Valid: 100% 7072/7083 [00:08<00:00, 865.96 uttr/s, accuracy=0.9444, loss=0.3374]
Valid: 100% 7072/7083 [00:08<00:00, 871.80 uttr/s, accuracy=0.9395, loss=0.3901]
Valid: 100% 7072/7083 [00:07<00:00, 906.80 uttr/s, accuracy=0.9399, loss=0.3856]
Valid: 100% 7072/7083 [00:07<00:00, 894.65 uttr/s, accuracy=0.9359, loss=0.3967]
Valid: 100% 7072/7083 [00:07<00:00, 894.30 uttr/s, accuracy=0.9463, loss=0.3468]
Valid: 100% 7072/7085 [00:08<00:00, 868.21 uttr/s, accuracy=0.9430, loss=0.3694]


[0.9454185520361991, 0.9455599547511312, 0.9444287330316742, 0.9394796380090498, 0.9399038461538461, 0.9384898190045249, 0.9462669683257918, 0.9430147058823529]


Train: 100% 2000/2000 [07:36<00:00,  4.38 step/s, accuracy=1.00, loss=0.56, running_patch=2, step=5e+4] 
Valid: 100% 7072/7083 [00:07<00:00, 909.39 uttr/s, accuracy=0.9453, loss=0.3531]
Valid: 100% 7072/7083 [00:08<00:00, 880.80 uttr/s, accuracy=0.9440, loss=0.3489]
Valid: 100% 7072/7083 [00:08<00:00, 865.91 uttr/s, accuracy=0.9427, loss=0.3497]
Valid: 100% 7072/7083 [00:07<00:00, 900.81 uttr/s, accuracy=0.9415, loss=0.3842]
Valid: 100% 7072/7083 [00:07<00:00, 891.48 uttr/s, accuracy=0.9427, loss=0.3857]
Valid: 100% 7072/7083 [00:08<00:00, 881.23 uttr/s, accuracy=0.9419, loss=0.3912]
Valid: 100% 7072/7083 [00:08<00:00, 882.00 uttr/s, accuracy=0.9501, loss=0.3327]
Valid: 100% 7072/7085 [00:07<00:00, 901.21 uttr/s, accuracy=0.9430, loss=0.3666]


[0.9454185520361991, 0.9455599547511312, 0.9444287330316742, 0.9414592760180995, 0.9427319004524887, 0.9418834841628959, 0.9500848416289592, 0.9430147058823529]


Train: 100% 2000/2000 [07:35<00:00,  4.39 step/s, accuracy=1.00, loss=0.10, running_patch=3, step=52000]
Valid: 100% 7072/7083 [00:08<00:00, 862.52 uttr/s, accuracy=0.9443, loss=0.3435]
Valid: 100% 7072/7083 [00:07<00:00, 906.28 uttr/s, accuracy=0.9487, loss=0.3485]
Valid: 100% 7072/7083 [00:07<00:00, 899.40 uttr/s, accuracy=0.9456, loss=0.3256]
Valid: 100% 7072/7083 [00:07<00:00, 898.37 uttr/s, accuracy=0.9367, loss=0.3902]
Valid: 100% 7072/7083 [00:08<00:00, 847.07 uttr/s, accuracy=0.9408, loss=0.3818]
Valid: 100% 7072/7083 [00:07<00:00, 898.26 uttr/s, accuracy=0.9420, loss=0.3732]
Valid: 100% 7072/7083 [00:07<00:00, 902.76 uttr/s, accuracy=0.9463, loss=0.3441]
Valid: 100% 7072/7085 [00:08<00:00, 883.58 uttr/s, accuracy=0.9424, loss=0.3770]


[0.9454185520361991, 0.948670814479638, 0.9455599547511312, 0.9414592760180995, 0.9427319004524887, 0.942024886877828, 0.9500848416289592, 0.9430147058823529]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=1.00, loss=0.66, running_patch=4, step=54000]
Valid: 100% 7072/7083 [00:07<00:00, 905.43 uttr/s, accuracy=0.9475, loss=0.3274]
Valid: 100% 7072/7083 [00:07<00:00, 890.94 uttr/s, accuracy=0.9504, loss=0.3374]
Valid: 100% 7072/7083 [00:08<00:00, 849.69 uttr/s, accuracy=0.9474, loss=0.3332]
Valid: 100% 7072/7083 [00:07<00:00, 909.68 uttr/s, accuracy=0.9430, loss=0.3670]
Valid: 100% 7072/7083 [00:07<00:00, 890.36 uttr/s, accuracy=0.9441, loss=0.3727]
Valid: 100% 7072/7083 [00:08<00:00, 870.99 uttr/s, accuracy=0.9412, loss=0.4050]
Valid: 100% 7072/7083 [00:08<00:00, 862.55 uttr/s, accuracy=0.9492, loss=0.3561]
Valid: 100% 7072/7085 [00:08<00:00, 860.35 uttr/s, accuracy=0.9457, loss=0.3507]


[0.947539592760181, 0.9503676470588235, 0.9473981900452488, 0.9430147058823529, 0.9441459276018099, 0.942024886877828, 0.9500848416289592, 0.9457013574660633]


Train: 100% 2000/2000 [07:35<00:00,  4.39 step/s, accuracy=1.00, loss=0.04, running_patch=5, step=56000]
Valid: 100% 7072/7083 [00:07<00:00, 923.72 uttr/s, accuracy=0.9446, loss=0.3565]
Valid: 100% 7072/7083 [00:08<00:00, 868.68 uttr/s, accuracy=0.9538, loss=0.3247]
Valid: 100% 7072/7083 [00:07<00:00, 889.56 uttr/s, accuracy=0.9474, loss=0.3295]
Valid: 100% 7072/7083 [00:07<00:00, 901.58 uttr/s, accuracy=0.9439, loss=0.3676]
Valid: 100% 7072/7083 [00:07<00:00, 906.83 uttr/s, accuracy=0.9461, loss=0.3807]
Valid: 100% 7072/7083 [00:08<00:00, 864.21 uttr/s, accuracy=0.9412, loss=0.3849]
Valid: 100% 7072/7083 [00:07<00:00, 900.84 uttr/s, accuracy=0.9504, loss=0.3264]
Valid: 100% 7072/7085 [00:07<00:00, 905.57 uttr/s, accuracy=0.9456, loss=0.3651]


[0.947539592760181, 0.9537613122171946, 0.9473981900452488, 0.9438631221719457, 0.9461255656108597, 0.942024886877828, 0.9503676470588235, 0.9457013574660633]


Train: 100% 2000/2000 [07:36<00:00,  4.38 step/s, accuracy=1.00, loss=0.04, running_patch=6, step=58000]
Valid: 100% 7072/7083 [00:07<00:00, 885.86 uttr/s, accuracy=0.9488, loss=0.3343]
Valid: 100% 7072/7083 [00:07<00:00, 906.96 uttr/s, accuracy=0.9501, loss=0.3431]
Valid: 100% 7072/7083 [00:07<00:00, 895.11 uttr/s, accuracy=0.9458, loss=0.3244]
Valid: 100% 7072/7083 [00:08<00:00, 870.60 uttr/s, accuracy=0.9457, loss=0.3681]
Valid: 100% 7072/7083 [00:08<00:00, 807.58 uttr/s, accuracy=0.9429, loss=0.3888]
Valid: 100% 7072/7083 [00:09<00:00, 785.61 uttr/s, accuracy=0.9426, loss=0.3963]
Valid: 100% 7072/7083 [00:09<00:00, 753.79 uttr/s, accuracy=0.9543, loss=0.3090]
Valid: 100% 7072/7085 [00:08<00:00, 824.40 uttr/s, accuracy=0.9491, loss=0.3504]


[0.9488122171945701, 0.9537613122171946, 0.9473981900452488, 0.9457013574660633, 0.9461255656108597, 0.9425904977375565, 0.9543269230769231, 0.9490950226244343]


Train: 100% 2000/2000 [07:37<00:00,  4.37 step/s, accuracy=1.00, loss=0.02, running_patch=7, step=6e+4] 
Valid: 100% 7072/7083 [00:07<00:00, 930.61 uttr/s, accuracy=0.9485, loss=0.3347]
Valid: 100% 7072/7083 [00:08<00:00, 796.29 uttr/s, accuracy=0.9519, loss=0.3310]
Valid: 100% 7072/7083 [00:07<00:00, 891.93 uttr/s, accuracy=0.9491, loss=0.3188]
Valid: 100% 7072/7083 [00:07<00:00, 911.92 uttr/s, accuracy=0.9490, loss=0.3561]
Valid: 100% 7072/7083 [00:07<00:00, 915.91 uttr/s, accuracy=0.9481, loss=0.3564]
Valid: 100% 7072/7083 [00:08<00:00, 881.82 uttr/s, accuracy=0.9416, loss=0.3971]
Valid: 100% 7072/7083 [00:08<00:00, 826.78 uttr/s, accuracy=0.9526, loss=0.3330]
Valid: 100% 7072/7085 [00:07<00:00, 904.35 uttr/s, accuracy=0.9464, loss=0.3619]


[0.9488122171945701, 0.9537613122171946, 0.9490950226244343, 0.9489536199095022, 0.9481052036199095, 0.9425904977375565, 0.9543269230769231, 0.9490950226244343]


Train: 100% 2000/2000 [07:42<00:00,  4.33 step/s, accuracy=1.00, loss=0.17, running_patch=0, step=62000]
Valid: 100% 7072/7083 [00:11<00:00, 614.34 uttr/s, accuracy=0.9514, loss=0.3280]
Valid: 100% 7072/7083 [00:10<00:00, 682.64 uttr/s, accuracy=0.9522, loss=0.3403]
Valid: 100% 7072/7083 [00:08<00:00, 844.00 uttr/s, accuracy=0.9508, loss=0.3274]
Valid: 100% 7072/7083 [00:09<00:00, 731.65 uttr/s, accuracy=0.9474, loss=0.3637]
Valid: 100% 7072/7083 [00:09<00:00, 720.06 uttr/s, accuracy=0.9492, loss=0.3685]
Valid: 100% 7072/7083 [00:10<00:00, 655.03 uttr/s, accuracy=0.9440, loss=0.3905]
Valid: 100% 7072/7083 [00:08<00:00, 849.37 uttr/s, accuracy=0.9555, loss=0.3228]
Valid: 100% 7072/7085 [00:10<00:00, 695.13 uttr/s, accuracy=0.9471, loss=0.3521]


[0.9513574660633484, 0.9537613122171946, 0.9507918552036199, 0.9489536199095022, 0.9492364253393665, 0.9440045248868778, 0.9554581447963801, 0.9490950226244343]


Train: 100% 2000/2000 [07:47<00:00,  4.28 step/s, accuracy=1.00, loss=0.11, running_patch=1, step=64000]
Valid: 100% 7072/7083 [00:07<00:00, 892.28 uttr/s, accuracy=0.9542, loss=0.3096]
Valid: 100% 7072/7083 [00:08<00:00, 871.94 uttr/s, accuracy=0.9538, loss=0.3360]
Valid: 100% 7072/7083 [00:11<00:00, 638.85 uttr/s, accuracy=0.9484, loss=0.3179]
Valid: 100% 7072/7083 [00:07<00:00, 907.01 uttr/s, accuracy=0.9463, loss=0.3706]
Valid: 100% 7072/7083 [00:07<00:00, 902.54 uttr/s, accuracy=0.9477, loss=0.3739]
Valid: 100% 7072/7083 [00:08<00:00, 828.31 uttr/s, accuracy=0.9432, loss=0.3830]
Valid: 100% 7072/7083 [00:10<00:00, 706.20 uttr/s, accuracy=0.9552, loss=0.3271]
Valid: 100% 7072/7085 [00:08<00:00, 882.65 uttr/s, accuracy=0.9484, loss=0.3494]


[0.954185520361991, 0.9537613122171946, 0.9507918552036199, 0.9489536199095022, 0.9492364253393665, 0.9440045248868778, 0.9554581447963801, 0.9490950226244343]


Train: 100% 2000/2000 [07:38<00:00,  4.37 step/s, accuracy=1.00, loss=0.13, running_patch=2, step=66000]
Valid: 100% 7072/7083 [00:07<00:00, 892.98 uttr/s, accuracy=0.9512, loss=0.3285]
Valid: 100% 7072/7083 [00:08<00:00, 823.31 uttr/s, accuracy=0.9529, loss=0.3326]
Valid: 100% 7072/7083 [00:07<00:00, 916.00 uttr/s, accuracy=0.9509, loss=0.3212]
Valid: 100% 7072/7083 [00:08<00:00, 852.29 uttr/s, accuracy=0.9492, loss=0.3518]
Valid: 100% 7072/7083 [00:08<00:00, 869.01 uttr/s, accuracy=0.9492, loss=0.3506]
Valid: 100% 7072/7083 [00:07<00:00, 915.81 uttr/s, accuracy=0.9468, loss=0.3740]
Valid: 100% 7072/7083 [00:07<00:00, 909.12 uttr/s, accuracy=0.9569, loss=0.3085]
Valid: 100% 7072/7085 [00:08<00:00, 857.02 uttr/s, accuracy=0.9485, loss=0.3489]


[0.954185520361991, 0.9537613122171946, 0.950933257918552, 0.9492364253393665, 0.9492364253393665, 0.9468325791855203, 0.9568721719457014, 0.9490950226244343]


Train: 100% 2000/2000 [07:35<00:00,  4.39 step/s, accuracy=1.00, loss=0.02, running_patch=3, step=68000]
Valid: 100% 7072/7083 [00:07<00:00, 926.35 uttr/s, accuracy=0.9519, loss=0.3182]
Valid: 100% 7072/7083 [00:07<00:00, 921.38 uttr/s, accuracy=0.9522, loss=0.3337]
Valid: 100% 7072/7083 [00:07<00:00, 907.92 uttr/s, accuracy=0.9490, loss=0.3260]
Valid: 100% 7072/7083 [00:07<00:00, 896.48 uttr/s, accuracy=0.9497, loss=0.3451]
Valid: 100% 7072/7083 [00:07<00:00, 920.02 uttr/s, accuracy=0.9488, loss=0.3666]
Valid: 100% 7072/7083 [00:07<00:00, 920.22 uttr/s, accuracy=0.9468, loss=0.3724]
Valid: 100% 7072/7083 [00:07<00:00, 919.66 uttr/s, accuracy=0.9546, loss=0.3224]
Valid: 100% 7072/7085 [00:08<00:00, 869.12 uttr/s, accuracy=0.9502, loss=0.3476]


[0.954185520361991, 0.9537613122171946, 0.950933257918552, 0.9496606334841629, 0.9492364253393665, 0.9468325791855203, 0.9568721719457014, 0.9502262443438914]


Train: 100% 2000/2000 [07:44<00:00,  4.31 step/s, accuracy=1.00, loss=0.08, running_patch=4, step=7e+4] 
Valid: 100% 7072/7083 [00:15<00:00, 449.81 uttr/s, accuracy=0.9491, loss=0.3314]
Valid: 100% 7072/7083 [00:14<00:00, 476.20 uttr/s, accuracy=0.9548, loss=0.3296]
Valid: 100% 7072/7083 [00:16<00:00, 434.84 uttr/s, accuracy=0.9490, loss=0.3336]
Valid: 100% 7072/7083 [00:11<00:00, 593.51 uttr/s, accuracy=0.9482, loss=0.3511]
Valid: 100% 7072/7083 [00:11<00:00, 618.81 uttr/s, accuracy=0.9512, loss=0.3515]
Valid: 100% 7072/7083 [00:16<00:00, 427.09 uttr/s, accuracy=0.9471, loss=0.3707]
Valid: 100% 7072/7083 [00:09<00:00, 756.37 uttr/s, accuracy=0.9559, loss=0.3168]
Valid: 100% 7072/7085 [00:08<00:00, 803.57 uttr/s, accuracy=0.9481, loss=0.3544]


[0.954185520361991, 0.9547511312217195, 0.950933257918552, 0.9496606334841629, 0.9512160633484162, 0.9471153846153846, 0.9568721719457014, 0.9502262443438914]


Train:   0% 0/2000 [00:00<?, ? step/s]


# Inference

## Dataset of inference

In [9]:
import os
import json
import torch
from pathlib import Path
from torch.utils.data import Dataset


class InferenceDataset(Dataset):
    def __init__(self, data_dir):
        testdata_path = Path(data_dir) / "testdata.json"
        metadata = json.load(testdata_path.open())
        self.data_dir = data_dir
        self.data = metadata["utterances"]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        utterance = self.data[index]
        feat_path = utterance["feature_path"]
        mel = torch.load(os.path.join(self.data_dir, feat_path))

        return feat_path, mel


def inference_collate_batch(batch):
    """Collate a batch of data."""
    feat_paths, mels = zip(*batch)

    return feat_paths, torch.stack(mels)

## Main funcrion of Inference

In [10]:
import json
import csv
from pathlib import Path
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader

def parse_args():
    """arguments"""
    config = {
        "data_dir": "/kaggle/input/ml2022spring-hw4/Dataset",
        "model_path": "./model",
        "output_path": "./output.csv",
    }

    return config


def main(
    data_dir,
    model_path,
    output_path,
):
    """Main function."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Info]: Use {device} now!")

    mapping_path = Path(data_dir) / "mapping.json"
    mapping = json.load(mapping_path.open())

    dataset = InferenceDataset(data_dir)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=0,
        collate_fn=inference_collate_batch,
    )
    print(f"[Info]: Finish loading data!",flush = True)

    speaker_num = len(mapping["id2speaker"])
    model = [Classifier(n_spks=speaker_num) for i in range(kfold)]
    for i in range(kfold):
        model[i].load_state_dict(torch.load(f"{model_path}_{i}.ckpt"))
        model[i] = model[i].to(device)
        model[i].eval()
    print(f"[Info]: Finish creating model!",flush = True)

    results = [["Id", "Category"]]
    for feat_paths, mels in tqdm(dataloader):
        with torch.no_grad():
            mels = mels.to(device)
            summation = None
            for md in model:
                outs, _ = md(mels)
                if summation == None:
                    summation = outs
                else:
                    summation += outs
            preds = summation.argmax(1).cpu().numpy()
            for feat_path, pred in zip(feat_paths, preds):
                results.append([feat_path, mapping["id2speaker"][str(pred)]])

    with open(output_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(results)


if __name__ == "__main__":
    main(**parse_args())

[Info]: Use cuda now!
[Info]: Finish loading data!
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
running config: embed=256, dropout=0.3
depwise_conv config: kernel_size=9
AM-Softmax config: margin=0.35, factor=30
[Info]: Finish creati

  0%|          | 0/8000 [00:00<?, ?it/s]