### Library

In [2]:
cd /workdir/github

/workdir/github


In [3]:
import torch
import torchaudio
from torch.utils.data import DataLoader
import torch.nn.functional as F 
from tqdm.notebook import tqdm

import numpy as np
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
torch.set_num_threads(1)
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
from speaker_verification.transforms import Audio_Transforms
from speaker_verification.transforms import Image_Transforms
from speaker_verification.models import Model
from speaker_verification.dataset import SpeakingFacesDataset
from speaker_verification.dataset import TrainDataset
from speaker_verification.dataset import ValidDataset
from speaker_verification.sampler import ProtoSampler
from speaker_verification.loss import PrototypicalLoss
from speaker_verification.train import train_model

from speaker_verification.metrics import EER_
from speaker_verification.metrics import accuracy_

### General Pipeline

In [7]:
# dataset
annotations_file = "/workdir/github/annotations_file_short_joint_cleaned.csv"
dataset_type = "SF"

if dataset_type == "SF":
    path_to_train_dataset = f"/workdir/sf_pv/data_v2"
    path_to_valid_dataset = "/workdir/sf_pv"
    path_to_valid_list = "/workdir/sf_pv/metadata/valid_list_v2.txt"
elif dataset_type == "VX2":
    path_to_train_dataset="/workdir/VoxCeleb2/dev"
    path_to_valid_list="/workdir/VoxCeleb1/metadata/test_list_vc_v2.txt"
    path_to_valid_dataset="/workdir/VoxCeleb1/test"

    
data_type = ['rgb']


# model
library = "timm"
model_name = "resnet34"
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None

# transform
audio_T = None
image_T = None

# sampler
n_batch=10
n_ways=2
n_support=1
n_query=1

# loss
dist_type='squared_euclidean'

# train
num_epochs=1
save_dir='/workdir/results'
exp_name='chern'
wandb=None

In [8]:
if 'wav' in data_type:
    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)
    audio_T = audio_T.transform

if 'rgb' in data_type or 'thr' in data_type:
    image_T = Image_Transforms(model_name=model_name,
                               library=library)

    image_T = image_T.transform         

In [9]:
# Dataset
train_dataset = TrainDataset(annotations_file=annotations_file,
                            path_to_train_dataset=path_to_train_dataset,
                            train_type='train',
                            image_transform=image_T, 
                            audio_transform=audio_T,
                            data_type=data_type,
                            dataset_type=dataset_type)

train_dataloader = DataLoader(dataset=train_dataset,
                        shuffle=True,
                        batch_size=128,
                        num_workers=4)

valid_dataset = ValidDataset(path_to_valid_dataset=path_to_valid_dataset,
                             path_to_valid_list=path_to_valid_list,
                             data_type=data_type,
                             dataset_type=dataset_type,
                             image_transform=image_T, 
                             audio_transform=audio_T)

valid_dataloader = DataLoader(dataset=valid_dataset,
                            batch_size=128)

In [10]:
pretrained_model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

n_classes = len(np.unique(train_dataset.labels))
classification_layer = torch.nn.Linear(embedding_size, n_classes)

model = torch.nn.Sequential()
model.add_module('pretrained_model', pretrained_model)
model.add_module('classification_layer', classification_layer)

model = model.to(device)

rgb data type
timm model is used.


In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.95)

### train one epoch

In [12]:
def train_singe_epoch(model,
                      train_dataloader, 
                      epoch,
                      optimizer,
                      device):

    model.train()
    pbar = tqdm(train_dataloader, desc=f'Train (epoch = {epoch})', leave=False)  

    total_loss = 0
    total_acc = 0
    for batch in pbar:

        data, label = batch
        data = data.to(device)
        label = label.to(device)

        data = model(data)
        loss = F.cross_entropy(data, label)
        pred = torch.argmax(F.softmax(data,dim=1), dim=1)
        accuracy = (pred == label).sum()/len(label) * 100
        
        total_loss += loss.item()
        total_acc += accuracy.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / len(train_dataloader)
    avg_acc = total_acc / len(train_dataloader)

    print()
    print(f"Average train loss: {avg_loss}")
    print(f"Average train accuracy: {avg_acc}")

    return model, avg_loss, avg_acc

In [13]:
model, avg_loss, avg_acc = train_singe_epoch(model,
                      train_dataloader, 
                      1,
                      optimizer,
                      device)

Train (epoch = 1):   0%|          | 0/73 [00:00<?, ?it/s]


Average train loss: 0.9223150528252941
Average train accuracy: 79.82862603174497


In [23]:
for batch in valid_dataloader:
    id1, id2, labels = batch
    if len(data_type) == 1:
        data_id1, _ = id1
        data_id2, _ = id2

        data_id1 = data_id1.to(device)
        data_id2 = data_id2.to(device)

    break

In [24]:
with torch.no_grad():
    id1_out = model.pretrained_model(data_id1)
    id2_out = model.pretrained_model(data_id2)

In [14]:
def evaluate_single_epoch(model,
                        val_dataloader,
                        epoch, 
                        device,
                        data_type):
    model.eval()
    total_eer = 0
    total_accuracy = 0

    pbar = tqdm(val_dataloader, desc=f'Eval (epoch = {epoch})')

    for batch in pbar:

        data_type = sorted(data_type)
        id1, id2, labels = batch

        if len(data_type) == 1:
            data_id1, _ = id1
            data_id2, _ = id2

            data_id1 = data_id1.to(device)
            data_id2 = data_id2.to(device)

        with torch.no_grad():
            # id1_out = model(data_id1)
            # id2_out = model(data_id2)

            id1_out = model.pretrained_model(data_id1)
            id2_out = model.pretrained_model(data_id2)

            cos_sim = F.cosine_similarity(id1_out, id2_out, dim=1)
            eer, scores = EER_(cos_sim, labels)
            accuracy = accuracy_(labels, scores)

            total_eer += eer
            total_accuracy += accuracy
    
    avg_eer = total_eer / len(val_dataloader)
    print("\nAverage val eer: {}".format(avg_eer))

    avg_accuracy = total_accuracy / len(val_dataloader)
    print("\nAverage val accuracy: {}".format(avg_accuracy))

    return model, avg_eer, avg_accuracy

In [15]:
model, avg_eer, avg_accuracy = evaluate_single_epoch(model,
                        valid_dataloader,
                        1, 
                        device,
                        data_type)

Eval (epoch = 1):   0%|          | 0/297 [00:00<?, ?it/s]


Average val eer: 8.192791005291005

Average val accuracy: 91.44645863395863


In [16]:
def train_model(model, 
                train_dataloader, 
                valid_dataloader):

    for epoch in tqdm(range(num_epochs)):
        
        model, train_loss, train_acc = train_singe_epoch(model, 
                                  train_dataloader,
                                  epoch,
                                  optimizer,
                                  device)
        
        model, val_eer, val_acc = evaluate_single_epoch(model, 
                                  valid_dataloader,
                                  epoch,
                                  device,
                                  data_type)

        scheduler.step()
    return model

In [17]:
model = train_model(model, 
                train_dataloader, 
                valid_dataloader)

  0%|          | 0/1 [00:00<?, ?it/s]

Train (epoch = 0):   0%|          | 0/73 [00:00<?, ?it/s]


Average train loss: 0.0807525069797284
Average train accuracy: 97.8167808219178


Eval (epoch = 0):   0%|          | 0/297 [00:00<?, ?it/s]


Average val eer: 6.672378547378548

Average val accuracy: 92.86277958152958
