### Library

In [1]:
cd /workdir/github

/workdir/github


In [23]:
import torch
import torchaudio
from torch.utils.data import DataLoader
import torch.nn.functional as F 
from tqdm.notebook import tqdm

import numpy as np
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
torch.set_num_threads(1)
device = torch.device(f"cuda:2" if torch.cuda.is_available() else "cpu")

In [24]:
from speaker_verification.transforms import Audio_Transforms
from speaker_verification.transforms import Image_Transforms
from speaker_verification.models import Model
from speaker_verification.dataset import SpeakingFacesDataset
from speaker_verification.dataset import ValidDataset
from speaker_verification.sampler import ProtoSampler
from speaker_verification.loss import PrototypicalLoss
from speaker_verification.train import train_model

from speaker_verification.metrics import EER_
from speaker_verification.metrics import accuracy_

### General Pipeline

In [25]:
# dataset
annotations_file = "/workdir/github/annotations_file_short_SF.csv"
path2datasets = "/workdir/sf_pv"
dataset_dir = f"/workdir/sf_pv/data_v2"
data_type = ['rgb']

# model
library = "timm"
model_name = "resnet34"
pretrained_weights=True
fine_tune=True
embedding_size=128
pool=None

# transform
audio_T = None
image_T = None

# sampler
n_batch=10
n_ways=2
n_support=1
n_query=1

# loss
dist_type='squared_euclidean'

# train
num_epochs=20
save_dir='/workdir/results'
exp_name='chern'
wandb=None

In [26]:
if 'wav' in data_type:
    # audio transform params
    sample_rate=16000
    sample_duration=2 # seconds
    n_fft=512 # from Korean code
    win_length=400
    hop_length=160
    window_fn=torch.hamming_window
    n_mels=40

    audio_T = Audio_Transforms(sample_rate=sample_rate,
                                sample_duration=sample_duration, # seconds
                                n_fft=n_fft, # from Korean code
                                win_length=win_length,
                                hop_length=hop_length,
                                window_fn=torch.hamming_window,
                                n_mels=n_mels,
                                model_name=model_name,
                                library=library)
    audio_T = audio_T.transform

if 'rgb' in data_type or 'thr' in data_type:
    image_T = Image_Transforms(model_name=model_name,
                               library=library)

    image_T = image_T.transform         

model = Model(library=library, 
            pretrained_weights=pretrained_weights, 
            fine_tune=fine_tune, 
            embedding_size=embedding_size,
            model_name = model_name,
            pool=pool,
            data_type=data_type)

model = model.to(device)         

rgb data type
timm model is used.


In [27]:
import torchvision
weights = torchvision.models.ResNet34_Weights.DEFAULT
model_resnet = torchvision.models.resnet34(weights=weights)
model_resnet.fc = torch.nn.Linear(model_resnet.fc.in_features, 128)
model_resnet = model_resnet.to(device)  

linear = torch.nn.Linear(128, 100)
linear = linear.to(device) 

model = torch.nn.Sequential(
            model_resnet,
            linear
)
model = model.to(device)

In [28]:
# Dataset
train_dataset = SpeakingFacesDataset(annotations_file,dataset_dir,'train',
                                image_transform=image_T, 
                                audio_transform=audio_T,
                                data_type=data_type)

train_dataloader = DataLoader(dataset=train_dataset,
                        shuffle=True,
                        batch_size=64,
                        num_workers=4)

valid_dataset = ValidDataset(path2datasets,'valid',
                                image_transform=image_T, 
                                audio_transform=audio_T,
                                data_type=data_type)

valid_dataloader = DataLoader(dataset=valid_dataset,
                            batch_size=64)

In [29]:
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.95)

### train one epoch

In [30]:
def train_singe_epoch(model,
                      train_dataloader, 
                      epoch,
                      optimizer,
                      device):

    model.train()
    pbar = tqdm(train_dataloader, desc=f'Train (epoch = {epoch})', leave=False)  

    total_loss = 0
    total_acc = 0
    for batch in pbar:

        data, label = batch
        data = data.to(device)
        label = label.to(device)

        data = model(data)
        loss = F.cross_entropy(data, label)
        pred = torch.argmax(F.softmax(data,dim=1), dim=1)
        accuracy = (pred == label).sum()/len(label) * 100
        
        total_loss += loss.item()
        total_acc += accuracy.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / len(train_dataloader)
    avg_acc = total_acc / len(train_dataloader)

    print()
    print(f"Average train loss: {avg_loss}")
    print(f"Average train accuracy: {avg_acc}")

    return model, avg_loss, avg_acc

In [31]:
model, avg_loss, avg_acc = train_singe_epoch(model,
                      train_dataloader, 
                      1,
                      optimizer,
                      device)

Train (epoch = 1):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.8382896775354262
Average train accuracy: 80.42475269265371


In [32]:
def evaluate_single_epoch(model,
                        val_dataloader,
                        epoch, 
                        device,
                        data_type):
    model.eval()
    total_eer = 0
    total_accuracy = 0

    pbar = tqdm(val_dataloader, desc=f'Eval (epoch = {epoch})')

    for batch in pbar:

        data_type = sorted(data_type)
        id1, id2, labels = batch

        if len(data_type) == 1:
            data_id1, _ = id1
            data_id2, _ = id2

            data_id1 = data_id1.to(device)
            data_id2 = data_id2.to(device)

        with torch.no_grad():
            id1_out = model(data_id1)
            id2_out = model(data_id2)

            cos_sim = F.cosine_similarity(id1_out, id2_out, dim=1)
            eer, scores = EER_(cos_sim, labels)
            accuracy = accuracy_(labels, scores)

            total_eer += eer
            total_accuracy += accuracy
    
    avg_eer = total_eer / len(val_dataloader)
    print("\nAverage val eer: {}".format(avg_eer))

    avg_accuracy = total_accuracy / len(val_dataloader)
    print("\nAverage val accuracy: {}".format(avg_accuracy))

    return model, avg_eer, avg_accuracy

In [33]:
def train_model(model, 
                train_dataloader, 
                valid_dataloader):

    for epoch in tqdm(range(num_epochs)):
        
        model, train_loss, train_acc = train_singe_epoch(model, 
                                  train_dataloader,
                                  epoch,
                                  optimizer,
                                  device)
        
        model, val_eer, val_acc = evaluate_single_epoch(model, 
                                  valid_dataloader,
                                  epoch,
                                  device,
                                  data_type)

        scheduler.step()
    return model

In [13]:
model = train_model(model, 
                train_dataloader, 
                valid_dataloader)

  0%|          | 0/20 [00:00<?, ?it/s]

Train (epoch = 0):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.09241721228926048
Average train accuracy: 97.8167808219178


Eval (epoch = 0):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 7.584369282840795

Average val accuracy: 92.25413860830528


Train (epoch = 1):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.058324468982316656
Average train accuracy: 98.67968352853435


Eval (epoch = 1):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 8.766921808950595

Average val accuracy: 90.900322671156


Train (epoch = 2):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.05622494854479518
Average train accuracy: 98.5512588710001


Eval (epoch = 2):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 9.328884409034284

Average val accuracy: 90.4136854657688


Train (epoch = 3):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.06344400461897101
Average train accuracy: 98.43353626826038


Eval (epoch = 3):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 7.20572950137433

Average val accuracy: 92.4636994949495


Train (epoch = 4):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.00869437371825557
Average train accuracy: 99.80736301369863


Eval (epoch = 4):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.863461271469695

Average val accuracy: 93.83680555555556


Train (epoch = 5):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.0006209122620432076
Average train accuracy: 100.0


Eval (epoch = 5):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.412804283542606

Average val accuracy: 94.16035353535354


Train (epoch = 6):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.00023227830263762098
Average train accuracy: 100.0


Eval (epoch = 6):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.3527086017503365

Average val accuracy: 94.30853675645342


Train (epoch = 7):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.000410251483416475
Average train accuracy: 99.97859589041096


Eval (epoch = 7):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 6.03525226119099

Average val accuracy: 93.68248456790124


Train (epoch = 8):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.12507632274644317
Average train accuracy: 96.4576198630137


Eval (epoch = 8):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 10.513774944940682

Average val accuracy: 89.29836560044893


Train (epoch = 9):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.00858360527590402
Average train accuracy: 99.77525684931507


Eval (epoch = 9):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.850548726426873

Average val accuracy: 93.81663860830528


Train (epoch = 10):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.0005734593217573703
Average train accuracy: 100.0


Eval (epoch = 10):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.664898297634714

Average val accuracy: 93.90958193041526


Train (epoch = 11):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.00021292578388318063
Average train accuracy: 100.0


Eval (epoch = 11):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.470000632816554

Average val accuracy: 94.2191007295174


Train (epoch = 12):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.00015004996616693934
Average train accuracy: 100.0


Eval (epoch = 12):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.5365098463397135

Average val accuracy: 94.23926767676768


Train (epoch = 13):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.0001329525487450803
Average train accuracy: 100.0


Eval (epoch = 13):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.504656620981037

Average val accuracy: 94.12352693602693


Train (epoch = 14):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.00010353646128550443
Average train accuracy: 100.0


Eval (epoch = 14):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.216038176569477

Average val accuracy: 94.51897446689114


Train (epoch = 15):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 8.765138191187336e-05
Average train accuracy: 100.0


Eval (epoch = 15):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.312792734591935

Average val accuracy: 94.4760101010101


Train (epoch = 16):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 7.850109793489786e-05
Average train accuracy: 100.0


Eval (epoch = 16):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.232823364413601

Average val accuracy: 94.48214786756454


Train (epoch = 17):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 7.138657936792859e-05
Average train accuracy: 100.0


Eval (epoch = 17):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.180951200262925

Average val accuracy: 94.54177188552188


Train (epoch = 18):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 5.7631371331908294e-05
Average train accuracy: 100.0


Eval (epoch = 18):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 5.180316440344332

Average val accuracy: 94.48653198653199


Train (epoch = 19):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 6.424354278476118e-05
Average train accuracy: 100.0


Eval (epoch = 19):   0%|          | 0/594 [00:00<?, ?it/s]


Average val eer: 4.904011297023658

Average val accuracy: 94.66978815937149


In [14]:
92.25413860830528 + 7.584369282840795

99.83850789114608

In [34]:
model = train_model(model, 
                train_dataloader, 
                valid_dataloader)

  0%|          | 0/20 [00:00<?, ?it/s]

Train (epoch = 0):   0%|          | 0/146 [00:00<?, ?it/s]


Average train loss: 0.13045802355220873
Average train accuracy: 96.48576229565764


Eval (epoch = 0):   0%|          | 0/594 [00:00<?, ?it/s]

KeyboardInterrupt: 