In [1]:
import os
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Input, GlobalAveragePooling2D, Dropout, Reshape
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import sklearn.metrics

In [3]:
base_model = keras.applications.ResNet50(
        include_top=False,
        weights=None,
        input_shape=(128,130,1)
)
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(256, 'relu')(x)
model = keras.models.Model(inputs=base_model.input, outputs=x)


In [4]:
model.summary()

In [5]:
batch_size = 256
spectrograms_array = np.load('spectrograms.npy')
labels_array = np.load('labels.npy')

spectrograms_array = spectrograms_array/80 + 1
spectrograms_array = np.expand_dims(spectrograms_array, axis=3)

dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(spectrograms_array), torch.from_numpy(labels_array)
)

train_size = int(0.8 * len(dataset))  
val_size = int(0.1 * len(dataset))    
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size]
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
    
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False
)


In [6]:
def jaccard_distance(x1, x2):
    
    x1 = x1[x1 !=0 ]
    x2 = x2[x2 !=0 ]
    
    x1_instruments = x1[0::3]
    x2_instruments = x2[0::3]

    x1_pitches = x1[2::3]
    x2_pitches = x2[2::3]

    ipx1 = set([(x1[i], x1[i+2]) for i in range(0, len(x1), 3)])
    ipx2 = set([(x2[i], x2[i+2]) for i in range(0, len(x2), 3)])

    shared_dist = len(ipx1.intersection(ipx2)) / len(ipx1.union(ipx2))
    

    instrument_dist = np.intersect1d(x1_instruments, x2_instruments).size / np.union1d(x1_instruments, x2_instruments).size
    pitches_dist = np.intersect1d(x1_pitches, x2_pitches).size / np.union1d(x1_pitches, x2_pitches).size
    
    
    return 1 - (0.5*pitches_dist + 0.5*instrument_dist + 0*shared_dist)

In [7]:
class LabelDifference(nn.Module):
    def __init__(self, distance_type='jaccard'):
        super(LabelDifference, self).__init__()
        self.distance_type = distance_type

    def forward(self, labels):
        #labels: [bs, label_dim]
        #output: [bs, bs]

        x = labels.shape[0]

        matrix = np.zeros((x,x))

        for i in range(x):
            for j in range(x):
                matrix[i][j] = jaccard_distance(labels[i],labels[j])

        return torch.from_numpy(matrix).to('cuda')

class FeatureSimilarity(nn.Module):
    def __init__(self, similarity_type='l2'):
        super(FeatureSimilarity, self).__init__()
        self.similarity_type = similarity_type

    def forward(self, features):
        # labels: [bs, feat_dim]
        # output: [bs, bs]
        if self.similarity_type == 'l2':
            return -(features[:, None, :] - features[None, :, :]).norm(2, dim=-1)
        else:
            raise ValueError(self.similarity_type)


class RnCLoss(nn.Module):
    def __init__(self, temperature=2, label_diff='jaccard', feature_sim='l2'):
        super(RnCLoss, self).__init__()
        self.t = temperature
        self.label_diff_fn = LabelDifference(label_diff)
        self.feature_sim_fn = FeatureSimilarity(feature_sim)

    def forward(self, features, labels):
        # features: [bs, 2, feat_dim]
        # labels: [bs, label_dim]

        features = torch.cat([features[:, 0], features[:, 1]], dim=0)  # [2bs, feat_dim]
        labels = labels.repeat(2, 1)  # [2bs, label_dim]

        label_diffs = self.label_diff_fn(labels)
        logits = self.feature_sim_fn(features).div(self.t)
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits -= logits_max.detach()
        exp_logits = logits.exp()

        n = logits.shape[0]  # n = 2bs

        # remove diagonal
        logits = logits.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)
        exp_logits = exp_logits.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)
        label_diffs = label_diffs.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)

        loss = 0.
        for k in range(n - 1):
            pos_logits = logits[:, k]  # 2bs
            pos_label_diffs = label_diffs[:, k]  # 2bs
            neg_mask = (label_diffs >= pos_label_diffs.view(-1, 1)).float()  # [2bs, 2bs - 1]
            pos_log_probs = pos_logits - torch.log((neg_mask * exp_logits).sum(dim=-1))  # 2bs
            loss += - (pos_log_probs / (n * (n - 1))).sum()

        return loss

In [8]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
loss_fn = RnCLoss()

In [None]:
epochs = 1000
best_vloss = 1_000_000.
for epoch in range(epochs):
    
    running_loss = 0.
    last_loss = 0.
    model.train(True)
    
    for step, (inputs, targets) in enumerate(train_dataloader):
        
        #augmentation
        time_masking = T.TimeMasking(time_mask_param=0)
        freq_masking = T.FrequencyMasking(freq_mask_param=0)

        inputs2 = inputs.clone()

        time_masked1 = time_masking(inputs)
        aug_inputs1 = freq_masking(time_masked1)

        time_masked2 = time_masking(inputs2)
        aug_inputs2 = freq_masking(time_masked2)
        
        logits1 = model(aug_inputs1)
        logits2 = model(aug_inputs2)

        features = torch.cat((logits1.unsqueeze(1), logits2.unsqueeze(1)), dim=1)
        
        loss = loss_fn(features, targets)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Optimizer variable updates
        optimizer.step()
        
        if step % 50 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")
 
    running_vloss = 0.0
    model.eval()
    
    with torch.no_grad():
        for i, vdata in enumerate(val_dataloader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            val_features = torch.cat((voutputs.unsqueeze(1), voutputs.unsqueeze(1)), dim=1)
            vloss = loss_fn(val_features, vlabels)
            running_vloss += vloss

    i=0
    avg_vloss = running_vloss / (i + 1)

    print("vloss: ", avg_vloss)
    
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model.save(f"models/epoch_{epoch + 1}_val_loss_{best_vloss:.4f}.keras")

model.save("models/final.keras")