# Processing for local training

In [1]:
import os
#!tar -xf data/for-rerec.tar.gz -C data/

In [2]:
#ls data/for-rerecorded/

In [3]:
# !pip3 install torch torchvision librosa matplotlib tqdm pandas

In [2]:
import torch
print(torch.cuda.is_available())

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
import librosa
import os
from tqdm import tqdm
from sklearn.metrics import roc_curve
import pandas as pd
import shutil
import zipfile

True


# Data Preprocessing

In [51]:
SAMPLE_RATE = 16000  # Sampling rate
N_MELS = 128

In [52]:
# TODO: Make it so each output is "513-dimensional" as with the reference paper
#
# https://arxiv.org/pdf/2203.16263

def compute_spectrograms(path):
    y, sr = librosa.load(path, sr=SAMPLE_RATE)
    fixed_length = 2 * SAMPLE_RATE
    if len(y) < fixed_length:
        y = np.pad(y, (0, fixed_length - len(y)))
    else:
        y = y[:fixed_length]

    cqt = librosa.cqt(y, sr=sr)
    cqt_spec = librosa.amplitude_to_db(np.abs(cqt), ref=np.max)

    stft = librosa.stft(y)
    log_spec = librosa.amplitude_to_db(np.abs(stft), ref=np.max)

    mel = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS)
    mel_spec = librosa.power_to_db(mel, ref=np.max)

    return cqt_spec, log_spec, mel_spec

In [53]:
data_dirs = {
    'training_fake': 'data/for-rerecorded/training/fake/',
    'testing_fake': 'data/for-rerecorded/testing/fake/',
    'validation_fake': 'data/for-rerecorded/validation/fake/',
    'training_real': 'data/for-rerecorded/training/real/',
    'testing_real': 'data/for-rerecorded/testing/real/',
    'validation_real': 'data/for-rerecorded/validation/real/',
}

In [54]:
def process_directory(directory, output_dir):
  os.makedirs(output_dir, exist_ok=True)

  for filename in tqdm(os.listdir(directory)):
    if filename.endswith('.wav'):
      audio_path = os.path.join(directory, filename)
      cqt, log, mel = compute_spectrograms(audio_path)

      base_name = os.path.splitext(filename)[0]
      # Save spectrograms as numpy arrays
      np.save(f"{output_dir}/{base_name}_cqt.npy", cqt)
      np.save(f"{output_dir}/{base_name}_log.npy", log)
      np.save(f"{output_dir}/{base_name}_mel.npy", mel)

compute_specs = False
if compute_specs:
    for set_name, directory in data_dirs.items():
      output_dir = f'data/spectrograms/{set_name}_spectrograms'
      process_directory(directory, output_dir)
      print(f"Processed {set_name} set.")


In [55]:
# lets look at some data
display_cqt = "data/spectrograms/training_fake_spectrograms/recording1.wav_norm_mono_cqt.npy"
display_log = "data/spectrograms/training_fake_spectrograms/recording1.wav_norm_mono_log.npy"
display_mel = "data/spectrograms/training_fake_spectrograms/recording1.wav_norm_mono_mel.npy"

cqt_test = np.load(display_cqt)
log_test = np.load(display_log)
mel_test = np.load(display_mel)
print(cqt_test.shape)
print(log_test.shape)
print(mel_test.shape)

# for reference
cqt_size = 84
log_size = 1025
mel_size = 128


(84, 63)
(1025, 63)
(128, 63)


# Define models

In [56]:
class ResNet50Spectrogram(nn.Module):
    def __init__(self):
        super(ResNet50Spectrogram, self).__init__()

        self.model = models.resnet50(weights=None)

        original_conv = self.model.conv1
        self.model.conv1 = nn.Conv2d(in_channels=1,
                            out_channels=original_conv.out_channels,
                            kernel_size = original_conv.kernel_size,
                            stride = original_conv.stride,
                            padding = original_conv.padding,
                            bias = False)

        self.model.fc = nn.Linear(self.model.fc.in_features, 2)

    def forward(self, x):
        return self.model(x)

In [57]:
class EfficientNetSpectrogram(nn.Module):
    def __init__(self, model_type):
        super(EfficientNetSpectrogram, self).__init__()
        
        self.enet = None
        
        if model_type == "b0":
            self.enet = models.efficientnet_b0(weights=None, num_classes=2)
            
        
        # We need to change the network to accept 1 channel instead of
        # 3 because of our data.
        original_conv = self.enet.features[0][0]
        new_conv = nn.Conv2d(in_channels=1,
                            out_channels=original_conv.out_channels,
                            kernel_size = original_conv.kernel_size,
                            stride = original_conv.stride,
                            padding = original_conv.padding,
                            bias = False)
        self.enet.features[0][0] = new_conv
        
    def forward(self, x):
        return self.enet(x)

In [58]:
class LSTMSpectrogram(nn.Module):
    def __init__(self):
        super(LSTMSpectrogram, self).__init__()
        
        self.nlayer = 2
        self.nhiddens = 256
        
        if feature_type == "cqt":
            self.lstm = nn.LSTM(input_size=cqt_size, hidden_size=self.nhiddens, num_layers=self.nlayer, 
                                batch_first=True, dropout=0.3)
        elif feature_type == "log":
            self.lstm = nn.LSTM(input_size=log_size, hidden_size=self.nhiddens, num_layers=self.nlayer, 
                                batch_first=True, dropout=0.3)
        elif feature_type == "mel":
            self.lstm = nn.LSTM(input_size=mel_size, hidden_size=self.nhiddens, num_layers=self.nlayer, 
                                batch_first=True, dropout=0.3)
            
        self.fc = nn.Linear(self.nhiddens, 1)
        
    def forward(self, x):
        x = x.squeeze(1)
        # features are in wrong order for lstm
        x = x.transpose(1,2)
        
        x, (h_o, c_o) = self.lstm(x)
        
        h_o = h_o.squeeze(0)
        if self.nlayer > 1:
            h_o = h_o[-1]
        x = self.fc(h_o)
        x = x.squeeze(-1)
        
        #x = self.fc(x)
        return x

# Training procedures

In [59]:
class SpecDataset(Dataset):
    # data_type is one of 'cqt', 'log', 'mel'
    #
    # loader_type is one of 'train', 'validation', 'test'
    def __init__(self, data_type, loader_type):
        
        root = os.getcwd()
        data_root = os.path.join(root, 'data/spectrograms')
        
        self.data = []
        
        real_folder = None
        fake_folder = None

        # get the folder
        if loader_type == "train":
            real_folder = os.path.join(data_root, 'training_real_spectrograms')
            fake_folder = os.path.join(data_root, 'training_fake_spectrograms')
        elif loader_type == "validation":
            real_folder = os.path.join(data_root, 'validation_real_spectrograms')
            fake_folder = os.path.join(data_root, 'validation_fake_spectrograms')
        elif loader_type == "test":
            real_folder = os.path.join(data_root, 'testing_real_spectrograms')
            fake_folder = os.path.join(data_root, 'testing_fake_spectrograms')
        elif loader_type == "ITWFull":
            real_folder = os.path.join(data_root, 'ITWfull_real_spectrograms')
            fake_folder = os.path.join(data_root, 'ITWfull_fake_spectrograms')
        else:
            # Should never occur.
            pass
        
        real_files = []
        fake_files = []
        
        # now we have the folder given the loader type, collect
        # the data required for the loader.
        
        # get real example filenames
        suffix = f"{data_type}.npy"
        for filename in os.listdir(real_folder):
            # check if correct suffix and exists as a file
            if filename.endswith(suffix) and os.path.isfile(os.path.join(real_folder, filename)):
                this_filepath = os.path.join(real_folder, filename)
                real_files.append(this_filepath)
                
        print(f"Real examples for {data_type} {loader_type}: {len(real_files)}")
        
        # get fake example filenames
        suffix = f"{data_type}.npy"
        for filename in os.listdir(fake_folder):
            # check if correct suffix and exists as a file
            if filename.endswith(suffix) and os.path.isfile(os.path.join(fake_folder, filename)):
                this_filepath = os.path.join(fake_folder, filename)
                fake_files.append(this_filepath)
                
        print(f"Fake examples for {data_type} {loader_type}: {len(fake_files)}")
        
        label_val_false = 0
        label_val_true = 1
        if model_type == "LSTM":
            label_val_false = float(0)
            label_val_true = float(1)
        
        # load the data into memory
        #
        # if we need to work with a larger dataset, you might need to
        # alter this to be lazy loading instead, but it fits in my main memory
        # because of how much I currently have.
        for real_file in real_files:
            rf_data = torch.tensor(np.load(real_file))
            rf_data = rf_data.unsqueeze(0)
            if resizing == True:
                rf_data = rf_data.unsqueeze(0)
                rf_data = F.interpolate(rf_data, size=dims_resize, mode='bilinear', align_corners = False)
                rf_data = rf_data.squeeze(0)
            self.data.append((rf_data,label_val_true))
            
        for fake_file in fake_files:
            ff_data = torch.tensor(np.load(fake_file))
            ff_data = ff_data.unsqueeze(0)
            if resizing == True:
                ff_data = ff_data.unsqueeze(0)
                ff_data = F.interpolate(ff_data, size=dims_resize, mode='bilinear', align_corners = False)
                ff_data = ff_data.squeeze(0)
            self.data.append((ff_data, label_val_false))
        
    def __len__(self):
        return len(self.data)
            
    def __getitem__(self, idx):
        # return the data and the label
        return self.data[idx]

In [60]:
def dynamic_collate(batch):
    data, labels = zip(*batch)
    data = [d for d in data]
    
    max_length = max(d.shape[2] for d in data)
    
    padded = []
    if resizing == False:
        for d in data:
            # total padding needed, >= 0
            padding = max_length - d.shape[2]
        
            padded_d = None
            if padding > 0:
                # add zero's (silence) to match rest of batch
                padded_data = F.pad(d, (0,padding))
            
            else:
                # already max length
                padded_data = d
            padded.append(padded_data)
    else:
        # if resizing was true, we don't need to pad, everything is of the same shape
        padded = data
    
    '''
    for p in padded:
        r = p.unsqueeze(0)
        r = F.interpolate(r, size=dims_resize, mode='bilinear', align_corners = False)
        r = r.squeeze(0)
        resized.append(r)
    '''
    
    # stack properly now that everything is padded
    padded = torch.stack(padded, dim=0)

    labels = torch.tensor(labels)
    
    return padded, labels

In [61]:
#mean = [0]
#std = [1]

# deal with this later
#
# we should also probably compute the mean and std manually instead of assuming they correctly
# normalized it, since this is the re-recorded dataset
train_transform = transforms.Compose([
    transforms.ToTensor()
    #transforms.Normalize(mean, std)
  ])
test_transform = transforms.Compose([transforms.ToTensor()])

####################################################
# <CHANGE ME> if you want to use different features!
####################################################
feature_type = "log"

####################################################
# <CHANGE ME> if you want to use resizing!
#
# We need to resize to, for example, (224, 224)
####################################################
resizing = True
dims_resize = (224, 224)

#model_type = "enet"
model_type = "res"
#model_type = "LSTM"

device = torch.device("cuda")

model = None

if model_type == "LSTM":
    model = LSTMSpectrogram()
elif model_type == "enet":
    model =  EfficientNetSpectrogram("b0")
elif model_type == "res":
    model = ResNet50Spectrogram()
    
model = model.to(device)

#epochs = 100
epochs = 30
batch_size = 32
weight_decay = 5e-4
learning_rate = 0.0001

criterion = nn.CrossEntropyLoss()
if model_type == "LSTM":
    criterion = nn.BCEWithLogitsLoss()
    
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)

FoR_train_loader = None
FoR_val_loader = None
FoR_test_loader = None

# data loaders
FoR_train_dataset = SpecDataset(feature_type, "train")
FoR_val_dataset = SpecDataset(feature_type, "validation")
FoR_test_dataset = SpecDataset(feature_type, "test")

FoR_train_loader = DataLoader(FoR_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=dynamic_collate)
FoR_val_loader = DataLoader(FoR_val_dataset, batch_size=batch_size, shuffle=True, collate_fn=dynamic_collate)
FoR_test_loader = DataLoader(FoR_test_dataset, batch_size=batch_size, shuffle=True, collate_fn=dynamic_collate)

Real examples for log train: 5104
Fake examples for log train: 5104
Real examples for log validation: 1101
Fake examples for log validation: 1143
Real examples for log test: 408
Fake examples for log test: 408


In [62]:
# we need to compute the equal error rate as one of our metrics.
def compute_EER(model, loader):
    model.eval()
    all_scores = []
    all_labels = []
    
    with torch.no_grad():
        for data in loader:
            waveform, labels = data
            
            waveform = waveform.to(device)
            labels = labels.to(device)
            
            out = model(waveform)
            if model_type == "LSTM":
                out = torch.sigmoid(out)
            else:
                out = torch.softmax(out, dim=1)
                # take the positive class labels
                out = out[:,1]
            
            all_scores.extend(out.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # use sklearn to compute this for us
    fpr, tpr, thresholds = roc_curve(all_labels, all_scores)
    
    # definition
    fnr = 1 - tpr

    # find closest threshold
    eer_thresh = np.nanargmin(np.abs(fpr-fnr))
    EER = (fpr[eer_thresh] + fnr[eer_thresh])/2
    
    return EER


In [63]:
def train(loader):
    model.train()
    training_loss = 0.0
    
    for data in loader:
        waveform, labels = data
        waveform = waveform.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # basic pytorch boilerplate
        out = model(waveform)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        
        training_loss += loss.item()
        
    training_loss = training_loss / len(loader)
    return training_loss

In [64]:
def validate(loader):
    model.eval()
    validation_loss = 0.0
    
    n_correct = 0
    n_total = 0
    
    with torch.no_grad():
        for data in loader:
            waveform, labels = data
            
            waveform = waveform.to(device)
            labels = labels.to(device)
            
            out = model(waveform)
            loss = criterion(out, labels)
            
            validation_loss += loss.item()
            
            # count correct predictions
            preds = None
            if model_type == "LSTM":
                preds = (out > 0).long()
            else:
                preds = out.argmax(dim=1)
            
            n_correct = n_correct + (preds == labels).sum().item()
            n_total = n_total + labels.size(0)
            
    validation_loss = validation_loss / len(loader)
    accuracy = n_correct / n_total
    
    return validation_loss, accuracy
            

In [65]:
# reference paper uses patience = 5
patience = 5
best_validation_loss = 10000.0
fail_count = 0
epochs = 30

training_losses = []
val_losses = []
test_losses = []

for epoch in tqdm(range(epochs)):
    training_loss = train(FoR_train_loader)
    print(f"[Epoch {epoch}] Training Loss: {training_loss}")
    
    training_losses.append(training_loss)
    
    validation_loss, val_accuracy = validate(FoR_val_loader)    
    print(f"[Epoch {epoch}] Validation Loss: {validation_loss} Accuracy: {val_accuracy}")
    
    val_losses.append(validation_loss)
    
    test_loss, test_accuracy = validate(FoR_test_loader)
    print(f"[DEBUG Epoch {epoch}] Test Loss: {test_loss} Accuracy: {test_accuracy}")
    
    test_losses.append(test_loss)
    
    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        fail_count = 0
    else:
        # increment number of epochs of no improvement
        fail_count = fail_count + 1
        
    if fail_count >= patience:
        print(f"Triggering early breaking on epoch {epoch}")
        break
    
    scheduler.step()

  0%|                                                                                            | 0/30 [00:00<?, ?it/s]

[Epoch 0] Training Loss: 0.37124145485250554
[Epoch 0] Validation Loss: 0.2722762826252991 Accuracy: 0.8863636363636364


  3%|██▋                                                                              | 1/30 [02:32<1:13:51, 152.81s/it]

[DEBUG Epoch 0] Test Loss: 0.7090642612714034 Accuracy: 0.7022058823529411
[Epoch 1] Training Loss: 0.1612746174385929
[Epoch 1] Validation Loss: 0.15724595247859685 Accuracy: 0.9407308377896613


  7%|█████▍                                                                           | 2/30 [09:07<2:17:44, 295.16s/it]

[DEBUG Epoch 1] Test Loss: 1.041413746201075 Accuracy: 0.6593137254901961
[Epoch 2] Training Loss: 0.09736738857910983
[Epoch 2] Validation Loss: 0.12890241352487092 Accuracy: 0.9527629233511586


 10%|████████                                                                         | 3/30 [16:07<2:38:21, 351.91s/it]

[DEBUG Epoch 2] Test Loss: 1.3922355610590715 Accuracy: 0.6176470588235294
[Epoch 3] Training Loss: 0.06983439575991707
[Epoch 3] Validation Loss: 0.09815628299305976 Accuracy: 0.966131907308378


 13%|██████████▊                                                                      | 4/30 [23:01<2:43:10, 376.56s/it]

[DEBUG Epoch 3] Test Loss: 0.8845844601209347 Accuracy: 0.7034313725490197
[Epoch 4] Training Loss: 0.055230470140301115
[Epoch 4] Validation Loss: 0.1065909654118488 Accuracy: 0.9710338680926917


 17%|█████████████▌                                                                   | 5/30 [30:34<2:48:22, 404.09s/it]

[DEBUG Epoch 4] Test Loss: 1.5850116381278405 Accuracy: 0.6740196078431373
[Epoch 5] Training Loss: 0.03968578453710723
[Epoch 5] Validation Loss: 0.2723827622527532 Accuracy: 0.9327094474153298


 20%|████████████████▏                                                                | 6/30 [38:41<2:52:56, 432.34s/it]

[DEBUG Epoch 5] Test Loss: 1.5275486478438745 Accuracy: 0.6495098039215687
[Epoch 6] Training Loss: 0.04241336406955604
[Epoch 6] Validation Loss: 0.07857141210238489 Accuracy: 0.9754901960784313


 23%|██████████████████▉                                                              | 7/30 [45:58<2:46:16, 433.76s/it]

[DEBUG Epoch 6] Test Loss: 2.046506418631627 Accuracy: 0.6066176470588235
[Epoch 7] Training Loss: 0.033735197724044236
[Epoch 7] Validation Loss: 0.38053533442082327 Accuracy: 0.8832442067736186


 27%|█████████████████████▌                                                           | 8/30 [55:17<2:53:42, 473.73s/it]

[DEBUG Epoch 7] Test Loss: 4.5555591858350315 Accuracy: 0.5147058823529411
[Epoch 8] Training Loss: 0.030094338735152434
[Epoch 8] Validation Loss: 0.10988308090190271 Accuracy: 0.9674688057040999


 30%|███████████████████████▋                                                       | 9/30 [1:03:28<2:47:41, 479.10s/it]

[DEBUG Epoch 8] Test Loss: 2.5107709169387817 Accuracy: 0.5882352941176471
[Epoch 9] Training Loss: 0.020042851171606816
[Epoch 9] Validation Loss: 0.08330403118400002 Accuracy: 0.9781639928698752


 33%|██████████████████████████                                                    | 10/30 [1:10:31<2:33:58, 461.92s/it]

[DEBUG Epoch 9] Test Loss: 1.203284339262889 Accuracy: 0.7279411764705882
[Epoch 10] Training Loss: 0.020739207701395077
[Epoch 10] Validation Loss: 0.08406051221211616 Accuracy: 0.9737076648841355


 37%|████████████████████████████▌                                                 | 11/30 [1:17:37<2:22:47, 450.91s/it]

[DEBUG Epoch 10] Test Loss: 1.7863023716669817 Accuracy: 0.6482843137254902
[Epoch 11] Training Loss: 0.00913654400017629
[Epoch 11] Validation Loss: 0.0915693430822141 Accuracy: 0.9763814616755794


 37%|████████████████████████████▌                                                 | 11/30 [1:25:47<2:28:10, 467.95s/it]

[DEBUG Epoch 11] Test Loss: 2.7673317331534166 Accuracy: 0.633578431372549
Triggering early breaking on epoch 11





In [66]:
print(f"Test EER: {compute_EER(model, FoR_test_loader)}")
test_loss, test_accuracy = validate(FoR_test_loader)
print(f"Testing loss: {test_loss} Accuracy: {test_accuracy}")

# expected to be quite low, though obvious overfitting at current settings
print(f"Train EER: {compute_EER(model, FoR_train_loader)}")

Test EER: 0.29289215686274506
Testing loss: 2.7867565155029297 Accuracy: 0.633578431372549
Train EER: 0.0007836990595611167


In [None]:
plt.plot(range(len(training_losses)), training_losses, label="Training Loss", marker='o')
plt.plot(range(len(val_losses)), val_losses, label="Validation Loss", marker='s')
plt.plot(range(len(test_losses)), test_losses, label="Test Loss", marker='x')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss vs epochs')
plt.legend()
plt.show()


In [68]:
torch.save(model.state_dict(), 'models/ResNet_log.pth')

In [39]:
device = torch.device("cuda")
model = ResNet50Spectrogram()
model.load_state_dict(torch.load('models/ResNet_cqt.pth'))
model.eval()
model = model.to(device)
model_type = "res"
criterion = nn.CrossEntropyLoss()

# Process entire ITW dataset

In [None]:
if not os.path.isdir("data/release_in_the_wild"):
    with zipfile.ZipFile('data/release_in_the_wild.zip') as zip_ref:
        zip_ref.extractall('data')

In [None]:
src_csv = 'data/release_in_the_wild/meta.csv'

df = pd.read_csv(src_csv)

for _, row in df.iterrows():
    name = row['file']
    label = str(row['label'])

    src_path = os.path.join('data/release_in_the_wild', name)
    dst_dir = os.path.join('data/release_in_the_wild', label)
    dst_path = os.path.join(dst_dir, name)

    os.makedirs(dst_dir, exist_ok=True)

    if os.path.exists(src_path):
        shutil.move(src_path, dst_path)

os.rename('data/release_in_the_wild/bona-fide', 'data/release_in_the_wild/real')
os.rename('data/release_in_the_wild/spoof', 'data/release_in_the_wild/fake')

In [None]:
def process_directory(directory, output_dir):
  os.makedirs(output_dir, exist_ok=True)

  for filename in os.listdir(directory):
    if filename.endswith('.wav'):
      audio_path = os.path.join(directory, filename)
      cqt, log, mel = compute_spectrograms(audio_path)

      base_name = os.path.splitext(filename)[0]
      # Save spectrograms as numpy arrays
      np.save(f"{output_dir}/{base_name}_cqt.npy", cqt)
      np.save(f"{output_dir}/{base_name}_log.npy", log)
      np.save(f"{output_dir}/{base_name}_mel.npy", mel)

data_dirs = {
    'ITWfull_real': 'data/release_in_the_wild/real',
    'ITWfull_fake': 'data/release_in_the_wild/fake/'
}
if compute_specs:
    for set_name, directory in data_dirs.items():
        output_dir = f'data/spectrograms/{set_name}_spectrograms'
        process_directory(directory, output_dir)
        print(f"Processed {set_name} set.")


In [17]:
# viewing some of the data
ITW_display_cqt = "data/spectrograms/ITWfull_real_spectrograms/5_cqt.npy"
ITW_display_log = "data/spectrograms/ITWfull_real_spectrograms/5_log.npy"
ITW_display_mel = "data/spectrograms/ITWfull_real_spectrograms/5_mel.npy"

ITW_cqt_test = np.load(display_cqt)
ITW_log_test = np.load(display_log)
ITW_mel_test = np.load(display_mel)
print(ITW_cqt_test.shape)
print(ITW_log_test.shape)
print(ITW_mel_test.shape)

# for reference, should be the same as before
ITW_cqt_size = 84
ITW_log_size = 1025
ITW_mel_size = 128

(84, 63)
(1025, 63)
(128, 63)


# Evaluate the FoR trained model on the ITW dataset

In [67]:
# already defined above
feature_type = "log"
resizing = True
dims_resize = (224, 224)
model_type = "res"

batch_size = 128

ITW_full_dataset = SpecDataset(feature_type, "ITWFull")
ITW_full_loader = DataLoader(ITW_full_dataset, batch_size=batch_size, shuffle=True, collate_fn=dynamic_collate)

Real examples for log ITWFull: 19963
Fake examples for log ITWFull: 11816


In [69]:
criterion = nn.CrossEntropyLoss()
print(f"ITW Test EER: {compute_EER(model, ITW_full_loader)}")
test_loss, test_accuracy = validate(ITW_full_loader)
print(f"ITW Testing loss: {test_loss} Accuracy: {test_accuracy}")

ITW Test EER: 0.29202589660540246
ITW Testing loss: 2.3529806414761216 Accuracy: 0.6387236854526575


# Transfer learning from FoR to ITW

In [70]:
# Most of the settings should be kept the same from previous training, because we are
# using the same model.

from torch.utils.data import random_split

train_size = 0.8 * len(ITW_full_dataset)
val_size = 0.1 * len(ITW_full_dataset)
train_size = int(train_size)
val_size = int(val_size)

test_size = len(ITW_full_dataset) - val_size - train_size

# now they should all sum to ITW_full_dataset, do the split

ITW_train, ITW_val, ITW_test = random_split(ITW_full_dataset, [train_size, val_size, test_size])

batch_size = 32
epochs = 30
weight_decay = 5e-4
learning_rate = 0.0001

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)

ITW_train_loader = DataLoader(ITW_train, batch_size=batch_size, shuffle=True, collate_fn=dynamic_collate)
ITW_val_loader = DataLoader(ITW_val, batch_size=batch_size, shuffle=True, collate_fn=dynamic_collate)
ITW_test_loader = DataLoader(ITW_test, batch_size=batch_size, shuffle=True, collate_fn=dynamic_collate)

print(len(ITW_train_loader.dataset))
print(len(ITW_val_loader.dataset))
print(len(ITW_test_loader.dataset))

25423
3177
3179


In [71]:
# Same loop, but for our ITW transfer learning.

patience = 5
best_validation_loss = 10000.0
fail_count = 0

TL_training_losses = []
TL_val_losses = []
# TL_test_losses = []

print("Starting transfer learning from FoR dataset model to ITW")

for epoch in tqdm(range(epochs)):
    training_loss = train(ITW_train_loader)
    print(f"[Epoch {epoch}] Training Loss: {training_loss}")
    
    TL_training_losses.append(training_loss)
    
    validation_loss, val_accuracy = validate(ITW_val_loader)    
    print(f"[Epoch {epoch}] Validation Loss: {validation_loss} Accuracy: {val_accuracy}")
    
    TL_val_losses.append(validation_loss)
    
    # test_loss, test_accuracy = validate(ITW_test_loader)
    # print(f"[DEBUG Epoch {epoch}] Test Loss: {test_loss} Accuracy: {test_accuracy}")
    
    # TL_test_losses.append(test_loss)
    
    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        fail_count = 0
    else:
        # increment number of epochs of no improvement
        fail_count = fail_count + 1
        
    if fail_count >= patience:
        print(f"Triggering early breaking on epoch {epoch}")
        break
    
    scheduler.step()

Starting transfer learning from FoR dataset model to ITW


  0%|                                                                                            | 0/30 [00:00<?, ?it/s]

[Epoch 0] Training Loss: 0.05553503832637966


  3%|██▋                                                                              | 1/30 [16:13<7:50:30, 973.47s/it]

[Epoch 0] Validation Loss: 0.07384771597862709 Accuracy: 0.9719861504564055
[Epoch 1] Training Loss: 0.017394461650883673


  7%|█████▎                                                                          | 2/30 [34:06<8:01:29, 1031.75s/it]

[Epoch 1] Validation Loss: 0.03014316294043965 Accuracy: 0.9911866540761725
[Epoch 2] Training Loss: 0.012930405923978457


 10%|████████                                                                        | 3/30 [52:12<7:55:32, 1056.77s/it]

[Epoch 2] Validation Loss: 0.008364373966687709 Accuracy: 0.9971671388101983
[Epoch 3] Training Loss: 0.015296023972846828


 13%|██████████▍                                                                   | 4/30 [1:10:16<7:42:38, 1067.64s/it]

[Epoch 3] Validation Loss: 0.010723048777144868 Accuracy: 0.996222851746931
[Epoch 4] Training Loss: 0.009140315692710479


 17%|█████████████                                                                 | 5/30 [1:27:14<7:17:21, 1049.66s/it]

[Epoch 4] Validation Loss: 0.007654492402562028 Accuracy: 0.9971671388101983
[Epoch 5] Training Loss: 0.008191908756438604


 20%|███████████████▌                                                              | 6/30 [1:44:51<7:00:47, 1051.97s/it]

[Epoch 5] Validation Loss: 0.0064674300309980025 Accuracy: 0.9971671388101983
[Epoch 6] Training Loss: 0.005799241441025847


 23%|██████████████████▏                                                           | 7/30 [2:01:12<6:34:27, 1029.00s/it]

[Epoch 6] Validation Loss: 0.009235581871926114 Accuracy: 0.9965376141013534
[Epoch 7] Training Loss: 0.005593754292278888


 27%|████████████████████▊                                                         | 8/30 [2:17:29<6:11:12, 1012.37s/it]

[Epoch 7] Validation Loss: 0.004259973008056477 Accuracy: 0.9981114258734656
[Epoch 8] Training Loss: 0.0035830006815761655


 30%|███████████████████████▍                                                      | 9/30 [2:35:09<5:59:31, 1027.23s/it]

[Epoch 8] Validation Loss: 0.0043986171832189665 Accuracy: 0.998426188227888
[Epoch 9] Training Loss: 0.00430395858285956


 33%|█████████████████████████▋                                                   | 10/30 [2:52:18<5:42:37, 1027.87s/it]

[Epoch 9] Validation Loss: 0.009180667863711278 Accuracy: 0.9971671388101983
[Epoch 10] Training Loss: 0.0022808275932298503


 37%|████████████████████████████▏                                                | 11/30 [3:08:41<5:21:07, 1014.06s/it]

[Epoch 10] Validation Loss: 0.08432432807225268 Accuracy: 0.9748190116462071
[Epoch 11] Training Loss: 0.0029864266350486356


 40%|███████████████████████████████▏                                              | 12/30 [3:24:41<4:59:15, 997.53s/it]

[Epoch 11] Validation Loss: 0.005726329073756915 Accuracy: 0.998426188227888
[Epoch 12] Training Loss: 0.002697717660196529


 40%|██████████████████████████████▊                                              | 12/30 [3:40:43<5:31:05, 1103.66s/it]

[Epoch 12] Validation Loss: 0.007115433502077622 Accuracy: 0.9971671388101983
Triggering early breaking on epoch 12





In [72]:
print(f"Transfer Learning ITW Test EER: {compute_EER(model, ITW_test_loader)}")
test_loss, test_accuracy = validate(ITW_test_loader)
print(f"Testing loss: {test_loss} Accuracy: {test_accuracy}")

Transfer Learning ITW Test EER: 0.004360804863950296
Testing loss: 0.009879228580221024 Accuracy: 0.9977980497011639


In [73]:
torch.save(model.state_dict(), 'models/ResNet_log_TL.pth')

# Pure ITW training

In [31]:
# We already have data loaders, simply setup the same model and training procedures.

train_transform = transforms.Compose([
    transforms.ToTensor()
    #transforms.Normalize(mean, std)
  ])
test_transform = transforms.Compose([transforms.ToTensor()])

# setup the model, exact same one as used prior with empty weights

model = None

torch.cuda.empty_cache()

if model_type == "LSTM":
    model = LSTMSpectrogram()
elif model_type == "enet":
    model =  EfficientNetSpectrogram("b0")
elif model_type == "res":
    model = ResNet50Spectrogram()
    
model = model.to(device)

#epochs = 100
epochs = 30
weight_decay = 5e-4
learning_rate = 0.0001

criterion = nn.CrossEntropyLoss()
if model_type == "LSTM":
    criterion = nn.BCEWithLogitsLoss()
    
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay = weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)

In [32]:
# Train on the ITW dataset with an empty model

# Same loop, but for our ITW transfer learning.

patience = 5
best_validation_loss = 10000.0
fail_count = 0

ITW_training_losses = []
ITW_val_losses = []
# ITW_test_losses = []

print("Starting pure ITW training using the same model type (with weights cleared)")

for epoch in tqdm(range(epochs)):
    training_loss = train(ITW_train_loader)
    print(f"[Epoch {epoch}] Training Loss: {training_loss}")
    
    ITW_training_losses.append(training_loss)
    
    validation_loss, val_accuracy = validate(ITW_val_loader)    
    print(f"[Epoch {epoch}] Validation Loss: {validation_loss} Accuracy: {val_accuracy}")
    
    ITW_val_losses.append(validation_loss)
    
    test_loss, test_accuracy = validate(ITW_test_loader)
    # print(f"[DEBUG Epoch {epoch}] Test Loss: {test_loss} Accuracy: {test_accuracy}")
    
    # ITW_test_losses.append(test_loss)
    
    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        fail_count = 0
    else:
        # increment number of epochs of no improvement
        fail_count = fail_count + 1
        
    if fail_count >= patience:
        print(f"Triggering early breaking on epoch {epoch}")
        break
    
    scheduler.step()

Starting pure ITW training using the same model type (with weights cleared)


  0%|                                                                                            | 0/30 [00:00<?, ?it/s]

[Epoch 0] Training Loss: 0.13709653658139279
[Epoch 0] Validation Loss: 0.05969329889398068 Accuracy: 0.9798552093169657


  3%|██▋                                                                              | 1/30 [09:57<4:48:41, 597.29s/it]

[Epoch 1] Training Loss: 0.048527227357211496
[Epoch 1] Validation Loss: 0.047622000323026444 Accuracy: 0.9886685552407932


  7%|█████▍                                                                           | 2/30 [25:49<6:16:02, 805.82s/it]

[Epoch 2] Training Loss: 0.03595966197987403
[Epoch 2] Validation Loss: 0.042623908452806065 Accuracy: 0.9852061693421467


 10%|████████                                                                         | 3/30 [41:44<6:33:18, 874.04s/it]

[Epoch 3] Training Loss: 0.025100845364209794
[Epoch 3] Validation Loss: 0.021617065165628446 Accuracy: 0.9924457034938622


 13%|██████████▊                                                                      | 4/30 [57:37<6:32:21, 905.43s/it]

[Epoch 4] Training Loss: 0.01808495312562061
[Epoch 4] Validation Loss: 0.038589535914488805 Accuracy: 0.9905571293673276


 17%|█████████████▏                                                                 | 5/30 [1:13:33<6:24:46, 923.45s/it]

[Epoch 5] Training Loss: 0.01925506910306267
[Epoch 5] Validation Loss: 0.01179993757934426 Accuracy: 0.9959080893925086


 20%|███████████████▊                                                               | 6/30 [1:29:26<6:13:25, 933.58s/it]

[Epoch 6] Training Loss: 0.014157977252735755
[Epoch 6] Validation Loss: 0.023532813157471535 Accuracy: 0.9940195152659742


 23%|██████████████████▍                                                            | 7/30 [1:45:22<6:00:39, 940.86s/it]

[Epoch 7] Training Loss: 0.013864283821969492
[Epoch 7] Validation Loss: 0.019728451837909233 Accuracy: 0.9933899905571294


 27%|█████████████████████                                                          | 8/30 [2:01:16<5:46:30, 945.01s/it]

[Epoch 8] Training Loss: 0.010197309929393463
[Epoch 8] Validation Loss: 0.014410113999083478 Accuracy: 0.9959080893925086


 30%|███████████████████████▋                                                       | 9/30 [2:17:09<5:31:38, 947.53s/it]

[Epoch 9] Training Loss: 0.010863139021807768
[Epoch 9] Validation Loss: 0.016034580965015265 Accuracy: 0.9937047529115518


 33%|██████████████████████████                                                    | 10/30 [2:33:04<5:16:40, 950.01s/it]

[Epoch 10] Training Loss: 0.0070711572845911005
[Epoch 10] Validation Loss: 0.006737054672780687 Accuracy: 0.9981114258734656


 37%|████████████████████████████▌                                                 | 11/30 [2:48:59<5:01:16, 951.39s/it]

[Epoch 11] Training Loss: 0.005095998747590944
[Epoch 11] Validation Loss: 0.02038345941118223 Accuracy: 0.9927604658482846


 40%|███████████████████████████████▏                                              | 12/30 [3:04:55<4:45:51, 952.85s/it]

[Epoch 12] Training Loss: 0.00656310174268932
[Epoch 12] Validation Loss: 0.006841832745883494 Accuracy: 0.9974819011646208


 43%|█████████████████████████████████▊                                            | 13/30 [3:20:52<4:30:18, 954.05s/it]

[Epoch 13] Training Loss: 0.002597256415372401
[Epoch 13] Validation Loss: 0.00817999015959245 Accuracy: 0.9981114258734656


 47%|████████████████████████████████████▍                                         | 14/30 [3:36:48<4:14:35, 954.74s/it]

[Epoch 14] Training Loss: 0.0034443782434343847
[Epoch 14] Validation Loss: 0.01158743194782801 Accuracy: 0.9965376141013534


 50%|███████████████████████████████████████                                       | 15/30 [3:52:45<3:58:51, 955.44s/it]

[Epoch 15] Training Loss: 0.0031726516899655637
[Epoch 15] Validation Loss: 0.012208493231513557 Accuracy: 0.9968523764557758


 50%|███████████████████████████████████████                                       | 15/30 [4:08:39<4:08:39, 994.62s/it]

Triggering early breaking on epoch 15





In [36]:
print(f"Pure ITW Test EER: {compute_EER(model, ITW_test_loader)}")
test_loss, test_accuracy = validate(ITW_test_loader)
print(f"Testing loss: {test_loss} Accuracy: {test_accuracy}")

Pure ITW Test EER: 0.005125068724934024
Testing loss: 0.011843823992901435 Accuracy: 0.9968543567159485


In [37]:
torch.save(model.state_dict(), 'models/ResNet50_ITW.pth')

# Benchmark the inference time of the model (should be same across either training method)

In [40]:
# get sample data
input_data, _ = next(iter(ITW_test_loader))
input_data = input_data.to(device)

# make sure model is in fastest cache
with torch.no_grad():
    for _ in range(5):
        _ = model(input_data)

n_bench_runs = 1000
run_times = []

with torch.no_grad():
    for _ in range(n_bench_runs):
        # important to make sure each run is done sequentially
        torch.cuda.synchronize()
        
        start = torch.cuda.Event(enable_timing = True)
        end = torch.cuda.Event(enable_timing = True)
        
        start.record()
        _ = model(input_data)
        end.record()
        
        # important to make sure each run is done sequentially
        torch.cuda.synchronize()
        
        run_times.append(start.elapsed_time(end))
    
average_rtime = sum(run_times) / n_bench_runs
print(f'Average run time for batch of size {batch_size} on model {model_type} with features {feature_type}')
print(f'{average_rtime} ms')
print(f'Averages to {average_rtime / batch_size} ms per input')

Average run time for batch of size 32 on model res with features cqt
99.91057424545288 ms
Averages to 3.1222054451704024 ms per input
