In [13]:
import torch
import torch.nn as nn

from torch.utils import data
from torch.utils.data import DataLoader
from torch.optim import Adam

import numpy as np
import pickle 
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# Dataset and dataloader

In [14]:
class SpkrEmbDataset(data.Dataset):
    """Dataset class for the Utterances dataset."""

    def __init__(self, pkl_path="/home/yrb/data/ID-DEID_data/voxceleb_spmel/train.pkl"):
        """Initialize and preprocess the Utterances dataset."""
        self.pkl_path = pkl_path
        
        """Load data"""
        self.load_data()
        
        
    def load_data(self):  
        # load train.pkl
        with open(self.pkl_path, "rb") as f:
            meta = pickle.load(f)
        self.dataset = [(sbmt[0], sbmt[1]) for sbmt in meta] # (spkr_id, spkr_emb)
        self.num_spkr = len(self.dataset)
        print('Finished loading the dataset...')

                   
        
    def __getitem__(self, index):
        spkr_id, spkr_emb = self.dataset[index]
        return spkr_id, spkr_emb
    

    def __len__(self):
        """Return the number of spkrs."""
        return self.num_spkr

def get_loader(pkl_path="/home/yrb/data/ID-DEID_data/voxceleb_spmel/train.pkl", batch_size=16, num_workers=0, shuffle=True, drop_last=True):
    """Build and return a data loader."""
    
    dataset = SpkrEmbDataset(pkl_path)
    
    worker_init_fn = lambda x: np.random.seed((torch.initial_seed()) % (2**32))
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=shuffle,
                                  num_workers=num_workers,
                                  drop_last=drop_last,
                                  worker_init_fn=worker_init_fn)
    return data_loader


# VAE model

In [15]:
"""
    A simple implementation of Gaussian MLP Encoder and Decoder
"""


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)

        self.FC_mean = nn.Linear(hidden_dim, latent_dim)
        self.FC_var = nn.Linear(hidden_dim, latent_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

        self.training = True

    def forward(self, x):
        h_ = self.LeakyReLU(self.FC_input(x))
        h_ = self.LeakyReLU(self.FC_input2(h_))

        mean = self.FC_mean(h_)
        # encoder produces mean and log of variance
        # (i.e., parateters of simple tractable normal distribution "q"
        log_var = self.FC_var(h_)
        return mean, log_var

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.FC_input = nn.Linear(input_dim, hidden_dim)
        # self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)

        self.FC_mean = nn.Linear(hidden_dim, latent_dim)
        self.FC_var = nn.Linear(hidden_dim, latent_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

        self.training = True

    def forward(self, x):
        h_ = self.LeakyReLU(self.FC_input(x))
        # h_ = self.LeakyReLU(self.FC_input2(h_))

        mean = self.FC_mean(h_)
        # encoder produces mean and log of variance
        # (i.e., parateters of simple tractable normal distribution "q"
        log_var = self.FC_var(h_)
        return mean, log_var

In [16]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x):
        h = self.LeakyReLU(self.FC_hidden(x))
        h = self.LeakyReLU(self.FC_hidden2(h))

        x_hat = torch.tanh(self.FC_output(h))
        return x_hat
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        # self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x):
        h = self.LeakyReLU(self.FC_hidden(x))
        # h = self.LeakyReLU(self.FC_hidden2(h))

        x_hat = torch.tanh(self.FC_output(h))
        return x_hat

In [17]:
class VAE(nn.Module):
    def __init__(self, x_dim, hidden_dim, latent_dim, DEVICE):
        super(VAE, self).__init__()
        self.Encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
        self.Decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)
        self.DEVICE = DEVICE

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(self.DEVICE)  # sampling epsilon
        z = mean + var*epsilon  # reparameterization trick
        return z

    def forward(self, x):
        mean, log_var = self.Encoder(x)
        # takes exponential function (log var -> var)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat = self.Decoder(z)

        return x_hat, mean, log_var

# Settings

In [18]:
# ENV settings
cuda = True
DEVICE = torch.device("cuda" if cuda else "cpu")

# Model Hyperparameters
batch_size = 32

x_dim = 256
hidden_dim = 384
latent_dim = 64

lr = 1e-3

epochs = 60

# Training

In [19]:
model = VAE(x_dim, hidden_dim, latent_dim, DEVICE).to(DEVICE)

In [20]:
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
mse = nn.MSELoss(reduction='sum')
l1 = nn.L1Loss(reduction='sum')
def loss_function2(x, x_hat, mean, log_var):
    # cosine similarity loss
    cos_distance_loss = 100*(1-cos(x, x_hat)).sum()
    mse_loss = mse(x, x_hat)
    reconstruction_loss = cos_distance_loss + mse_loss
    KLD = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reconstruction_loss, KLD

def loss_function1(x, x_hat, mean, log_var):
    # cosine similarity loss
    cos_distance_loss = 200*(1-cos(x, x_hat)).sum()
    l1_loss = l1(x, x_hat)
    reconstruction_loss = cos_distance_loss + l1_loss
    KLD = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reconstruction_loss, KLD

def loss_function(x, x_hat, mean, log_var):
    mse_loss = mse(x, x_hat)
    reconstruction_loss = 10 * mse_loss
    KLD = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reconstruction_loss, KLD

optimizer = Adam(model.parameters(), lr=lr)

In [21]:
print("Preparing the data loader...")
vox_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/voxceleb_spmel/train.pkl", num_workers=2, batch_size=batch_size)
wsj_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/wsj_spmel/train.pkl", num_workers=2, batch_size=batch_size)
vctk_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/vctk_spmel/train.pkl", num_workers=2, batch_size=batch_size)

dataloader = vox_dataloader

print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss, overall_reconst, overall_KLD = 0, 0, 0
    for batch_idx, (spkrids, x) in enumerate(dataloader):
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        reconst_loss, KLD = loss_function(x, x_hat, mean, log_var)
        loss = reconst_loss + KLD
        overall_loss += loss.item()
        overall_reconst += reconst_loss.item()
        overall_KLD += KLD.item()

        loss.backward()
        optimizer.step()

    print("\tEpoch", epoch + 1, "complete!",
        "\tAverage Loss: ", overall_loss / (batch_idx*batch_size),
        "\tAverage reconst Loss: ", overall_reconst / (batch_idx*batch_size),
        "\tAverage KLD Loss: ", overall_KLD / (batch_idx*batch_size),)

print("Finish!!")

Preparing the data loader...
Finished loading the dataset...
Finished loading the dataset...
Finished loading the dataset...
Start training VAE...
	Epoch 1 complete! 	Average Loss:  4.989629334637097 	Average reconst Loss:  4.956616018499647 	Average KLD Loss:  0.033013316170711605
	Epoch 2 complete! 	Average Loss:  4.236941071493285 	Average reconst Loss:  4.228967549545424 	Average KLD Loss:  0.007973513274919242
	Epoch 3 complete! 	Average Loss:  4.1526430036340445 	Average reconst Loss:  4.1484784207173755 	Average KLD Loss:  0.0041645761604221275
	Epoch 4 complete! 	Average Loss:  4.114597948534148 	Average reconst Loss:  4.111635301794324 	Average KLD Loss:  0.002962652219658984
	Epoch 5 complete! 	Average Loss:  4.098131466124739 	Average reconst Loss:  4.0965201152222495 	Average KLD Loss:  0.0016113486739673785
	Epoch 6 complete! 	Average Loss:  4.081492188785758 	Average reconst Loss:  4.079987690917084 	Average KLD Loss:  0.0015044958438790803
	Epoch 7 complete! 	Average Los

# Eval

In [None]:
vox_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/voxceleb_spmel/train.pkl", num_workers=2, batch_size=batch_size, shuffle=False, drop_last=False)
wsj_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/wsj_spmel/train.pkl", num_workers=2, batch_size=batch_size, shuffle=False, drop_last=False)
vctk_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/vctk_spmel/train.pkl", num_workers=2, batch_size=batch_size, shuffle=False, drop_last=False)


def eval(model, dataloader):
    model.eval()
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    cos_sim = 0.0
    mse = 0.0
    n_sample = 0.0
    with torch.no_grad():
        for batch_idx, (spkrids, x) in enumerate(dataloader):
            x = x.to(DEVICE)
            n_sample += x.shape[0]
            x_hat, _, _ = model(x)
            cos_sim += cos(x_hat, x).sum()
            mse += ((x-x_hat)**2).sum()
    return {"cos_sim": cos_sim/n_sample, "mse": mse/n_sample}

print("wsj:")
print("model:", eval(model, wsj_dataloader))

print("vox:")
print("model:", eval(model, vox_dataloader))

print("vctk:")
print("model:", eval(model, vctk_dataloader))


In [23]:
vox_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/voxceleb_spmel/train.pkl", num_workers=2, batch_size=batch_size, shuffle=False, drop_last=False)
wsj_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/wsj_spmel/train.pkl", num_workers=2, batch_size=batch_size, shuffle=False, drop_last=False)
vctk_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/vctk_spmel/train.pkl", num_workers=2, batch_size=batch_size, shuffle=False, drop_last=False)

model1 = VAE(256, 384, 64, DEVICE).to(DEVICE)
checkpoint = torch.load("/home/yrb/code/ID-DEID/data/vae_model/l1_vae_on_voxceleb.ckpt", map_location=DEVICE)
model1.load_state_dict(checkpoint['model'])

model2 = VAE(256, 384, 64, DEVICE).to(DEVICE)
checkpoint = torch.load("/home/yrb/code/ID-DEID/data/vae_model/vae_on_voxceleb.ckpt", map_location=DEVICE)
model2.load_state_dict(checkpoint['model'])


def eval(model, dataloader):
    model.eval()
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    cos_sim = 0.0
    mse = 0.0
    n_sample = 0.0
    with torch.no_grad():
        for batch_idx, (spkrids, x) in enumerate(dataloader):
            x = x.to(DEVICE)
            n_sample += x.shape[0]
            x_hat, _, _ = model(x)
            cos_sim += cos(x_hat, x).sum()
            mse += ((x-x_hat)**2).sum()
    return {"cos_sim": cos_sim/n_sample, "mse": mse/n_sample}

print("wsj:")
print("model1:", eval(model1, wsj_dataloader))
print("model2:", eval(model2, wsj_dataloader))
print("model:", eval(model, wsj_dataloader))

print("vox:")
print("model1:", eval(model1, vox_dataloader))
print("model2:", eval(model2, vox_dataloader))
print("model:", eval(model, vox_dataloader))

print("vctk:")
print("model1:", eval(model1, vctk_dataloader))
print("model2:", eval(model2, vctk_dataloader))
print("model:", eval(model, vctk_dataloader))


Finished loading the dataset...
Finished loading the dataset...
Finished loading the dataset...
wsj:
model1: {'cos_sim': tensor(0.5309, device='cuda:0'), 'mse': tensor(0.5145, device='cuda:0')}
model2: {'cos_sim': tensor(0.4445, device='cuda:0'), 'mse': tensor(0.8494, device='cuda:0')}
model: {'cos_sim': tensor(0.0331, device='cuda:0'), 'mse': tensor(0.6336, device='cuda:0')}
vox:
model1: {'cos_sim': tensor(0.9164, device='cuda:0'), 'mse': tensor(0.0834, device='cuda:0')}
model2: {'cos_sim': tensor(0.8696, device='cuda:0'), 'mse': tensor(0.4103, device='cuda:0')}
model: {'cos_sim': tensor(0.1941, device='cuda:0'), 'mse': tensor(0.4037, device='cuda:0')}
vctk:
model1: {'cos_sim': tensor(0.5026, device='cuda:0'), 'mse': tensor(0.5190, device='cuda:0')}
model2: {'cos_sim': tensor(0.4478, device='cuda:0'), 'mse': tensor(0.9651, device='cuda:0')}
model: {'cos_sim': tensor(0.0831, device='cuda:0'), 'mse': tensor(0.6302, device='cuda:0')}


# Save model

In [22]:
# torch.save({'model': model.state_dict()}, "/home/yrb/code/ID-DEID/data/vae_model/vae_on_voxceleb.ckpt")
torch.save({'model': model.state_dict()}, "/home/yrb/code/ID-DEID/data/vae_model/vae_wo_cos_on_voxceleb.ckpt")

# Finetune

In [95]:
model = VAE(x_dim, hidden_dim, latent_dim, DEVICE).to(DEVICE)
checkpoint = torch.load("/home/yrb/code/ID-DEID/data/vae_model/l1_vae_on_voxceleb.ckpt", map_location=DEVICE)
model.load_state_dict(checkpoint['model'])

cos = nn.CosineSimilarity(dim=1, eps=1e-6)
mse = nn.MSELoss(reduction='sum')
l1 = nn.L1Loss(reduction='sum')
def loss_function(x, x_hat, mean, log_var):
    # cosine similarity loss
    cos_distance_loss = 100*(1-cos(x, x_hat)).sum()
    mse_loss = mse(x, x_hat)
    reconstruction_loss = cos_distance_loss + mse_loss
    KLD = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reconstruction_loss, KLD

def loss_function2(x, x_hat, mean, log_var):
    # cosine similarity loss
    cos_distance_loss = 200*(1-cos(x, x_hat)).sum()
    l1_loss = l1(x, x_hat)
    reconstruction_loss = cos_distance_loss + l1_loss
    KLD = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reconstruction_loss, KLD

optimizer = Adam(model.parameters(), lr=lr*0.1)

print("Preparing the data loader...")
vox_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/voxceleb_spmel/train.pkl", num_workers=2, batch_size=batch_size)
wsj_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/wsj_spmel/train.pkl", num_workers=2, batch_size=batch_size)
vctk_dataloader = get_loader(pkl_path="/home/yrb/data/ID-DEID_data/vctk_spmel/train.pkl", num_workers=2, batch_size=batch_size)

dataloader = wsj_dataloader

print("Start training VAE...")
model.train()

finetune_epochs = 5
for epoch in range(finetune_epochs):
    overall_loss, overall_reconst, overall_KLD = 0, 0, 0
    for batch_idx, (spkrids, x) in enumerate(dataloader):
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        reconst_loss, KLD = loss_function2(x, x_hat, mean, log_var)
        loss = reconst_loss + KLD
        overall_loss += loss.item()
        overall_reconst += reconst_loss.item()
        overall_KLD += KLD.item()

        loss.backward()
        optimizer.step()

    print("\tEpoch", epoch + 1, "complete!",
        "\tAverage Loss: ", overall_loss / (batch_idx*batch_size),
        "\tAverage reconst Loss: ", overall_reconst / (batch_idx*batch_size),
        "\tAverage KLD Loss: ", overall_KLD / (batch_idx*batch_size),)

print("Finish!!")

Preparing the data loader...
Finished loading the dataset...
Finished loading the dataset...
Finished loading the dataset...
Start training VAE...
	Epoch 1 complete! 	Average Loss:  117.84901721660907 	Average reconst Loss:  99.29551227276141 	Average KLD Loss:  18.55350538400503
	Epoch 2 complete! 	Average Loss:  98.447020310622 	Average reconst Loss:  79.60963146503155 	Average KLD Loss:  18.837389139028694
	Epoch 3 complete! 	Average Loss:  87.93128556471605 	Average reconst Loss:  69.55970294658954 	Average KLD Loss:  18.371582177969124
	Epoch 4 complete! 	Average Loss:  82.50406411977914 	Average reconst Loss:  64.76671336247371 	Average KLD Loss:  17.737350463867188
	Epoch 5 complete! 	Average Loss:  78.49950643686148 	Average reconst Loss:  61.14552483191857 	Average KLD Loss:  17.35398167830247
Finish!!


# Generate emb from noise

In [128]:
with torch.no_grad():
    noise = torch.randn(batch_size, 64).to(DEVICE)
    generated_emb = model1.Decoder(noise)

cos_sim = cos(generated_emb, x[4:5, :])
cos_sim, cos_sim.max(), cos_sim.mean()

(tensor([-0.0102,  0.0700,  0.0442, -0.0299, -0.0095, -0.1873, -0.2445,  0.0364,
         -0.1090,  0.0463,  0.2157,  0.0664,  0.0020,  0.1663, -0.1197, -0.2575,
         -0.0822, -0.1179,  0.0825, -0.1582, -0.0775,  0.0583, -0.1722,  0.0544,
          0.0935,  0.1655, -0.0128,  0.2362, -0.0468,  0.0649, -0.1068, -0.1868],
        device='cuda:0'),
 tensor(0.2362, device='cuda:0'),
 tensor(-0.0164, device='cuda:0'))