## imports

In [1]:
# dataset
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split

from torchaudio import datasets
import torchaudio.transforms

from torch.utils.data import DataLoader

from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

import pickle

# audio processing
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt

# neural network
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.models as models

#set device to GPU
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


## importing dataset

In [2]:
# custom VCTK-based dataset definition
class STFTC_Dataset(Dataset):
    def __init__(self, root: str):
        self.root_path = root + "/"
        
    def __len__(self):
        # hard code as length will not change and to prevent needing to access original dataset
        return 43873
    
    # load items from directory
    def __getitem__(self, idx):
        with open(self.root_path + "stftc" + str(idx), "rb") as file:
            pickled_list = pickle.load(file)
        
        target_input, target_label = pickled_list
        return torch.from_numpy(target_input).float(), target_label
    
# initialise dataset
stftc_data = STFTC_Dataset("STFTcomp")

# split to test/train
split_ratio = 0.8
train_size = int(split_ratio * stftc_data.__len__())
test_size = stftc_data.__len__() - train_size

#device_gen = torch.Generator(device = device)
train_dataset, test_dataset = random_split(stftc_data, [train_size, test_size])

### dataset parameters

*VCTK Structure:*
(0: waveform; 1: sample rate; 2: text transcript; 3: person identifier; 4: text identifier)

*STFT VCTK structure:*
(0: waveform; 1: person identifier)

### creating labels

In [3]:
# generate label dictionary from original VCTK dataset
# label_temp_vctk = datasets.VCTK_092(root = "VCTK")

# label_loader = iter(DataLoader(label_temp_vctk, shuffle = False))
# dict_label = 0
# label_dict = {}
# while True:
#    try:
#        item = next(label_loader)
#        if item[3] not in label_dict.keys():
#            label_dict[item[3]] = dict_label
#            dict_label = dict_label + 1
#    except StopIteration:
#        break

In [4]:
# using the above code to generate, hard code the label dictionary once done:
label_dict = {
    ('p225',): 0, ('p226',): 1, ('p227',): 2, ('p228',): 3, ('p229',): 4, ('p230',): 5, ('p231',): 6, ('p232',): 7, 
    ('p233',): 8, ('p234',): 9, ('p236',): 10, ('p237',): 11, ('p238',): 12, ('p239',): 13, ('p240',): 14, ('p241',): 15, 
    ('p243',): 16, ('p244',): 17, ('p245',): 18, ('p246',): 19, ('p247',): 20, ('p248',): 21, ('p249',): 22, ('p250',): 23, 
    ('p251',): 24, ('p252',): 25, ('p253',): 26, ('p254',): 27, ('p255',): 28, ('p256',): 29, ('p257',): 30, ('p258',): 31, 
    ('p259',): 32, ('p260',): 33, ('p261',): 34, ('p262',): 35, ('p263',): 36, ('p264',): 37, ('p265',): 38, ('p266',): 39, 
    ('p267',): 40, ('p268',): 41, ('p269',): 42, ('p270',): 43, ('p271',): 44, ('p272',): 45, ('p273',): 46, ('p274',): 47, 
    ('p275',): 48, ('p276',): 49, ('p277',): 50, ('p278',): 51, ('p279',): 52, ('p281',): 53, ('p282',): 54, ('p283',): 55, 
    ('p284',): 56, ('p285',): 57, ('p286',): 58, ('p287',): 59, ('p288',): 60, ('p292',): 61, ('p293',): 62, ('p294',): 63, 
    ('p295',): 64, ('p297',): 65, ('p298',): 66, ('p299',): 67, ('p300',): 68, ('p301',): 69, ('p302',): 70, ('p303',): 71, 
    ('p304',): 72, ('p305',): 73, ('p306',): 74, ('p307',): 75, ('p308',): 76, ('p310',): 77, ('p311',): 78, ('p312',): 79, 
    ('p313',): 80, ('p314',): 81, ('p316',): 82, ('p317',): 83, ('p318',): 84, ('p323',): 85, ('p326',): 86, ('p329',): 87, 
    ('p330',): 88, ('p333',): 89, ('p334',): 90, ('p335',): 91, ('p336',): 92, ('p339',): 93, ('p340',): 94, ('p341',): 95, 
    ('p343',): 96, ('p345',): 97, ('p347',): 98, ('p351',): 99, ('p360',): 100, ('p361',): 101, ('p362',): 102, ('p363',): 103, 
    ('p364',): 104, ('p374',): 105, ('p376',): 106, ('s5',): 107
}

## defining model

In [5]:
# model to classify between the different speakers
# we will treat the STFT matrix as if it is a 2d image: remember the columns are the audio signal windows with frequency components

class VoiceCNN(nn.Module):
    def __init__(self):
        super(VoiceCNN, self).__init__()
        
        #activation
        self.activation = nn.ReLU()
        
        #CONV LAYERS
        
        #interpret complex numbers as two channels
        self.conv1 = nn.Conv2d(in_channels = 2, out_channels = 8, kernel_size = (5, 5), stride = 2, padding = 1)
        nn.init.kaiming_normal_(self.conv1.weight, mode = 'fan_in', nonlinearity = 'relu') # initialise weights
        
        self.conv2 = nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = (3, 3), stride = 2, padding = 1)
        nn.init.kaiming_normal_(self.conv2.weight, mode = 'fan_in', nonlinearity = 'relu')
        
        self.conv3 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3, 3), stride = 2, padding = 1)
        nn.init.kaiming_normal_(self.conv3.weight, mode = 'fan_in', nonlinearity = 'relu')
        
        self.conv4 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3, 3), stride = 2, padding = 1)
        nn.init.kaiming_normal_(self.conv4.weight, mode = 'fan_in', nonlinearity = 'relu')
        
        #POOLING
        
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        pool_h, pool_w = 2, 2
        self.adaptive_pooling = nn.AdaptiveAvgPool2d((pool_h, pool_w))
        
        # define conv block
        self.conv_layers = nn.Sequential(
            self.conv1,
            self.activation,
            self.pool,
            self.conv2,
            self.activation,
            self.pool,
            self.conv3,
            self.activation,
            self.pool,
            self.conv4,
            self.activation,
            self.pool
        )
        
        #FULLY CONNECTED LAYERS
        
        self.fc1 = nn.Linear(pool_h * pool_w * self.conv4.out_channels, 256)
        nn.init.kaiming_normal_(self.fc1.weight, mode = 'fan_in', nonlinearity = 'relu')
        self.fc2 = nn.Linear(256, 108)
        nn.init.kaiming_normal_(self.fc2.weight, mode = 'fan_in', nonlinearity = 'relu')
        
    def forward(self, x):
        # through conv layers
        out = self.conv_layers(x)
        
        # adaptive pooling to standardise feature map size
        out = self.adaptive_pooling(out)
        
        # flatten feature maps
        out = out.view(out.size(0), -1)
        
        # through connected layers
        out = self.activation(self.fc1(out))
        out = self.fc2(out)
        return out

In [6]:
# sanity test model with forward pass
model = VoiceCNN().to(device)

def test_forward(x):
    example = x[0][None, :, :, :].to(device)
    print(example.size())
    output = model(example)
    return output
    
output = test_forward(test_dataset.__getitem__(0))
print(nn.Softmax(dim = 1)(output).size())

torch.Size([1, 2, 257, 1663])
torch.Size([1, 108])


## loading data

In [7]:
test = test_dataset.__getitem__(0)
print(test[0].shape[2])

1663


In [8]:
# hyperparam
set_batch_size = 32

# custom collate function as input STFTs are of differing lengths
def collate_pad(batch):
    # sort by length in descending order
    batch.sort(key = lambda x: x[0].shape[2], reverse = True)
    max_length = batch[0][0].shape[2]
    
    # pad all inputs to same length
    padded_X = []
    for item in batch:
        tensor, label = item
        pad_size = max_length - tensor.shape[2]
        padded = F.pad(tensor, (0, pad_size))
        padded_X.append(padded)
    
    # return padded X in a stack
    X = torch.stack(padded_X)
    
    # convert labels to one-hot stack
    dict_labels = torch.tensor([label_dict[(item[1],)] for item in batch])
    y = F.one_hot(dict_labels, len(label_dict)).float()
    
    return X, y

# dataloaders
train_loader = DataLoader(train_dataset, batch_size = set_batch_size, shuffle = True, collate_fn = collate_pad)
test_loader = DataLoader(test_dataset, batch_size = set_batch_size, shuffle = False, collate_fn = collate_pad)

## training model

In [12]:
# def training func
def train(dataloader, model, loss_fn, optimiser):
    length = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # pred error
        hx = model(X)
        loss = loss_fn(hx, y)
        
        # backprop
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{length:>5d}]")

# def test func
def test(dataloader, model, loss_fn):
    length = len(dataloader.dataset)
    batch_length = len(dataloader)
    model.eval()
    
    test_loss, total_correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            
            hx = model(X)
            test_loss += loss_fn(hx, y).item()
            
            total_correct += (hx.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
        test_loss /= batch_length
        total_correct /= length
        print(f"Test Error: \n Accuracy: {(100 * total_correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# sanity test

for X, y in test_loader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([32, 2, 257, 3786])
Shape of y: torch.Size([32, 108]) torch.float32


In [19]:
# hyperparams
learning_rate = 0.002
weight_decay = 1e-6
epochs = 10

# loss and optimiser
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = weight_decay)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimiser, gamma = 0.8)

for epoch in range(epochs):
    print("\nEpoch " + str(epoch) +":")
    train(train_loader, model, loss_fn, optimiser)
    scheduler.step()
    test(test_loader, model, loss_fn)
    torch.save(model.state_dict(), "comp_model_params" + "_ep_" + str(epoch) + ".pth")


Epoch 0:
loss: 0.171932  [   32/35098]
loss: 0.864804  [ 3232/35098]
loss: 0.767665  [ 6432/35098]
loss: 0.341065  [ 9632/35098]
loss: 0.273837  [12832/35098]
loss: 0.245040  [16032/35098]
loss: 0.255948  [19232/35098]
loss: 0.447391  [22432/35098]
loss: 0.387948  [25632/35098]
loss: 0.349256  [28832/35098]
loss: 0.376904  [32032/35098]
Test Error: 
 Accuracy: 86.6%, Avg loss: 0.462848 


Epoch 1:
loss: 0.190572  [   32/35098]
loss: 0.247172  [ 3232/35098]
loss: 0.213811  [ 6432/35098]
loss: 0.289042  [ 9632/35098]
loss: 0.163893  [12832/35098]
loss: 0.269561  [16032/35098]
loss: 0.278807  [19232/35098]
loss: 0.225892  [22432/35098]
loss: 0.442743  [25632/35098]
loss: 0.155530  [28832/35098]
loss: 0.405546  [32032/35098]
Test Error: 
 Accuracy: 87.0%, Avg loss: 0.441864 


Epoch 2:
loss: 0.260344  [   32/35098]
loss: 0.147010  [ 3232/35098]
loss: 0.043669  [ 6432/35098]
loss: 0.079303  [ 9632/35098]
loss: 0.293366  [12832/35098]
loss: 0.123277  [16032/35098]
loss: 0.040464  [19232/350

In [16]:
#torch.save(model.state_dict(), "comp_model_params.pth")

In [20]:
for epoch in range(epochs):
    print("\nEpoch " + str(epoch) +":")
    train(train_loader, model, loss_fn, optimiser)
    scheduler.step()
    test(test_loader, model, loss_fn)
    torch.save(model.state_dict(), "comp_model_params_further" + "_ep_" + str(epoch) + ".pth")


Epoch 0:
loss: 0.105944  [   32/35098]
loss: 0.019909  [ 3232/35098]
loss: 0.052351  [ 6432/35098]
loss: 0.003501  [ 9632/35098]
loss: 0.074927  [12832/35098]
loss: 0.027165  [16032/35098]
loss: 0.115724  [19232/35098]
loss: 0.048907  [22432/35098]
loss: 0.053381  [25632/35098]
loss: 0.013843  [28832/35098]
loss: 0.001674  [32032/35098]
Test Error: 
 Accuracy: 92.9%, Avg loss: 0.293395 


Epoch 1:
loss: 0.052659  [   32/35098]
loss: 0.022348  [ 3232/35098]
loss: 0.152860  [ 6432/35098]
loss: 0.082621  [ 9632/35098]
loss: 0.048076  [12832/35098]
loss: 0.062924  [16032/35098]
loss: 0.176050  [19232/35098]
loss: 0.050535  [22432/35098]
loss: 0.013892  [25632/35098]
loss: 0.022687  [28832/35098]
loss: 0.125794  [32032/35098]
Test Error: 
 Accuracy: 92.8%, Avg loss: 0.291956 


Epoch 2:
loss: 0.029710  [   32/35098]
loss: 0.031510  [ 3232/35098]
loss: 0.008850  [ 6432/35098]
loss: 0.004483  [ 9632/35098]
loss: 0.027906  [12832/35098]
loss: 0.012686  [16032/35098]
loss: 0.028374  [19232/350