In [1]:
import os
import sys
import time
import math
import random
import bisect
import shutil
import PIL
import gc

import librosa
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as T
import torchaudio

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm
from collections import OrderedDict
from multiprocessing import Pool

import warnings
warnings.filterwarnings("ignore")

In [2]:
class func_pbar:
    def __init__(self, target_iters, pbar=None):
        from functools import wraps
        from inspect import isawaitable
        self.wr = wraps
        self.is_async = isawaitable
        self.target_iters = target_iters
        self.first_call = True
        self.pbar_func = pbar

    def init_pbar(self):
        if not self.pbar_func:
            from tqdm import tqdm
            self.pbar = tqdm(total=self.target_iters, leave=False)
        else:
            self.pbar = self.pbar_func(total=self.target_iters, leave=False)
    
    def __call__(self, fn):
        @self.wr(fn)
        async def async_wrapper(*args, **kwargs):
            if self.first_call:
                self.first_call = False
                self.init_pbar()
            result = await fn(*args, **kwargs)
            self.pbar.update(1)
            return result
        
        @self.wr(fn)
        def wrapper(*args, **kwargs):
            if self.first_call:
                self.first_call = False
                self.init_pbar()
            result = fn(*args, **kwargs)
            self.pbar.update(1)
            return result
        
        if self.is_async(fn):
            return async_wrapper
        else:
            return wrapper

In [3]:
def fix_dataset(path, new_path):
    music_downloaded = os.listdir(path)
    fixed = list(map(lambda x: x[:-3]+'wav', music_downloaded))
    shutil.copytree(path, new_path)
    [os.rename(new_path+'/'+src, new_path+'/'+dst) for src, dst in zip(music_downloaded, fixed)]

In [4]:
class BasicModule(nn.Module):
    def __init__(self):
        super(BasicModule, self).__init__()
        self.model_name = str(type(self))

    def load(self, path):
        self.load_state_dict(torch.load(path))

    def save(self, name=None):
        prefix = 'check_points/' + self.model_name +name+ '/'
        if not os.path.isdir(prefix):
            os.mkdir(prefix)
        name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
        print('model name', name.split('/')[-1] )
        torch.save(self.state_dict(), name)
        torch.save(self.state_dict(), prefix+'latest.pth')
        return name
    
    def get_optimizer(self, lr, weight_decay):
        return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
    
    def load_latest(self, notes):
        path = 'check_points/' + self.model_name +notes+ '/latest.pth'
        self.load_state_dict(torch.load(path))

In [5]:
class CQTNet(BasicModule):
    def __init__(self, emb_size=300):
        super().__init__()
        self.emb_size = emb_size
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(1, 32, kernel_size=(12, 3), dilation=(1, 1), padding=(6, 0), bias=False)),
            ('norm0', nn.BatchNorm2d(32)), ('relu0', nn.ReLU(inplace=True)),
            ('conv1', nn.Conv2d(32, 64, kernel_size=(13, 3), dilation=(1, 2), bias=False)),
            ('norm1', nn.BatchNorm2d(64)), ('relu1', nn.ReLU(inplace=True)),
            ('pool1', nn.MaxPool2d((1, 2), stride=(1, 2), padding=(0, 1))),

            ('conv2', nn.Conv2d(64, 64, kernel_size=(13, 3), dilation=(1, 1), bias=False)),
            ('norm2', nn.BatchNorm2d(64)), ('relu2', nn.ReLU(inplace=True)),
            ('conv3', nn.Conv2d(64, 64, kernel_size=(3, 3), dilation=(1, 2), bias=False)),
            ('norm3', nn.BatchNorm2d(64)), ('relu3', nn.ReLU(inplace=True)),
            ('pool3', nn.MaxPool2d((1, 2), stride=(1, 2), padding=(0, 1))),

            ('conv4', nn.Conv2d(64, 128, kernel_size=(3, 3), dilation=(1, 1), bias=False)),
            ('norm4', nn.BatchNorm2d(128)), ('relu4', nn.ReLU(inplace=True)),
            ('conv5', nn.Conv2d(128, 128, kernel_size=(3, 3), dilation=(1, 2), bias=False)),
            ('norm5', nn.BatchNorm2d(128)), ('relu5', nn.ReLU(inplace=True)),
            ('pool5', nn.MaxPool2d((1, 2), stride=(1, 2), padding=(0, 1))),

            ('conv6', nn.Conv2d(128, 256, kernel_size=(3, 3), dilation=(1, 1), bias=False)),
            ('norm6', nn.BatchNorm2d(256)), ('relu6', nn.ReLU(inplace=True)),
            ('conv7', nn.Conv2d(256, 256, kernel_size=(3, 3), dilation=(1, 2), bias=False)),
            ('norm7', nn.BatchNorm2d(256)), ('relu7', nn.ReLU(inplace=True)),
            ('pool7', nn.MaxPool2d((1, 2), stride=(1, 2), padding=(0, 1))),

            ('conv8', nn.Conv2d(256, 512, kernel_size=(3, 3), dilation=(1, 1), bias=False)),
            ('norm8', nn.BatchNorm2d(512)), ('relu8', nn.ReLU(inplace=True)),
            ('conv9', nn.Conv2d(512, 512, kernel_size=(3, 3), dilation=(1, 2), bias=False)),
            ('norm9', nn.BatchNorm2d(512)), ('relu9', nn.ReLU(inplace=True)),
        ]))
        self.pool = nn.AdaptiveMaxPool2d((1, 1))
        self.fc0 = nn.Linear(512, emb_size)

    def forward(self, song1, samples1=None, song2=None, samples2=None):
        inputs = list(filter(lambda x: x is not None, [song1, samples1, song2, samples2]))
        outputs = []
        for x in inputs:
            shape = x.shape
            if len(shape) == 3:
                x = x.view(1, *shape)
            elif len(shape) == 5:
                x = x.view(-1, *shape[2:])
            N = x.size()[0]
            x = self.features(x)  # [N, 512, 57, 2~15]
            x = self.pool(x)
            x = x.view(N, -1)
            feature = self.fc0(x)
            if len(shape) == 5:
                feature = feature.view(*shape[:2], -1)
            outputs.append(feature)
        return tuple(outputs)

In [6]:
music_path = '/kaggle/input/last-fm-us-music/music'
data_path = '/kaggle/input/last-fm-us-music-data/music_data'
#fixed_path = '/kaggle/music'

In [7]:
#fix_dataset(music_path, fixed_path)

In [8]:
#@func_pbar(len(os.listdir(new_path)), tqdm)
def preprocess(filename: str, new_filename: str, cqt_time_reduction=20) -> None:
    song, sr = librosa.load(filename)
    cqt = np.abs(librosa.cqt(y=song, sr=sr))
    height, length = cqt.shape
    cqt_compressed = cqt[:, :(length//cqt_time_reduction)*cqt_time_reduction].reshape(height, -1, cqt_time_reduction).mean(axis=2)
    np.save(new_filename, cqt_compressed)

def preprocess_path(path: str, out_path: str, cqt_time_reduction = 20, threads: int = 1) -> None:
    file_list = os.listdir(path)
    new_file_list = list(map(lambda x: out_path + '/' * int(out_path[-1] != '/') + x[:x.rfind('.')] + '.npy', file_list))
    file_list = list(map(lambda x: path + '/' * int(path[-1] != '/') + x, file_list))
    os.makedirs(out_path, exist_ok=True)
    pbar = func_pbar(len(file_list), pbar=tqdm)
    if threads == 1:
        for file, new_file in zip(file_list, new_file_list):\
            pbar(preprocess)(file, new_file, cqt_time_reduction)
            #preprocess(file, new_file, cqt_time_reduction)
    elif threads < 1:
        raise ValueError("Threads must be a positive integer")
    else:
        pool = Pool(threads)
        #pool.starmap(pbar(preprocess), list(zip(file_list, new_file_list, [cqt_time_reduction]*len(file_list))))
        pool.starmap(preprocess, list(zip(file_list, new_file_list, [cqt_time_reduction]*len(file_list))))
        pool.close()
        pool.join()

In [9]:
def cut_data(data, out_length, pad_to=None):
    if out_length is not None:
        if data.shape[0] > out_length:
            max_offset = data.shape[0] - out_length
            offset = np.random.randint(max_offset)
            data = data[offset:(out_length+offset),:]
        else:
            offset = out_length - data.shape[0]
            data = np.pad(data, ((0,offset),(0,0)), "constant")
    if data.shape[0] < 200:
        offset = 200 - data.shape[0]
        data = np.pad(data, ((0,offset),(0,0)), "constant")
        
    if pad_to:
        data = np.pad(data, ((0,pad_to-data.shape[0]),(0,0)), "constant")
        
    return data

def cut_data_front(data, out_length):
    if out_length is not None:
        if data.shape[0] > out_length:
            max_offset = data.shape[0] - out_length
            offset = 0
            data = data[offset:(out_length+offset),:]
        else:
            offset = out_length - data.shape[0]
            data = np.pad(data, ((0,offset),(0,0)), "constant")
    if data.shape[0] < 200:
        offset = 200 - data.shape[0]
        data = np.pad(data, ((0,offset),(0,0)), "constant")
    return data

def shorter(feature, mean_size=2):
    length, height  = feature.shape
    new_f = np.zeros((int(length/mean_size),height),dtype=np.float64)
    for i in range(int(length/mean_size)):
        new_f[i,:] = feature[i*mean_size:(i+1)*mean_size,:].mean(axis=0)
    return new_f

def change_speed(data, l=0.7, r=1.5): # change data.shape[0]
    new_len = int(data.shape[0]*np.random.uniform(l,r))
    maxx = np.max(data)+1
    data0 = PIL.Image.fromarray((data*255.0/maxx).astype(np.uint8))
    transform = transforms.Compose([
        transforms.Resize(size=(new_len,data.shape[1])), 
    ])
    new_data = transform(data0)
    return np.array(new_data)/255.0*maxx

def SpecAugment(data):
    F = 24
    f = np.random.randint(F)
    f0 = np.random.randint(84-f)
    data[f0:f0+f,:]*=0
    return data

class CQT(Dataset):
    def __init__(self, file_list, mode='train', out_length=None, n_samples=1):
        #self.data_path = data_path
        self.mode = mode
        self.file_list = file_list
        #self.file_list = list(map(lambda x: data_path + '/' * int(data_path[-1] != '/') + x, os.listdir(data_path)))
        self.out_length = out_length
        self.n_samples = n_samples
        self.transforms = {
            'train_song': T.Compose([
                lambda x : x.T,
                lambda x : x.astype(np.float32) / (np.max(np.abs(x))+ 1e-6),
                lambda x : cut_data(x, self.out_length),
                lambda x : torch.Tensor(x),
                lambda x : x.permute(1,0).unsqueeze(0),
            ]),
            'train_sample': T.Compose([
                lambda x : x.T,
                lambda x : x.astype(np.float32) / (np.max(np.abs(x))+ 1e-6),
                lambda x : cut_data(x, np.random.randint(10, self.out_length), self.out_length),
                lambda x : torch.Tensor(x),
                lambda x : x.permute(1,0).unsqueeze(0),
            ])
        }
        
    def __getitem__(self, index):
        index2 = np.random.choice(np.concatenate([np.arange(index), np.arange(index + 1, len(self.file_list))]))
        songs, samples = [], []
        for i, index in enumerate((index, index2)):
            filename = self.file_list[index]
            data = np.load(filename)
            songs.append(self.transforms['train_song'](data))
            samples.append([])
            for j in range(self.n_samples):
                samples[i].append(self.transforms['train_sample'](data))
            samples[i] = torch.stack(samples[i])
        return (songs[0], samples[0], songs[1], samples[1], index)
    
    def __len__(self):
        return len(self.file_list)

In [10]:
data_dir = '/kaggle/data/'

In [11]:
#preprocess_path(music_path, data_dir, threads=1)

In [12]:
#!zip -r data.zip /kaggle/data

In [13]:
#shutil.rmtree(data_dir)

In [14]:
def train(model, dataloaders, n_epochs, criterion, metric, optimizer, scheduler=None, device='cpu', backup_freq=None, init_epoch=0):
    model.to(device)
    print('Computing datasets')
    with torch.no_grad():
        for phase in tqdm(('train', 'val')):
            datasets[phase] = torch.zeros((len(dataloaders[phase].dataset), model.module.emb_size)).to(device)
            for item in tqdm(dataloaders[phase]):
                dataloaders[item[2]] = model(item[0].to(device))[0].detach()
    print('Training')
    for epoch in tqdm(range(init_epoch, n_epochs)):
        print(f'Epoch {epoch+1}')
        for phase in ('train', 'val'):
            mean_loss = 0
            mean_loss_pos_neg = 0
            mean_accuracy = 0
            num_iters = 0
            
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            for batch in tqdm(dataloaders[phase]):
                song1, samples1, song2, samples2, index = tuple(map(lambda x: x.to(device), batch))
                
                if phase == 'train':
                    song1_emb, samples1_emb, song2_emb, samples2_emb = model(song1, samples1, song2, samples2)
                else:
                    with torch.no_grad():
                        song1_emb, samples1_emb, song2_emb, samples2_emb = model(song1, samples1, song2, samples2)
                
                loss = criterion(song1_emb, samples1_emb, song2_emb, samples2_emb)
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    
                datasets[phase][index, :] = song1_emb
                mean_accuracy += metric(samples1_emb, datasets[phase], index)
                mean_loss += loss
                mean_loss_pos_neg += loss_pos_neg_fixed(song1_emb, samples1_emb, song2_emb, samples2_emb)
                num_iters += 1
            print(f'{phase} loss: {mean_loss/num_iters}')
            print(f'{phase} loss positive: {mean_loss_pos_neg[0]/num_iters}')
            print(f'{phase} loss negative: {mean_loss_pos_neg[1]/num_iters}')
            print(f'{phase} accuracy: {mean_accuracy/num_iters}')
        
        if scheduler:
            if type(scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
                scheduler.step(mean_loss/num_iters)
            else:
                scheduler.step()
        
        if backup_freq:
            if (epoch+1) % backup_freq == 0:
                serializer.backup(model, f'model_ep{epoch+1}.pth')
                
        del song1, samples1, song2, samples2, song1_emb, samples1_emb, song2_emb, samples2_emb, index, loss, mean_loss, mean_loss_pos_neg, mean_accuracy, num_iters
        gc.collect()

In [15]:
len(os.listdir(data_path))

6341

In [16]:
file_list = list(map(lambda x: data_path + '/' * int(data_path[-1] != '/') + x, os.listdir(data_path)))

In [17]:
files_train, files_test = train_test_split(
    file_list,
    test_size=0.2,
    random_state=42
)

In [18]:
def triplet_loss(song1=None, samples1=None, song2=None, samples2=None, positive_rate=0.5):
    def loss_fn(song1, samples1=None, song2=None, samples2=None):
        loss = 0
        if samples1 is not None and song2 is not None and samples2 is not None:
            norm_coef = 1 / (song1.shape[0] * samples1.shape[1] * (samples1.shape[1] + 1))
            loss += (torch.pairwise_distance(song1[:, None, :], samples1).sum() +
                     torch.pairwise_distance(samples1[:, :, None, :], samples1[:, None, :, :]).sum() / 2 +
                     torch.pairwise_distance(song2[:, None, :], samples2).sum() +
                     torch.pairwise_distance(samples2[:, :, None, :], samples2[:, None, :, :]).sum() / 2)\
                    * positive_rate * norm_coef
            norm_coef =  song1.shape[0] * (samples1.shape[1] + 1) ** 2
            loss += 1 / (torch.pairwise_distance(song1, song2).sum() +
                         torch.pairwise_distance(samples1, samples2).sum() +
                         torch.pairwise_distance(song1[:, None, :], samples2).sum() +
                         torch.pairwise_distance(song1[:, None, :], samples2).sum())\
                    * (1 - positive_rate) * norm_coef
        return loss
    if song1 is None and samples1 is None and song2 is None and samples2 is None:
        return loss_fn
    return loss_fn(song1, samples1, song2, samples2)

In [19]:
def triplet_loss_fixed(song1=None, samples1=None, song2=None, samples2=None, positive_rate=0.5):
    def loss_fn(song1, samples1=None, song2=None, samples2=None):
        loss = 0
        if samples1 is not None and song2 is not None and samples2 is not None:
            norm_coef = 1 / 2
            loss += (torch.pairwise_distance(song1[:, None, :], samples1).mean() +
                     torch.pairwise_distance(song2[:, None, :], samples2).mean())\
                    * positive_rate * norm_coef
            loss += 1 / torch.pairwise_distance(song1, song2).mean()\
                    * (1 - positive_rate)
        return loss
    if song1 is None and samples1 is None and song2 is None and samples2 is None:
        return loss_fn
    return loss_fn(song1, samples1, song2, samples2)

In [30]:
def triplet_loss_fixed(song1=None, samples1=None, song2=None, samples2=None, positive_rate=0.5):
    def loss_fn(song1, samples1=None, song2=None, samples2=None):
        loss = 0
        if samples1 is not None and song2 is not None and samples2 is not None:
            norm_coef = 1
            loss += torch.pairwise_distance(song1[:, None, :], samples1).mean()\
                    * positive_rate * norm_coef
            loss += 1 / torch.pairwise_distance(song1, song2).mean()\
                    * (1 - positive_rate)
        return loss
    if song1 is None and samples1 is None and song2 is None and samples2 is None:
        return loss_fn
    return loss_fn(song1, samples1, song2, samples2)

In [20]:
def loss_pos_neg(song1, samples1, song2, samples2):
    loss = torch.zeros((2))
    norm_coef = 1 / (song1.shape[0] * samples1.shape[1] * (samples1.shape[1] + 1))
    loss[0] = (torch.pairwise_distance(song1[:, None, :], samples1).sum() +
             torch.pairwise_distance(samples1[:, :, None, :], samples1[:, None, :, :]).sum() / 2 +
             torch.pairwise_distance(song2[:, None, :], samples2).sum() +
             torch.pairwise_distance(samples2[:, :, None, :], samples2[:, None, :, :]).sum() / 2)\
            * norm_coef
    norm_coef = song1.shape[0] * (samples1.shape[1] + 1) ** 2
    loss[1] = 1 / (torch.pairwise_distance(song1, song2).sum() +
                 torch.pairwise_distance(samples1, samples2).sum() +
                 torch.pairwise_distance(song1[:, None, :], samples2).sum() +
                 torch.pairwise_distance(song1[:, None, :], samples2).sum())\
            * norm_coef
    return loss

In [21]:
def loss_pos_neg_fixed(song1, samples1, song2, samples2):
    loss = torch.zeros((2))
    norm_coef = 1 / 2
    loss[0] = (torch.pairwise_distance(song1[:, None, :], samples1).mean() +
             torch.pairwise_distance(song2[:, None, :], samples2).mean())\
            * norm_coef
    loss[1] = 1 / torch.pairwise_distance(song1, song2).mean()
    return loss

In [22]:
def accuracy(batch, all_data, batch_indices):
    batch = batch[:, 0, :].view(batch.shape[0], -1)
    indices = torch.pairwise_distance(batch[:, None, :], all_data[None, :, :]).min(dim=-1).indices
    accuracy = torch.mean((batch_indices == indices).type(torch.FloatTensor))
    return accuracy

In [23]:
n_gpus = torch.cuda.device_count()

In [24]:
datasets = {
    'train': CQT(files_train, mode='train', out_length=394, n_samples=2),
    'val': CQT(files_test, mode='val', out_length=394, n_samples=2)
}
dataloaders = {
    'train': DataLoader(datasets['train'], batch_size=24 * n_gpus),
    'val': DataLoader(datasets['val'], batch_size=24 * n_gpus)
}

In [25]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [26]:
model = CQTNet(300).to(device)
if n_gpus > 1:
    model = torch.nn.parallel.DataParallel(model, device_ids=list(range(n_gpus)), dim=0)

In [27]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5], gamma=0.2)

In [31]:
train(model, dataloaders, 10, triplet_loss_fixed(positive_rate=0.5), accuracy, optimizer, device=device)

Computing datasets


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Training


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1


  0%|          | 0/106 [00:00<?, ?it/s]

OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_34/3828923816.py", line 48, in forward
    x = self.features(x)  # [N, 512, 57, 2~15]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/pooling.py", line 166, in forward
    return F.max_pool2d(input, self.kernel_size, self.stride,
  File "/opt/conda/lib/python3.10/site-packages/torch/_jit_internal.py", line 488, in fn
    return if_false(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 791, in _max_pool2d
    return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 118.00 MiB. GPU 0 has a total capacty of 14.75 GiB of which 11.06 MiB is free. Process 3042 has 14.73 GiB memory in use. Of the allocated memory 14.01 GiB is allocated by PyTorch, and 513.32 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


In [30]:
torch.cuda.empty_cache()