# preparing mtat data

In [1]:
!pip install transformers
!pip install torchmetric

from IPython.display import clear_output
clear_output(wait=False)

In [6]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchaudio
import torchaudio.functional as AF
import torchaudio.transforms as AT
import torchvision.transforms as VT

import numpy as np
import pandas as pd

import os

from tqdm.notebook import tqdm

n_fft = 1024
win_length = None
hop_length = 512
n_mels = 64
sample_rate = 22050
top_db = 80

mel_spectrogram = AT.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",

    power=2.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="htk",
)
ampl2db = AT.AmplitudeToDB(top_db=top_db)

In [7]:
fma_path = 'mtat_wav'
def getListOfFiles(dirName):
    # create a list of file and sub directories
    # names in the given directory
    listOfFile = os.listdir(dirName)
    allFiles = list()
    # Iterate over all the entries
    for entry in listOfFile:
        # Create full path
        fullPath = os.path.join(dirName, entry)
        # If entry is a directory then get the list of files in this directory
        if os.path.isdir(fullPath):
            allFiles = allFiles + getListOfFiles(fullPath)
        else:
            allFiles.append(fullPath)

    return allFiles


In [13]:

mel_list = []
'''
Here to add the mtat data
'''
file_path = getListOfFiles('MTAT HERE')
for wav_path in tqdm(file_path):
    try:
        waveform, sample_rate = torchaudio.load(wav_path)
    except RuntimeError:
        print(wav_path)
        continue
    melspect = mel_spectrogram(waveform)
    melspect = ampl2db(melspect) * 0.01
    # print(melspect.shape)
    mel_list.append(melspect)

  0%|          | 0/16283 [00:00<?, ?it/s]

# mtat dataset

In [14]:
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as AT
import torch.nn as nn
import random

class MuseData(Dataset):
    def __init__(self, data, transform, mode='train', max_len=25):
        # processed_data = []

        x_shape = []
        for idx, l in enumerate(data):
            try:
                x_shape.append(l.shape[-1])
            except ValueError:
                print(idx)

        self.max_len = max_len

        self.data = data
        self.transform = transform
        self.mode = mode
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ori_width = self.data[0].shape[-1]
        ori_height = self.data[0].shape[-2]

        width = random.randint(int(ori_width * 0.9), int(ori_width * 1.2))
        scaler = VT.Resize((ori_height, width))
        cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2))
        # cropper(scaler(mel_list[0])).shape
        brightness = random.uniform(0.9, 1.1)
        mel = scaler(self.data[idx])
        mel = cropper(mel) * brightness

        split_mel = []

        for idx in range(self.max_len):

            piece = mel[:, :, idx*ori_height//2: (idx+1)*ori_height//2]
            if piece.shape[-1] != ori_height//2:
                print('Error piece')

            split_mel.append(piece)
        split_mel = torch.stack(split_mel, dim=0)
        return split_mel


class VAEDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):

        ori_width = self.data[0].shape[-1]
        ori_height = self.data[0].shape[-2]
        width = random.randint(int(ori_width * 0.9), int(ori_width * 1.2))
        scaler = VT.Resize((ori_height, width))
        brightness = random.uniform(0.9, 1.1)

        mel = scaler(self.data[idx]) * brightness

        max_pos = mel.shape[2] -( ori_height // 2 + 1)
        sampled_pos = random.randint(0, max_pos)
        sample = mel[:, :, sampled_pos:sampled_pos+ori_height//2]
        return sample

    def __len__(self):
        return len(self.data)


mel_transform = nn.Sequential(
    AT.FrequencyMasking(128),
    AT.TimeMasking(128)
)

In [15]:
vae_set = VAEDataset(mel_list)
vae_set[0]

tensor([[[-0.2985, -0.2758, -0.2557,  ..., -0.2686, -0.2635, -0.3103],
         [-0.2013, -0.1682, -0.1752,  ..., -0.0987, -0.1230, -0.1446],
         [-0.1107, -0.0960, -0.1197,  ..., -0.0247, -0.0536, -0.0883],
         ...,
         [-0.4735, -0.4841, -0.4905,  ..., -0.5111, -0.5162, -0.5406],
         [-0.5138, -0.5438, -0.5489,  ..., -0.5805, -0.5787, -0.5838],
         [-0.6193, -0.6193, -0.6193,  ..., -0.6193, -0.6193, -0.6193]]])

In [16]:
import random
batch_size = 32
random.shuffle(mel_list)


# train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8),len(dataset) - int(len(dataset) * 0.8)])
train_set, val_set = MuseData(mel_list[:int(0.8 * len(mel_list))], None, 'val'), MuseData(mel_list[int(0.8*len(mel_list)):], None, 'val')
# train_set = MuseData(train_data)
train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=True,
        shuffle=True,

    )

# val_set = MuseData(val_data)
val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=True,

    )

vae_loader = DataLoader(
    vae_set,
    batch_size=batch_size * 5,
    num_workers=0,
    drop_last=False,
    shuffle=True
)

# Models

In [28]:
import torch
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
import random
from torch.autograd import Variable
# from pytorch_metric_learning.losses import NTXentLoss
from transformers import BertTokenizer, BertModel, BertConfig
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder
from info_nce import InfoNCE
import torchvision.models as models

class FeatureExtractor(nn.Module):
    def __init__(self, nf=32, num_res=0):
        super(FeatureExtractor, self).__init__()
        self.models = nn.Sequential(
            nn.Conv2d(1, nf, 7, 2, 3),
            nn.BatchNorm2d(nf),
            nn.ReLU(True),

            nn.Conv2d(nf, nf*2, 3, 2, 1),
            nn.BatchNorm2d(nf*2),
            nn.ReLU(True),

            nn.Conv2d(nf*2, nf*4, 3, 1, 1),
            nn.BatchNorm2d(nf*4),
            nn.ReLU(True),

            nn.Conv2d(nf*4, nf*4, 3, 2, 1),
            nn.BatchNorm2d(nf*4),
            nn.ReLU(True),

            nn.Conv2d(nf*4, nf*8, 3, 1, 1),
            nn.BatchNorm2d(nf*8),
            nn.ReLU(True),

            nn.Conv2d(nf*8, nf*8, 3, 2, 1),
            nn.BatchNorm2d(nf*8),
            nn.ReLU(True),

            nn.Conv2d(nf*8, nf*8, 3, 1, 1),
            nn.BatchNorm2d(nf*8),
            nn.ReLU(True),

        )

        self.out_layer = nn.Linear(6*nf*4, 512)

    def forward(self, x, pretrain=False):
        x = self.models(x)

        return x.view(x.shape[0], -1)


class FeatureExtractorDecoder(nn.Sequential):
    def __init__(self, nf=32, num_res=2):
        super(FeatureExtractorDecoder, self).__init__()


        self.de_convs = nn.Sequential(

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf*8, nf*4, 3, 1, 1),
            nn.BatchNorm2d(nf*4),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf*4, nf*2, 3, 1, 1),
            nn.BatchNorm2d(nf*2),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf*2, nf, 3, 1, 1),
            nn.BatchNorm2d(nf),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf, 1, 7, 1, 3),
            nn.Tanh()
        )

    def forward(self, input):
        input = input.view(input.shape[0], -1, 4, 2)
        return self.de_convs(input)


class InputRepresentation(nn.Module):
    def __init__(self, max_len, input_dim, hidden_dim, do_rate=0.1):
        super(InputRepresentation, self).__init__()
        self.rep_embedding = nn.Linear(input_dim, hidden_dim)
        self.pos_embedding = nn.Embedding(max_len, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.drop_out = nn.Dropout(do_rate)

    def forward(self, input, pos):
        rep = self.rep_embedding(input)
        pos = self.pos_embedding(pos)
        input_rep = rep + pos
        # print(input_rep.dtype)
        input_rep = self.layer_norm(input_rep)
        input_rep = self.drop_out(input_rep)

        return input_rep


class Patchifier(nn.Module):
    def __init__(self, fe_config, bert_config, bs=64, proj_dim=64):
        super(Patchifier, self).__init__()
        self.fe_config = fe_config
        self.bert_config = bert_config
        self.proj_dim = proj_dim
        # num_embeddings, embedding_dim, commitment_c
        self.feat_extr = FeatureExtractor(fe_config.nf)
        self.feat_extr_decoder = FeatureExtractorDecoder(fe_config.nf)

        self.embedder = InputRepresentation(
                                            fe_config.max_len,
                                            fe_config.nf*8 * 8,
                                            bert_config.hidden_size,
                                            bert_config.hidden_dropout_prob)
        self.encoder = BertEncoder(bert_config)


        self.inv_embedder = nn.Linear(bert_config.hidden_size, fe_config.nf*8 * 8)
        self.diversity_weight = 0.1
        # self.cl_loss_fn = InfoNCE(temperature=0.5)

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.cls_token = nn.Parameter((torch.ones([1, fe_config.nf*8 * 8], dtype=torch.float)).to(self.device))
        self.mask_token =  nn.Parameter((torch.randn([fe_config.nf*8 * 8], dtype=torch.float)).to(self.device))

    def load_fe(self, path):
        self.feat_extr.load_state_dict(torch.load(path))

    def load_bert(self, path):
        self.encoder.load_state_dict(torch.load(path))

    def load_fe_dec(self, path):
        self.feat_extr_decoder.load_state_dict(torch.load(path))

    def save_fe(self, path):
        torch.save(self.feat_extr.state_dict(), path)

    def save_bert(self, path):
        torch.save(self.encoder.state_dict(), path)

    def save_fe_dec(self, path):
        torch.save(self.feat_extr_decoder.state_dict(), path)

    def straight_forward(self, x, mode='vq'):
        x = x.to(self.device)
        bs = x.shape[0]
        length = x.shape[1]
        height = x.shape[3]
        width = x.shape[4]

        # feature extractor embeddings

        feats = [self.cls_token.repeat(bs, 1)]
        for idx in range(length):
            mel = x[:, idx]
            feat = self.feat_extr(mel)
            feats.append(feat)

        feats = torch.stack(feats, dim=1)


        pos = torch.arange(0, length+1).repeat(bs, 1).to(self.device)
        rep = self.embedder(feats, pos)
        rep = self.encoder(rep).last_hidden_state

        return rep

    def mlm_pretrain(self, x, mask_rate=0.5):
        x = x.to(self.device)
        bs = x.shape[0]
        length = x.shape[1]
        height = x.shape[3]
        width = x.shape[4]

        # feature extractor embeddings

        feats = [self.cls_token.repeat(bs, 1)]
        for idx in range(length):
            mel = x[:, idx]
            feat = self.feat_extr(mel)
            feats.append(feat)

        feats = torch.stack(feats, dim=1)

        '''

            MASKING

        '''

        for i in range(feats.shape[0]):
            for j in range(1, feats.shape[1]):
                if random.random() < mask_rate:
                    feats[i,j] = self.mask_token

        pos = torch.arange(0, length+1).repeat(bs, 1).to(self.device)
        rep = self.embedder(feats, pos)
        rep = self.encoder(rep).last_hidden_state

        # MLM loss

        de_embd = self.inv_embedder(rep[:, 1:])

        recons_feats = []
        for idx in range(length):
            rec_feat = de_embd[:, idx]
            rec_feat = self.feat_extr_decoder(rec_feat)
            recons_feats.append(rec_feat)
        recons = torch.stack(recons_feats, dim=1)

        loss = F.mse_loss(recons, x)
        return loss

    def ae_forward(self, x):
        x = x.to(self.device)
        recon_x = self.feat_extr_decoder(self.feat_extr(x))
        loss = F.mse_loss(recon_x, x)
        return loss

    def forward(self, x, mode='mlm', mask_rate=0.5):
        if mode == 'mlm' and mask_rate > 0.:
            loss = self.mlm_forward(x, mask_rate=mask_rate)
            return loss
        elif mode == 'ae':
            loss = self.ae_forward(x),
            return loss
        else:
            return self.straight_forward(x)

In [29]:
class ConfigFE:
    def __init__(self, nf, num_vars, groups, combine_groups, vq_dim, max_len):
        self.nf = nf
        self.num_vars = num_vars
        self.groups = groups
        self.combine_groups = combine_groups
        self.vq_dim = vq_dim
        self.max_len = max_len


In [22]:
lr = 1e-4
epochs = 5000
fe_config = ConfigFE(16, 320, 2, True, 256, 40)
bert_config = BertConfig(hidden_size=256, num_attention_heads=4, num_hidden_layers=8, intermediate_size=1024, hidden_dropout_prob=0.3)
patchifier = Patchifier(fe_config, bert_config).to('cuda')

opt = torch.optim.AdamW(patchifier.parameters(), lr=lr)

for e in tqdm(range(epochs)):
    epoch_losses = []
    patchifier.train()

    for x in vae_loader:
        x = x.to('cuda')
        loss = patchifier.ae_forward(x)

        opt.zero_grad()
        loss.backward()
        opt.step()

        epoch_losses.append(loss.cpu().data.item())

    mean_loss = np.mean(np.array(epoch_losses))
    if (e+1) % 50 == 0:
        print('Loss at %d epoch: %.5f' % (e, mean_loss))

# patchifier.save_fe('ckpt/ckpt_fe.pkl')
# patchifier.save_fe_dec('ckpt/ckpt_fe_dec.pkl')
# torch.save(opt.state_dict(), 'ckpt/autoencoder_opt.pkl')

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Downstream Classification

In [30]:
'''
preparing data
'''

import os

'''
Here to add gtzan path
'''
ori_path = 'GTZAN HERE'

genres = ['reggae', 'pop', 'rock', 'hiphop', 'metal', 'country', 'disco', 'classical', 'blues', 'jazz']

import torch
import torchaudio
import torchaudio.functional as AF
import torchaudio.transforms as AT

n_fft = 1024
win_length = None
hop_length = 512
n_mels = 64
sample_rate = 22050
top_db = 80

mel_spectrogram = AT.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",
    power=2.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="htk",
)
ampl2db = AT.AmplitudeToDB(top_db=top_db)

from tqdm.notebook import tqdm

genres_dict = {
    'reggae': 0,
    'pop': 1,
    'rock': 2,
    'hiphop': 3,
    'metal': 4,
    'country': 5,
    'disco': 6,
    'classical': 7,
    'blues': 8,
    'jazz': 9
}

mel_list = []
for g in tqdm(genres):
    for wav in os.listdir(os.path.join(ori_path, g)):
        wav_path = os.path.join(ori_path, g, wav)
        try:
            waveform, sample_rate = torchaudio.load(wav_path)
        except RuntimeError:
            continue
        melspect = mel_spectrogram(waveform)
        melspect = ampl2db(melspect)

        assert genres_dict[g] != None
        mel_list.append([melspect, torch.tensor(genres_dict[g])])

mels = []
for i, (elem, label) in enumerate(mel_list):
    if elem.shape[-1] >= 1293:
        elem = elem[..., :1293]
        mels.append(elem)

mels = torch.stack(mels, dim=0)
mels = mels.detach() / 100

new_mel_list = []
for mel, (_, label) in zip(mels, mel_list):
    new_mel_list.append([mel, label])

  0%|          | 0/10 [00:00<?, ?it/s]

In [31]:
'''
Dataset
'''

class MuseData(Dataset):
    def __init__(self, data, transform, mode='train', max_len=25):

        x_shape = []
        for idx, (x, l) in enumerate(data):
            try:
                x_shape.append(x.shape[-1])
            except ValueError:
                print(idx)

        self.max_len = max_len

        self.data = data
        self.transform = transform
        self.mode = mode
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data, label = self.data[idx]
        ori_width = data.shape[-1]
        ori_height = data.shape[-2]

        if self.mode  == 'train':
            width = random.randint(int(ori_width * 0.9), int(ori_width * 1.1))
            scaler = VT.Resize((ori_height, width))
            cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2 ))
            # cropper(scaler(mel_list[0])).shape
            brightness = random.uniform(0.9, 1.1)
            mel = scaler(data)
            mel = (cropper(mel) * brightness).detach()

        else:
            mel = data[..., :(self.max_len * ori_height) // 2 ]

        split_mel = []

        for idx in range(self.max_len):

            piece = mel[:, :, idx*ori_height//2: (idx+1)*ori_height//2]
            if piece.shape[-1] != ori_height//2:
                print('Error piece')

            split_mel.append(piece)
        split_mel = torch.stack(split_mel, dim=0)
        return split_mel, label


import random
batch_size = 32
random.shuffle(new_mel_list)

train_set, val_set = MuseData(new_mel_list[:int(0.8 * len(new_mel_list))], None, 'train'), MuseData(new_mel_list[int(0.8*len(new_mel_list)):], None, 'val')
train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=True,
        shuffle=True,

    )

val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=True,

    )



In [37]:
'''
Linear Prob
'''

class CLSF(nn.Module):
    def __init__(self, config):
        super(CLSF, self).__init__()
        self.model = nn.Sequential(
            # nn.Linear(config.hidden_size, config.hidden_size),
            # nn.ReLU(True),
            # nn.Dropout(0.2),
            nn.Linear(config.hidden_size, 10),
        )

    def forward(self, x):
        return self.model(x)
lr = 1e-3
epochs = 1000

fe_config = ConfigFE(16, 320, 2, True, 256, 40)
bert_config = BertConfig(hidden_size=256, num_attention_heads=4, num_hidden_layers=8, intermediate_size=1024, hidden_dropout_prob=0.1)

model = Patchifier(fe_config, bert_config).to('cuda')

'''
Here to add model path
'''
model.load_state_dict(torch.load('MODEL HERE'))

# freeze model
for param in model.parameters():
    param.requires_grad = False

class Pooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

        pooled_output = self.dense(hidden_states)
        pooled_output = self.activation(pooled_output)
        return pooled_output

pooler = Pooler(bert_config).to('cuda')
clsf = CLSF(bert_config).to('cuda')
crit = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(list(model.parameters()) + list(pooler.parameters()) + list(clsf.parameters()), lr=lr,
                        weight_decay=0.1)

def multi_acc(y_pred, y_test):
    # print(y_pred.shape)
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    acc = torch.round(acc * 100)
    return acc


train_losses = []
acc_script = []
for epoch in tqdm(range(epochs)):

    loss_script = []

    model.train()
    pooler.train()
    clsf.train()

    total, correct = 0, 0
    for x, label in train_loader:
        x = x.to('cuda')
        label = label.to('cuda')

        rep = model(x, mode='straight', mask_rate=0.5)

        logit = clsf(pooler(rep[:, 0])).squeeze()
        loss = crit(logit, label)

        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_script.append(loss.cpu().data.item())
        _, pred = torch.max(logit.data, 1)
        total += label.size(0)
        correct += (pred == label).sum().item()

    accuracy = (correct / total) * 100
    if (epoch+1) % 10 == 0:
        print('train accuracy in %d epoch: %.4f' % ((epoch+1), accuracy))
    mean_loss = np.mean(np.array(loss_script))
    train_losses.append(mean_loss)

    with torch.no_grad():

        total, correct = 0, 0

        model.eval()
        pooler.eval()
        clsf.eval()

        for x, label in val_loader:
            x = x.to('cuda')
            label = label.to('cuda')

            rep = model(x, mode='straight', mask_rate=0.5)
            # print(rep.shape)
            logit = clsf(pooler(rep[:, 0])).squeeze()

            _, pred = torch.max(logit.data, 1)
            total += label.size(0)
            correct += (pred == label).sum().item()

        accuracy = (correct / total) * 100
        if (epoch+1) % 10 == 0:
            print('valid accuracy in %d epoch: %.4f' % ((epoch+1), accuracy))
            print('-' * 30)

  0%|          | 0/1000 [00:00<?, ?it/s]

train accuracy in 10 epoch: 57.0312
valid accuracy in 10 epoch: 55.5556
------------------------------
train accuracy in 20 epoch: 64.0625
valid accuracy in 20 epoch: 62.1212
------------------------------
train accuracy in 30 epoch: 65.7552
valid accuracy in 30 epoch: 62.1212
------------------------------
train accuracy in 40 epoch: 68.4896
valid accuracy in 40 epoch: 64.1414
------------------------------
train accuracy in 50 epoch: 69.0104
valid accuracy in 50 epoch: 64.6465
------------------------------
train accuracy in 60 epoch: 68.2292
valid accuracy in 60 epoch: 66.6667
------------------------------
train accuracy in 70 epoch: 70.3125
valid accuracy in 70 epoch: 67.1717
------------------------------
train accuracy in 80 epoch: 66.9271
valid accuracy in 80 epoch: 63.6364
------------------------------
train accuracy in 90 epoch: 70.1823
valid accuracy in 90 epoch: 68.6869
------------------------------
train accuracy in 100 epoch: 68.4896
valid accuracy in 100 epoch: 69.1919

# Downstream regression

In [36]:
'''
preparing data
'''
import json
def parse_annotation_file(json_file):
    tracks = {}
    with open(json_file) as f:
        examples = json.load(f)
        for song_id in examples:
            tracks[song_id] = {
                'track_id': examples[song_id]['extra']['songs_info']['song_id'],
                'split': examples[song_id]['split'],
                'labels': examples[song_id]['y']
            }
    return tracks

'''
here to add emo music annotation
'''

infors = parse_annotation_file('ANNOTATION HERE')

n_fft = 1024
win_length = None
hop_length = 512
n_mels = 64
sample_rate = 44100
top_db = 80

mel_spectrogram = AT.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",
    power=2.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="htk",
)
ampl2db = AT.AmplitudeToDB(top_db=top_db)

mel_list = []

'''
Here to add emomusic path
'''

file_path = getListOfFiles('emo music')
for wav_path in tqdm(file_path):
    wav_id = wav_path[10: -4].zfill(4)

    try:
        waveform, sample_rate = torchaudio.load(wav_path)
        emotion = infors[wav_id]['labels']
    except RuntimeError:
        print(wav_path)
        continue
    except KeyError:
        print(wav_id)
        continue
    # scaler = VT.Resize((64, 1938))
    melspect = mel_spectrogram(waveform)
    melspect = ampl2db(melspect)
    melspect = (melspect).detach() * 0.01
    # print(melspect.shape)
    mel_list.append([melspect, emotion])


  0%|          | 0/1000 [00:00<?, ?it/s]

0000
0001
0100
0011
0132
0139
0014
0141
0147
0015
0016
0165
0183
0200
0023
0230
0249
0251
0252
0255
0256
0259
0026
0261
0263
0267
0268
0027
0271
0277
0028
0283
0287
0029
0291
0295
0030
0033
0331
0337
0034
0351
0036
0363
0373
0377
0038
0382
0385
0388
0389
0394
0396
0398
0409
0411
0412
0413
0414
0417
0418
0421
0424
0433
0434
0438
0439
0443
0446
0447
0457
0465
0470
0471
0474
0476
0483
0491
0495
0505
0508
0509
0510
0511
0516
0517
0526
0528
0531
0532
0533
0534
0538
0539
0541
0542
0543
0545
0546
0548
0549
0552
0553
0557
0559
0562
0563
0566
0567
0569
0057
0570
0571
0572
0573
0575
0576
0578
0583
0587
0588
0589
0590
0593
0595
0596
0599
0006
0601
0602
0603
0604
0061
0618
0619
0624
0626
0627
0063
0630
0633
0636
0641
0642
0655
0659
0066
0669
0670
0678
0679
0680
0683
0694
0701
0705
0716
0720
0075
0751
0752
0753
0754
0755
0760
0761
0762
0765
0766
0768
0770
0771
0772
0774
0778
0783
0785
0786
0788
0792
0793
0802
0803
0809
0812
0816
0817
0821
0822
0827
0828
0832
0835
0837
0838
0840
0842
0843
0847
0849


In [None]:
'''
Dataset
'''

class MuseData(Dataset):
    def __init__(self, data, transform='None', mode='train', max_len=25):
        # processed_data = []

        x_shape = []
        for idx, (x, l) in enumerate(data):
            try:
                x_shape.append(x.shape[-1])
            except ValueError:
                print(idx)
        # x_shape = [l.shape[-1] for (l, _) in data]
        self.max_len = max_len

        self.data = data
        self.transform = transform
        self.mode = mode
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data, label = self.data[idx]
        ori_width = data.shape[-1]
        ori_height = data.shape[-2]

        if self.mode  == 'train':
            width = random.randint(int(ori_width * 0.9), int(ori_width * 1.1))
            scaler = VT.Resize((ori_height, width))
            cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2 ))
            # cropper(scaler(mel_list[0])).shape
            brightness = random.uniform(0.9, 1.1)
            mel = scaler(data)
            mel = (cropper(mel) * brightness).detach()

        else:

            mel = data[..., :(self.max_len * ori_height) // 2 ]

        split_mel = []

        for idx in range(self.max_len):

            piece = mel[:, :, idx*ori_height//2: (idx+1)*ori_height//2]
            if piece.shape[-1] != ori_height//2:
                print('Error piece')

            split_mel.append(piece)
        split_mel = torch.stack(split_mel, dim=0)
        return split_mel, torch.tensor(label[0]), torch.tensor(label[1])


import random
batch_size = 32
random.shuffle(mel_list)


# train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8),len(dataset) - int(len(dataset) * 0.8)])
train_set, val_set = MuseData(mel_list[:int(0.8 * len(mel_list))], None, 'train'), MuseData(mel_list[int(0.8*len(mel_list)):], None, 'val')
# train_set = MuseData(train_data)
train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=True,
        shuffle=True,

    )

# val_set = MuseData(val_data)
val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=True,

    )

In [None]:
'''
Linear prob
'''

class CLSF(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.prob_valence = nn.Linear(config.hidden_size, 1)
        self.prob_arousal = nn.Linear(config.hidden_size, 1)

    def forward(self, x):
        return self.prob_valence(x), self.prob_arousal(x)

lr = 1e-4
epochs = 2000

fe_config = ConfigFE(16, 320, 2, True, 256, 40)
bert_config = BertConfig(hidden_size=256, num_attention_heads=4, num_hidden_layers=8, intermediate_size=1024, hidden_dropout_prob=0.1)

model = Patchifier(fe_config, bert_config).to('cuda')

'''
Add Model
'''
model.load_state_dict(torch.load('MODEL HERE'))

for param in model.parameters():
      param.requires_grad = False

class Pooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # mean_hidden = hidden_states.mean(1)
        pooled_output = self.dense(hidden_states)
        pooled_output = self.activation(pooled_output)
        return pooled_output

pooler = Pooler(bert_config).to('cuda')
clsf = CLSF(bert_config).to('cuda')
crit = nn.MSELoss()
opt = torch.optim.AdamW(list(model.parameters()) + list(pooler.parameters()) + list(clsf.parameters()), lr=lr,
                        weight_decay=0.1)

In [None]:
from torchmetrics import R2Score
import torchmetrics
from torchmetrics import R2Score

scorer = R2Score().to('cuda')
train_losses = []
acc_script = []
for epoch in tqdm(range(epochs)):

    loss_script = []

    model.train()
    pooler.train()
    clsf.train()

    flag = 0
    for x, label_v, label_a in train_loader:
        x = x.to('cuda')
        label_v = label_v.to('cuda')
        label_a = label_a.to('cuda')

        rep = model(x, mode='straight')

        logit_v, logit_a = clsf(pooler(rep[:, 1:].mean(1)))
        logit_v, logit_a = logit_v.squeeze(-1), logit_a.squeeze(-1)

        loss = crit(logit_v, label_v) + crit(logit_a, label_a)

        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_script.append(loss.cpu().data.item())
        if flag == 0:
            total_label_v, total_label_a = label_v.data, label_a.data
            total_logit_v, total_logit_a = logit_v.data, logit_a.data
            flag = 1
        else:
            total_label_v = torch.cat([total_label_v, label_v.data], dim=0)
            total_label_a = torch.cat([total_label_a, label_a.data], dim=0)

            total_logit_v = torch.cat([total_logit_v, logit_v.data], dim=0)
            total_logit_a = torch.cat([total_logit_a, logit_a.data], dim=0)

    r2_v = scorer(total_logit_v, total_label_v)
    r2_a = scorer(total_logit_a, total_label_a)

    mean_loss = np.mean(np.array(loss_script))
    train_losses.append(mean_loss)

    if (epoch+1) % 5 == 0:
        print('train loss in %d epoch: %.4f' % ((epoch+1), mean_loss))
        print('r2 score valence: %.4f' % r2_v)
        print('r2 score arousal: %.4f' % r2_a)

    with torch.no_grad():

        total, correct = 0, 0

        model.eval()
        pooler.eval()
        clsf.eval()

        for x, label_v, label_a in train_loader:
            x = x.to('cuda')
            label_v = label_v.to('cuda')
            label_a = label_a.to('cuda')

            rep = model(x, mode='straight', mask_rate=0.5)
            # print(rep.shape)
            logit_v, logit_a = clsf(pooler(rep[:, 1:].mean(1)))
            logit_v, logit_a = logit_v.squeeze(-1), logit_a.squeeze(-1)

            if flag == 0:
                total_label_v, total_label_a = label_v.data, label_a.data
                total_logit_v, total_logit_a = logit_v.data, logit_a.data
                flag = 1
            else:
                total_label_v = torch.cat([total_label_v, label_v.data], dim=0)
                total_label_a = torch.cat([total_label_a, label_a.data], dim=0)

                total_logit_v = torch.cat([total_logit_v, logit_v.data], dim=0)
                total_logit_a = torch.cat([total_logit_a, logit_a.data], dim=0)


        r2_v = scorer(total_logit_v, total_label_v)
        r2_a = scorer(total_logit_a, total_label_a)

        if (epoch+1) % 5 == 0:
            print('r2 score in %d epoch: %.4f, %.4f' % ((epoch+1), r2_v, r2_a))
            print('-' * 30)