In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
################################
# Imports
################################

import pandas as pd
import h5py
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import glob
import re
from multiprocessing import cpu_count
from multiprocessing import Pool

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


import os
import sys
import subprocess
import random
from random import sample
from tqdm import tqdm
from functools import partial

from sklearn.metrics import confusion_matrix


import librosa

In [None]:
################################
# configs
################################

#################
# spectrogram
#################

# Mean of flute = -6.438913345336914
# Median of flute = -6.118330955505371
# stdDev of flute = 4.377377986907959
# Max of flute = 1.8442230224609375
# Min of flute = -39.0754280090332

# Mean of piano = -6.015857219696045
# Median of piano = -5.299488544464111
# stdDev of piano = 4.420877456665039
# Max of piano = 1.5825170278549194
# Min of piano = -40.520179748535156

spectrogramStats = {
                    'flute': {'mean': -6.438913345336914, 'median': -6.118330955505371, 'stdDev': 4.377377986907959, 'max': 1.8442230224609375, 'min': -39.0754280090332},
                    'piano': {'mean': -6.015857219696045, 'median': -5.299488544464111, 'stdDev': 4.420877456665039, 'max': 1.5825170278549194, 'min': -40.520179748535156}
                    }

# standardizationStyleOptions = normal, logNormal, uniform
# following needs to be 'SUBTRACTED' from the data
centerOffset = {
                'flute': {'normal': spectrogramStats['flute']['mean'], 'logNormal': spectrogramStats['flute']['mean'], 'uniform': (spectrogramStats['flute']['min'] + spectrogramStats['flute']['max'])/2},
                'piano': {'normal': spectrogramStats['piano']['mean'], 'logNormal': spectrogramStats['piano']['mean'], 'uniform': (spectrogramStats['piano']['min'] + spectrogramStats['piano']['max'])/2}
                }

# following needs to be 'DIVIDED' to the data
divFactor = {
            'flute': {'normal': 3*spectrogramStats['flute']['stdDev'], 'logNormal': (1.1*spectrogramStats['flute']['max'] - spectrogramStats['flute']['mean']) , 'uniform': 1*(spectrogramStats['flute']['max'] - spectrogramStats['flute']['min'])/2},
            'piano': {'normal': 3*spectrogramStats['piano']['stdDev'], 'logNormal': (1.1*spectrogramStats['piano']['max'] - spectrogramStats['piano']['mean']) , 'uniform': 1*(spectrogramStats['piano']['max'] - spectrogramStats['piano']['min'])/2}
            }


####################
# training params
####################
STANDARDIZATION_STYLE = 'uniform'
BATCH_SIZE = 4
NUM_WORKERS = 2
NUM_EPOCHS = 200
LEARNING_RATE = 0.0001
LAMBDA_CYCLE = 10
LAMBDA_IDENTITY = 5
GRADIENT_PENALTY = 0
ADAM_BETA1 = 0
ADAM_BETA2 = 0.9
NUM_RESIDUALS = 6

####################
# checkpoint options
####################
SAVE_CHECKPOINTS = True

####################
# find GPU device
####################
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print (f'Device = {DEVICE}, # of CUDA devices = {torch.cuda.device_count()}')


###################
# add fileTag
###################
fileTag = f'lr_{LEARNING_RATE}_cyc_{LAMBDA_CYCLE}_id_{LAMBDA_IDENTITY}_b1_{ADAM_BETA1}_b2_{ADAM_BETA2}_numRes_{NUM_RESIDUALS}_gp_{GRADIENT_PENALTY}_stdStyle_{STANDARDIZATION_STYLE}'


In [None]:
################################
# discriminator
################################

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 4, stride = stride, padding = 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=1, features=[16, 32, 64, 128], numFlatFeatures = 300):
        
        super().__init__()
        
        # output shape = 168 x 128
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size = 4, stride = 2, padding = 1, bias=True, padding_mode="reflect"),            
            nn.LeakyReLU(0.2)
        )
        
        
        layers = []   
        in_channels = features[0]
        
        # linear layer will have 20 x 15 features flattened
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=2))
            in_channels = feature        
        
        layers.append(nn.Sequential(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"), 
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(in_features = numFlatFeatures, out_features = 1)
        ))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):        
        x = self.initial(x)
        return torch.sigmoid(self.model(x))
        
    
def test():
    x = torch.randn((5,1,336,256))
    model = Discriminator(in_channels=1)
    print (model)
    preds = model(x)
    print(preds.shape)

test()

In [None]:
################################
# generator
################################

class EncoderConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_act=True, use_tanh = False, **kwargs):
        super().__init__()
        self.encoderConv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs),
            nn.InstanceNorm2d(out_channels),
            (nn.Tanh() if use_tanh else nn.ReLU(inplace=True)) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.encoderConv(x)


class DecoderConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_act=True, use_tanh = False):
        super().__init__()
        
        # nearest neighbor upsample + same convolution
        self.decoderConv = nn.Sequential(            
            nn.Upsample(scale_factor = 2, mode='nearest'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0),
            nn.InstanceNorm2d(out_channels),
            (nn.Tanh() if use_tanh else nn.ReLU(inplace=True)) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.decoderConv(x)
    

class ResidualBlock(nn.Module):
    def __init__(self, channels, use_tanh = False):
        super().__init__()
        self.block = nn.Sequential(
            EncoderConvBlock(channels, channels, use_act = True,  use_tanh = use_tanh, kernel_size=3, stride = 1, padding=1),
            EncoderConvBlock(channels, channels, use_act = False, use_tanh = use_tanh, kernel_size=3, stride = 1, padding=1)
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    
    def __init__(self, img_channels, use_tanh = False, num_features = 16, num_residuals=6):
        
        super().__init__()
        
        # same convolution. output channels = 16
        self.initial = nn.Sequential(            
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True)
        )
        
        
        self.down_blocks = nn.ModuleList(
            [
                # output shape = floor((W - F + 2P) / S) + 1 . 
                # Wout x Hout = 128 x 168, for Win x Hin = 256 x 336, output channels = 32
                EncoderConvBlock(num_features, num_features*2, use_act = True, use_tanh = use_tanh, kernel_size=3, stride=2, padding=1),
                
                # Wout x Hout = 64 x 84, for Win x Hin = 128 x 168, output channels = 64
                EncoderConvBlock(num_features*2, num_features*4, use_act = True, use_tanh = use_tanh, kernel_size=3, stride=2, padding=1),
            ]
        )

        # same convolutions in residual blocks. Shape remains 64 x 86, numChannels = 64
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4, use_tanh = use_tanh) for _ in range(num_residuals)]
        )

        self.up_blocks = nn.ModuleList(
            [
                # Wout x Hout = 128 x 168, numChannels = 32
                DecoderConvBlock(num_features*4, num_features*2, use_act = True, use_tanh = use_tanh),
                
                # Wout x Hout = 256 x 336, numChannels = 16
                DecoderConvBlock(num_features*2, num_features, use_act = True, use_tanh = use_tanh)
            ]
        )

        # same convolution, numChannels = 1
        self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        
        # same convolution
        x = self.initial(x)
        
        # down sampling
        for layer in self.down_blocks:
            x = layer(x)
            
        # same convolutions
        x = self.residual_blocks(x)
        
        # upsampling
        for layer in self.up_blocks:
            x = layer(x)
        
        # same convolution
        x = self.last(x)
        
        return torch.tanh(x)

    
def test():
    x = torch.randn((5,1,336,256))
    model = Generator(img_channels=1, num_features = 16, num_residuals=9)
    print (model)
    gen = model(x)    
    print(gen.shape)

    
test()

In [None]:
# data loader

class Dataset(torch.utils.data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, rootDir):
        'Initialization'
        self.rootDir = rootDir
        self.fileList = os.listdir(rootDir)        
        

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.fileList)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        X = np.load(os.path.join(self.rootDir, self.fileList[index]))
        
        return X

In [None]:
##########################
# dataset class
##########################

class PianoFluteDataset(Dataset):
    def __init__(self, root_piano, root_flute, standardizationStyle):
        self.root_piano = root_piano
        self.root_flute = root_flute

        random.seed(2)
        self.piano_images = os.listdir(root_piano)
        self.flute_images = sample(os.listdir(root_flute), len(os.listdir(root_flute)))
        
        self.length_dataset = max(len(self.piano_images), len(self.flute_images))

        self.piano_dataset_length = len(self.piano_images)
        self.flute_dataset_length = len(self.flute_images)
        
        self.standardizationStyle = standardizationStyle
        
        # default to normal for standardizing the data
        self.offsetFlute = centerOffset['flute']['normal']
        self.offsetPiano = centerOffset['piano']['normal']
        
        self.divFlute = divFactor['flute']['normal']
        self.divPiano = divFactor['piano']['normal']
        
        if standardizationStyle == 'uniform':
            print ('Using uniform assumption...')
            self.offsetFlute = centerOffset['flute']['uniform']
            self.offsetPiano = centerOffset['piano']['uniform']
        
            self.divFlute = divFactor['flute']['uniform']
            self.divPiano = divFactor['piano']['uniform']
            
        else:
            print ('Using logNormal assumption...')
            self.offsetFlute = centerOffset['flute']['logNormal']
            self.offsetPiano = centerOffset['piano']['logNormal']
        
            self.divFlute = divFactor['flute']['logNormal']
            self.divPiano = divFactor['piano']['logNormal']
        

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        
        flute_image = self.flute_images[index % self.flute_dataset_length]
        piano_image = self.piano_images[index % self.piano_dataset_length]

        flute_path = os.path.join(self.root_flute, flute_image)
        piano_path = os.path.join(self.root_piano, piano_image)
        
        flute_img = np.load(flute_path)
        piano_img = np.load(piano_path)
        
        # add extra dimension for the single channel
        flute_img = flute_img[None, :]
        piano_img = piano_img[None, :]
        
        # standardize the scale
        flute_img = (flute_img - self.offsetFlute) / self.divFlute
        piano_img = (piano_img - self.offsetPiano) / self.divPiano        

        return flute_img, piano_img

In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):    
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [None]:
################################
# paths to directories
################################

curDir = os.getcwd()
trainSetDir = f'{curDir}/../../dataSuperSet/processedData/trainSet'
fluteTrainSetDir = f'{trainSetDir}/flute/cqtChunks'
pianoTrainSetDir = f'{trainSetDir}/piano/cqtChunks'

# create directories to store checkpoint outputs
checkPointDir = f'{curDir}/checkPoints'
checkPointModelDir = f'{checkPointDir}/models'
checkPointImageDir = f'{checkPointDir}/images'
checkPointLossTrackingDir = f'{checkPointDir}/lossTracking'

os.system(f'mkdir -p {checkPointDir}')
os.system(f'mkdir -p {checkPointModelDir}')
os.system(f'mkdir -p {checkPointImageDir}')
os.system(f'mkdir -p {checkPointLossTrackingDir}')

In [None]:
################################
# datasets
################################

# create dataset
dataset = PianoFluteDataset(pianoTrainSetDir, fluteTrainSetDir, STANDARDIZATION_STYLE)

# create dataloader
loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [None]:
#########################################
# instantiate networks and optimizers
#########################################

# discriminator piano
disc_P = Discriminator(in_channels=1).to(DEVICE)

# discriminator flute
disc_F = Discriminator(in_channels=1).to(DEVICE)

# generator piano
gen_P = Generator(img_channels=1, num_residuals = NUM_RESIDUALS).to(DEVICE)

# generator piano
gen_F = Generator(img_channels=1, num_residuals = NUM_RESIDUALS).to(DEVICE)

# optimizer discriminator
opt_disc = optim.Adam(list(disc_P.parameters()) + list(disc_F.parameters()), lr = LEARNING_RATE, betas=(ADAM_BETA1, ADAM_BETA2))

# optimizer generator
opt_gen  = optim.Adam(list(gen_P.parameters())  + list(gen_F.parameters()),  lr = LEARNING_RATE, betas=(ADAM_BETA1, ADAM_BETA2))


# Losses

# For cycle consistency and identity loss
L1 = nn.L1Loss() 

# adversarial loss
mse = nn.MSELoss()


In [None]:
################################
# training
################################

generatorLossProgression = np.zeros(NUM_EPOCHS)
discriminatorLossProgression = np.zeros(NUM_EPOCHS)
identityLossProgression = np.zeros(NUM_EPOCHS)
cycleLossProgression = np.zeros(NUM_EPOCHS)

for epoch in range(NUM_EPOCHS):    
    
    for idx, (flute, piano) in enumerate(loader):
        
        # move data to device
        piano = piano.to(DEVICE)
        flute = flute.to(DEVICE)
                
        
        ##############################
        # Discriminator training
        ##############################
        
        # piano generator output 
        fake_piano = gen_P (flute)        
        
        # piano discriminator
        D_P_real = disc_P(piano)
        D_P_fake = disc_P(fake_piano.detach())                
        
        # real = 1, fake = 0
        D_P_real_loss = mse(D_P_real, torch.ones_like(D_P_real))
        D_P_fake_loss = mse(D_P_fake, torch.zeros_like(D_P_real))
                
        
        # total piano discriminator loss
        Disc_piano_loss = D_P_real_loss + D_P_fake_loss

        # flute generator output
        fake_flute = gen_F(piano)
        
        # flute discriminator
        D_F_real = disc_F(flute)
        D_F_fake = disc_F(fake_flute.detach())
        
        # real = 1, fake = 0
        D_F_real_loss = mse(D_F_real, torch.ones_like(D_F_real))
        D_F_fake_loss = mse(D_F_fake, torch.zeros_like(D_F_real))
        
        # total flute discriminator loss
        Disc_flute_loss = D_F_real_loss + D_F_fake_loss

        # Overall discriminator loss
        D_loss = (Disc_flute_loss + Disc_piano_loss)/2

        # zero out the gradients
        opt_disc.zero_grad()
        
        # backprop
        D_loss.backward()
        
        # update discriminator params
        opt_disc.step()        
        
        ##############################
        # Generator training
        ##############################
        
        # Adversarial Loss for both generators
        D_P_fake = disc_P(fake_piano)
        D_F_fake = disc_F(fake_flute)
        
        # generator wants fake to be detected real
        loss_G_F = mse(D_F_fake, torch.ones_like(D_F_fake))
        loss_G_P = mse(D_P_fake, torch.ones_like(D_P_fake))

        # Cycle Loss
        cycle_piano = gen_P(fake_flute)
        cycle_flute = gen_F(fake_piano)
        cycle_piano_loss = L1(piano, cycle_piano)
        cycle_flute_loss = L1(flute, cycle_flute)
        
        cycle_loss = cycle_flute_loss + cycle_piano_loss

        # Identity Loss
        identity_flute = gen_F(flute)
        identity_piano = gen_P(piano)
        identity_piano_loss = L1(piano, identity_piano)
        identity_flute_loss = L1(flute, identity_flute)
        
        identity_loss = identity_flute_loss + identity_piano_loss

        # Overall Generator Loss
        G_loss = loss_G_F + loss_G_P + cycle_loss*LAMBDA_CYCLE + identity_loss*LAMBDA_IDENTITY

        # zero out the gradients
        opt_gen.zero_grad()
        
        # backprop
        G_loss.backward()
        
        # update generator params
        opt_gen.step()
        
        if idx % 100 == 0:
            
            # print the current losses
            print (f'[epoch = {epoch}, idx = {idx}]:   D_loss = {D_loss.item()}, \t G_loss = {G_loss.item()}, \t identity_loss = {identity_loss.item()}, \t cycle_loss = {cycle_loss.item()}')            
            
        if idx % 1250 == 0:      
            
            # save the fake piano and flute np array of a sample            
            np.save(f'{checkPointImageDir}/originalPiano__idx_{idx}__{fileTag}.npy', np.squeeze(piano.detach().cpu().numpy()[0,:]))
            np.save(f'{checkPointImageDir}/fakeFlute__idx_{idx}__{fileTag}.npy', np.squeeze(fake_flute.detach().cpu().numpy()[0,:]))            
            np.save(f'{checkPointImageDir}/originalFlute__idx_{idx}__{fileTag}.npy', np.squeeze(flute.detach().cpu().numpy()[0,:]))
            np.save(f'{checkPointImageDir}/fakePiano__idx_{idx}__{fileTag}.npy', np.squeeze(fake_piano.detach().cpu().numpy()[0,:]))
        
        
        # aggregate loss over all batches
        discriminatorLossProgression[epoch] += D_loss.item()
        generatorLossProgression[epoch] += G_loss.item()
        identityLossProgression[epoch] += identity_loss.item()
        cycleLossProgression[epoch] += cycle_loss.item()
    
    # average loss over the length of dataset
    discriminatorLossProgression[epoch] = discriminatorLossProgression[epoch] / len(dataset)
    generatorLossProgression[epoch] = generatorLossProgression[epoch] / len(dataset)
    identityLossProgression[epoch] = identityLossProgression[epoch] / len(dataset)
    cycleLossProgression[epoch] = cycleLossProgression[epoch] / len(dataset)
    
    # save the model and current losses
    if SAVE_CHECKPOINTS:
        save_checkpoint(gen_P, opt_gen, filename = f'{checkPointModelDir}/genp__{fileTag}.pth.tar')
        save_checkpoint(gen_F, opt_gen, filename = f'{checkPointModelDir}/genf__{fileTag}.pth.tar')
        save_checkpoint(disc_P, opt_disc, filename = f'{checkPointModelDir}/criticp__{fileTag}.pth.tar')
        save_checkpoint(disc_F, opt_disc, filename = f'{checkPointModelDir}/criticf__{fileTag}.pth.tar')
    
        # save the lossProgressions
        np.save(f'{checkPointLossTrackingDir}/discriminatorLossProgression__{fileTag}.npy', discriminatorLossProgression)
        np.save(f'{checkPointLossTrackingDir}/generatorLossProgression__{fileTag}.npy', generatorLossProgression)
        np.save(f'{checkPointLossTrackingDir}/identityLossProgression__{fileTag}.npy', identityLossProgression)
        np.save(f'{checkPointLossTrackingDir}/cycleLossProgression__{fileTag}.npy', cycleLossProgression)
    
    


