In [None]:
!pip install pytorch-lightning-bolts;
!pip install gdown;
!pip install openpyxl;

In [4]:
# Download spectogram dataset from google drive
!gdown 1Wj7Hl5d94iWFzmnoMdofnWq6XQJvTZor

In [None]:
!unzip -q spectogram_dataset.zip

In [6]:
import os
import shutil
import gdown
from functools import partial
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
from PIL import Image
import random
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torch.optim import Adam

import torchvision
from torchvision import transforms
import torchvision.datasets as datasets

from pl_bolts.models.autoencoders.components import (
    resnet18_decoder,
    resnet18_encoder,
)
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger

pl.seed_everything(1234)

## Data Preprocessing

In [7]:
class SpectogramDataset(Dataset):
    '''
    A Pytorch dataset that loads all spectograms from the directory passed in path_to_images.
    Or if images_names!=None, loads only spectograms that are listed in images_names.
    '''
    def __init__(self,path_to_images,images_names = None,transforms_func = None):


        self.path_to_images = path_to_images
        self.images_names = images_names if images_names else os.listdir(path_to_images)
        self.transforms = transforms_func
        
    def __len__(self):
        return len(self.images_names)
    
    def __getitem__(self, idx):
        full_path_to_image = os.path.join(self.path_to_images,self.images_names[idx])
        # Read a grayscale image
        image = Image.open(full_path_to_image).convert('L')
        
        if self.transforms!=None:
            image = self.transforms(image)

        return  image

In [8]:
base_path = './spectogram_dataset'
train_path = os.path.join(base_path,'train')
valid_path = os.path.join(base_path,'valid')
test_path = os.path.join(base_path,'test')

base_image_transforms = [
    transforms.ToTensor(),
    transforms.Resize((216,216))
]

train_ds = SpectogramDataset(train_path, transforms_func = transforms.Compose(base_image_transforms))
valid_ds = SpectogramDataset(valid_path, transforms_func = transforms.Compose(base_image_transforms))
test_ds = SpectogramDataset(test_path,   transforms_func = transforms.Compose(base_image_transforms))

In [9]:
print(f'Train len: {len(train_ds)}')
print(f'Valid len: {len(valid_ds)}')
print(f'Test len: {len(test_ds)}')

In [10]:
class VAE(pl.LightningModule):
    def __init__(self, enc_out_dim=512, latent_dim=128, input_height=28, in_channels=1, lr=1e-3):
        super().__init__()

        self.save_hyperparameters()
        
        self.lr = lr
        # encoder, decoder
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )
        # Edit the first Conv layer to adjust the model for images with different
        # number of channels
        self.encoder.conv1 = nn.Conv2d(in_channels, out_channels=64, 
                               kernel_size = (3,3), stride=(1,1), 
                               padding=(1,1), bias=False)
        self.decoder.conv1 = nn.Conv2d(in_channels=64, out_channels=in_channels, 
                               kernel_size = (3,3), stride=(1,1), 
                               padding=(1,1), bias=False)

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def gaussian_likelihood(self, x_hat, logscale, x):
        scale = torch.exp(logscale)
        mean = x_hat
        dist = torch.distributions.Normal(mean, scale)

        # measure prob of seeing image under p(x|z)
        log_pxz = dist.log_prob(x)
        return log_pxz.sum(dim=(1, 2, 3))

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl
    
    def encode(self, x):
        self.eval()
        with torch.no_grad():
            x_encoded = self.encoder(x)
            mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)
            std = torch.exp(log_var / 2)
            z = self.sample_z_from_q(mu,std)
        return z
            
        
    def sample_z_from_q(self,mu,std):
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return z

    def training_step(self, batch, batch_idx):
        x = batch

        # encode x to get the mu and variance parameters
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        z = self.sample_z_from_q(mu,std)

        # decoded
        x_hat = self.decoder(z)

        # reconstruction loss
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo_train': elbo,
            'kl_train': kl.mean(),
            'recon_loss_train': recon_loss.mean(),
        },
            on_step=False, on_epoch=True)

        return elbo
    
    def validation_step(self, batch, batch_idx):
        x = batch

        # encode x to get the mu and variance parameters
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # decoded
        x_hat = self.decoder(z)

        # reconstruction loss
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo_valid': elbo,
            'kl_valid': kl.mean(),
            'recon_loss_valid': recon_loss.mean(),
        },
            on_step=False, on_epoch=True)

        return x_hat,elbo

In [11]:
LR = 1e-3
EPOCHS = 100
BS = 16
num_workers = 2

device='cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

In [12]:
train_loader = DataLoader(
train_ds, batch_size=BS, shuffle=True, num_workers=num_workers)

valid_loader = DataLoader(
valid_ds, batch_size=BS, shuffle=False, num_workers=num_workers)

test_loader = DataLoader(
test_ds, batch_size=BS, shuffle=False, num_workers=num_workers)

In [13]:
'''
train_set =   torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,)),
                               torchvision.transforms.Resize((24,24))]))

test_set =  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,)),
                               torchvision.transforms.Resize((24,24))
                             ]))

# Random split
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size])

print(f'Train set: {len(train_set)}')
print(f'Valid set: {len(valid_set)}')
print(f'Test set: {len(test_set)}')

train_loader = torch.utils.data.DataLoader(
train_set, batch_size=BS, shuffle=True)

valid_loader = torch.utils.data.DataLoader(
valid_set,batch_size=BS, shuffle=True)

test_loader = torch.utils.data.DataLoader(
 test_set,batch_size=BS, shuffle=False)
 '''

## Train

In [14]:
logger = CSVLogger("logs", name="my_exp_name")
vae = VAE(input_height=216, in_channels=1, latent_dim=128, lr=LR)
trainer = pl.Trainer(accelerator='gpu' if device=='cuda' else device,max_epochs=EPOCHS, logger=logger)

In [15]:
trainer.fit(vae, train_loader,valid_loader)

## Training Metrics

In [16]:
def merge_logs(logs):
    merged_dict = []
    for i in range(0,len(logs),2):
        row = {'elbo_valid':logs.iloc[i]['elbo_valid'],
              'kl_valid':logs.iloc[i]['kl_valid'],
              'recon_loss_valid':logs.iloc[i]['recon_loss_valid'],
              'epoch':logs.iloc[i]['epoch'],
              'step':logs.iloc[i]['step'],
              'elbo_train':logs.iloc[i+1]['elbo_train'],
              'kl_train':logs.iloc[i+1]['kl_train'],
              'recon_loss_train':logs.iloc[i+1]['recon_loss_train']}
        merged_dict.append(row)

    merged_logs = pd.DataFrame(merged_dict)
    return merged_logs

In [17]:
def plot_stats(merged_logs):
    fig, axs = plt.subplots(2, 2, figsize=(8,6),dpi=150)

    axs[0][0].plot(merged_logs['epoch'],merged_logs['elbo_train'], 'b-',label='elbo_train')
    axs[0][0].plot(merged_logs['epoch'],merged_logs['elbo_valid'], 'r--',label='elbo_valid')
    axs[0][0].set_title('Elbo')
    axs[0][0].set(xlabel='epoch', ylabel='Elbo loss')
    axs[0][0].legend()
    axs[0][0].label_outer()

    axs[1][0].plot(merged_logs['epoch'],merged_logs['kl_train'], 'b-',label='kl_train')
    axs[1][0].plot(merged_logs['epoch'],merged_logs['kl_valid'], 'r--',label='kl_valid')
    axs[1][0].set_title('KL')
    axs[1][0].set(xlabel='epoch', ylabel='KL loss')
    axs[1][0].legend()
    axs[1][0].label_outer()

    axs[0][1].label_outer()


    axs[1][1].plot(merged_logs['epoch'],merged_logs['recon_loss_train'], 'b-',label='recon_loss_train')
    axs[1][1].plot(merged_logs['epoch'],merged_logs['recon_loss_valid'], 'r--',label='recon_loss_valid')
    axs[1][1].set_title('Recon loss')
    axs[1][1].set(xlabel='epoch', ylabel='Recon loss')
    axs[1][1].invert_yaxis()
    axs[1][1].legend()
    #axs[1][1].label_outer()

    plt.show()



In [18]:
logs = pd.read_csv('./logs/my_exp_name/version_0/metrics.csv')
merged_logs = merge_logs(logs)
plot_stats(merged_logs)

## Testing 

In [27]:
def plot_distances(beats_names, base_path, similarity_metric ):
    '''Calculates the distances of all the voxes in base_path to all the beats from the beats_names list,
    and returns a list of stylers (dataframes with changed styles) which contain all the distances, and 
    the distances from the voxes from the same beat are highlighten green. '''
    vae.eval()
    
    # Step 1. Encode the beats
    beats_ds = SpectogramDataset(base_path, beats_names,transforms.Compose(base_image_transforms))
    beats_dl = DataLoader(beats_ds, batch_size=len(beats_names), shuffle=False)
    encoded_beats = []
    for beats_batch in tqdm(beats_dl):
        encoded_beats.append(vae.encode(beats_batch.to(device))) 

    encoded_beats = torch.cat(encoded_beats)
    
    # Step 2. Encode all the voxes in the base directory
    # Take only original voxes (not augmented)
    vox_names = [file_name for file_name in os.listdir(base_path) if (('vox'in file_name) and not 
                                                                                  ('pitch_scale' in file_name or 
                                                                                  'white_noise' in file_name or
                                                                                  'time_stretch' in file_name))]
    
    voxes_ds = SpectogramDataset(base_path, vox_names,transforms.Compose(base_image_transforms))
    voxes_dl = DataLoader(voxes_ds, batch_size=BS, shuffle=False)
    encoded_voxes = []
    for voxes_batch in tqdm(voxes_dl):
        encoded_voxes.append(vae.encode(voxes_batch.to(device))) 

    encoded_voxes = torch.cat(encoded_voxes)
    
    # Step 3, 4. Calculate distances, prettify the final dataframe 
    def highlight_same_song(df, song_name):
        if song_name in df['file_name']:
            return ['background-color: green'] * len(df)
        else:
            return ['background-color: white'] * len(df)
        
    stylers_list = []
    for i in range(len(encoded_beats)):
        # Calculate similiarity between a specific beat section and all the voxes
        distances = similarity_metric(encoded_voxes, encoded_beats[i][None,:])
        # Combine voxes file names and their distances into one dataframe
        distances_df = pd.DataFrame({'file_name': vox_names, 'distance':distances.cpu().numpy()}, columns = ['file_name', 'distance'])
        distances_df.sort_values(by=['distance'], inplace=True)
        distances_df.reset_index(inplace=True, drop=True)
        # 
        song_name = beats_names[i].split('beat')[0][:-1]
        highlight_current_song = partial(highlight_same_song, song_name=song_name)
        styler = distances_df.style.apply(highlight_current_song, axis=1)
        styler = styler.set_caption(f'Distances to {song_name}')
    
        stylers_list.append(styler)
        
    return stylers_list

#### On Train

In [20]:
cos_similiarity = nn.CosineSimilarity(eps=1e-6)
def cos_distance(input1,input2):
    return 1 - cos_similiarity(input1,input2)

In [21]:
beats_names_from_train = ['baba_yaga_beat_section_1.png',
                         'be_afraid_my_enemy_beat_section_3.png',
                         'body_minor_beat_section_4.png',
                         'dance_alone_beat_section_1.png']

In [28]:
result_train = plot_distances(beats_names_from_train, train_path, cos_distance)

In [44]:
for styler in result_train:
    styler.to_excel(f'{styler.caption}.xlsx')

#### On Test + Valid

In [None]:
# Merge the content of two directories for convenient access
def copy_dir_content(src, trg):
    for file_name in tqdm(os.listdir(src)):
        shutil.copy2(os.path.join(src,file_name), trg)
        
valid_test_path = os.path.join(base_path, 'test_valid')
if not os.path.exists(valid_test_path):
    os.mkdir(valid_test_path)  
    copy_dir_content(valid_path, valid_test_path)
    copy_dir_content(test_path, valid_test_path)
    
print(len(os.listdir(valid_test_path)))

In [None]:
beats_names_from_test_valid = ['body_minor_beat_section_2.png',
                         'i_got_everything_beat_section_2.png',
                         'tail_about_bogatir_beat_section_2.png',
                         'Game_Over_beat_section_1.png']

In [None]:
result_test_valid = plot_distances(beats_names_from_test_valid, valid_test_path, cos_distance)