# Preprocessing Data

In [None]:
import torch
import os
import torch.nn.functional as F
import torch.nn as nn
import torchaudio
import torchaudio.functional as AF
import torchaudio.transforms as AT
import torchvision.transforms as VT

import numpy as np
import pandas as pd

n_fft = 1024
win_length = None
hop_length = 512
n_mels = 64
sample_rate = 22050
top_db = 80

mel_spectrogram = AT.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",

    power=2.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="htk",
)
ampl2db = AT.AmplitudeToDB(top_db=top_db)

from tqdm.notebook import tqdm

genres_dict = {
    'reggae': 0,
    'pop': 1,
    'rock': 2,
    'hiphop': 3,
    'metal': 4,
    'country': 5,
    'disco': 6,
    'classical': 7,
    'blues': 8,
    'jazz': 9

}

fma_path = 'fma_small'
def getListOfFiles(dirName):
    listOfFile = os.listdir(dirName)
    allFiles = list()

    for entry in listOfFile:

        fullPath = os.path.join(dirName, entry)

        if os.path.isdir(fullPath):
            allFiles = allFiles + getListOfFiles(fullPath)
        else:
            allFiles.append(fullPath)

    return allFiles

len(getListOfFiles(fma_path))

In [None]:
from os import path
from pydub import AudioSegment

mtat_path = getListOfFiles('mtat')
for i, mp3_path in tqdm(enumerate(mtat_path)):
    try:
        sound = AudioSegment.from_mp3(mp3_path)
        dst = os.path.join('mtat_wav', '%d.wav' % i)
        sound.export(dst, format="wav")
    except Exception:
        print(mp3_path)
        continue

In [None]:
mel_list = []
file_path = getListOfFiles('mtat_wav')
for wav_path in tqdm(file_path):
    try:
        waveform, sample_rate = torchaudio.load(wav_path)
    except RuntimeError:
        print(wav_path)
        continue
    melspect = mel_spectrogram(waveform)
    melspect = ampl2db(melspect)
    # print(melspect.shape)
    mel_list.append(melspect)

In [None]:
import torchvision.transforms as VT

# Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as AT
import torch.nn as nn
import random

class MuseData(Dataset):
    def __init__(self, data, transform, mode='train', max_len=25):
        # processed_data = []

        x_shape = []
        for idx, l in enumerate(data):
            try:
                x_shape.append(l.shape[-1])
            except ValueError:
                print(idx)
        # x_shape = [l.shape[-1] for (l, _) in data]
        self.max_len = max_len

        self.data = data
        self.transform = transform
        self.mode = mode
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ori_width = self.data[0].shape[-1]
        ori_height = self.data[0].shape[-2]

        width = random.randint(int(ori_width * 0.9), int(ori_width * 1.2))
        scaler = VT.Resize((ori_height, width))
        cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2))
        # cropper(scaler(mel_list[0])).shape
        brightness = random.uniform(0.9, 1.1)
        mel = scaler(self.data[idx])
        mel = cropper(mel) * brightness
        # if self.mode == 'train':
        #     mel = self.transform(mel)
        split_mel = []
        # print(mel.shape[2])
        for idx in range(self.max_len):

            piece = mel[:, :, idx*ori_height//2: (idx+1)*ori_height//2]
            if piece.shape[-1] != ori_height//2:
                print('Error piece')

            # piece = torch.zeros(piece)
            split_mel.append(piece)
        split_mel = torch.stack(split_mel, dim=0)
        return split_mel


class VAEDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):

        ori_width = self.data[0].shape[-1]
        ori_height = self.data[0].shape[-2]
        width = random.randint(int(ori_width * 0.9), int(ori_width * 1.2))
        scaler = VT.Resize((ori_height, width))
        brightness = random.uniform(0.9, 1.1)

        mel = scaler(self.data[idx]) * brightness

        max_pos = mel.shape[2] -( ori_height // 2 + 1)
        sampled_pos = random.randint(0, max_pos)
        sample = mel[:, :, sampled_pos:sampled_pos+ori_height//2]
        return sample

    def __len__(self):
        return len(self.data)


mel_transform = nn.Sequential(
    AT.FrequencyMasking(128),
    AT.TimeMasking(128)
)

In [None]:
from torch.nn import functional as F
import random
batch_size = 32
random.shuffle(mel_list)


# train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8),len(dataset) - int(len(dataset) * 0.8)])
train_set, val_set = MuseData(mel_list[:int(0.8 * len(mel_list))], None, 'val'), MuseData(mel_list[int(0.8*len(mel_list)):], None, 'val')
# train_set = MuseData(train_data)
train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=True,
        shuffle=True,

    )

# val_set = MuseData(val_data)
val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=True,

    )

vae_loader = DataLoader(
    vae_set,
    batch_size=batch_size * 5,
    num_workers=0,
    drop_last=False,
    shuffle=True
)

# Network

In [1]:
from transformers import BertTokenizer, BertModel, BertConfig
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder
import torch

config = BertConfig(num_hidden_layers = 6, num_attention_heads=6)
encoder = BertEncoder(config)

In [None]:
import torch
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
import random
from torch.autograd import Variable
# from pytorch_metric_learning.losses import NTXentLoss
from transformers import BertTokenizer, BertModel, BertConfig
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder
from info_nce import InfoNCE

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self, nf=32, num_res=0):
        super(FeatureExtractor, self).__init__()
        self.models = nn.Sequential(
            nn.Conv2d(1, nf, 7, 2, 3),
            nn.BatchNorm2d(nf),
            nn.ReLU(True),
            # nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(nf, nf*2, 3, 2, 1),
            nn.BatchNorm2d(nf*2),
            nn.ReLU(True),
            # nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(nf*2, nf*4, 3, 1, 1),
            nn.BatchNorm2d(nf*4),
            nn.ReLU(True),

            nn.Conv2d(nf*4, nf*4, 3, 2, 1),
            nn.BatchNorm2d(nf*4),
            nn.ReLU(True),
            # nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(nf*4, nf*8, 3, 1, 1),
            nn.BatchNorm2d(nf*8),
            nn.ReLU(True),

            nn.Conv2d(nf*8, nf*8, 3, 2, 1),
            nn.BatchNorm2d(nf*8),
            nn.ReLU(True),
            # nn.MaxPool2d(2, 2, 0),

            nn.Conv2d(nf*8, nf*8, 3, 1, 1),
            nn.BatchNorm2d(nf*8),
            nn.ReLU(True),

            # nn.Conv2d(nf*8, nf*8, 3, 2, 1),
            # nn.BatchNorm2d(nf*8),
            # nn.ReLU(True),
            # nn.MaxPool2d(2, 2, 0),
            # nn.AdaptiveAvgPool2d((1, 1))
        )
        # res_blocks = []
        # for _ in range(num_res):
        #     res_blocks.append(ResBlock(nf*8, nf*8))
        #
        # self.res_blocks = nn.Sequential(*res_blocks)
        self.out_layer = nn.Linear(6*nf*4, 512)

    def forward(self, x, pretrain=False):
        x = self.models(x)
        # x = self.res_blocks(x)
        # x = x.view(x.shape[0], -1)
        # if pretrain:
        #     return self.proj_head(x)
        # else:
        return x.view(x.shape[0], -1)


class FeatureExtractorDecoder(nn.Sequential):
    def __init__(self, nf=32, num_res=2):
        super(FeatureExtractorDecoder, self).__init__()
        # res_blocks = []
        # for _ in range(num_res):
        #     res_blocks.append(ResBlock(nf*8, nf*8))
        # self.res_blocks = nn.Sequential(*res_blocks)

        self.de_convs = nn.Sequential(
            # nn.Upsample(scale_factor=2, mode='bilinear'),
            # nn.Conv2d(nf*8, nf*8, 3, 1, 1),
            # nn.BatchNorm2d(nf*8),
            # nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf*8, nf*4, 3, 1, 1),
            nn.BatchNorm2d(nf*4),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf*4, nf*2, 3, 1, 1),
            nn.BatchNorm2d(nf*2),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf*2, nf, 3, 1, 1),
            nn.BatchNorm2d(nf),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(nf, 1, 7, 1, 3),
            nn.Tanh()
        )
    def forward(self, input):
        input = input.view(input.shape[0], -1, 4, 2)
        return self.de_convs(input)


class InputRepresentation(nn.Module):
    def __init__(self, max_len, input_dim, hidden_dim, do_rate=0.1):
        super(InputRepresentation, self).__init__()
        self.rep_embedding = nn.Linear(input_dim, hidden_dim)
        self.pos_embedding = nn.Embedding(max_len, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.drop_out = nn.Dropout(do_rate)

    def forward(self, input, pos):
        rep = self.rep_embedding(input)
        pos = self.pos_embedding(pos)
        input_rep = rep + pos
        # print(input_rep.dtype)
        input_rep = self.layer_norm(input_rep)
        input_rep = self.drop_out(input_rep)

        return input_rep

# dim, num_vars, groups, combine_groups, vq_dim

class UniSML(nn.Module):
    def __init__(self, fe_config, bert_config, bs=64, proj_dim=64):
        super(UniSML, self).__init__()
        self.fe_config = fe_config
        self.bert_config = bert_config
        self.proj_dim = proj_dim
        # num_embeddings, embedding_dim, commitment_c
        self.feat_extr = FeatureExtractor(fe_config.nf)
        self.feat_extr_decoder = FeatureExtractorDecoder(fe_config.nf)
        # self.quantizer = KmeansVectorQuantizer(
        #     fe_config.nf*8,
        #     fe_config.num_vars,
        #     fe_config.groups,
        #     fe_config.combine_groups,
        #     fe_config.vq_dim
        #  )
        self.embedder = InputRepresentation(
                                            fe_config.max_len,
                                            fe_config.nf*8 * 8,
                                            bert_config.hidden_size,
                                            bert_config.hidden_dropout_prob)
        self.encoder = BertEncoder(bert_config)

        # self.proj_fe = nn.Linear(fe_config.nf*8, proj_dim)
        # self.proj_bert = nn.Linear(bert_config.hidden_size, proj_dim)
        self.inv_embedder = nn.Linear(bert_config.hidden_size, fe_config.nf*8 * 8)
        self.diversity_weight = 0.1
        self.cl_loss_fn = InfoNCE(temperature=0.5)

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.cls_token = nn.Parameter((torch.ones([1, fe_config.nf*8 * 8], dtype=torch.float)).to(self.device))
        self.mask_token =  nn.Parameter((torch.randn([fe_config.nf*8 * 8], dtype=torch.float)).to(self.device))

    def load_fe(self, path):
        self.feat_extr.load_state_dict(torch.load(path))

    def load_bert(self, path):
        self.encoder.load_state_dict(torch.load(path))

    def load_fe_dec(self, path):
        self.feat_extr_decoder.load_state_dict(torch.load(path))

    def save_fe(self, path):
        torch.save(self.feat_extr.state_dict(), path)

    def save_bert(self, path):
        torch.save(self.encoder.state_dict(), path)

    def save_fe_dec(self, path):
        torch.save(self.feat_extr_decoder.state_dict(), path)




    def mlm_pretrain(self, x, mask_rate=0.5):
        x = x.to(self.device)
        bs = x.shape[0]
        length = x.shape[1]
        height = x.shape[3]
        width = x.shape[4]

        # feature extractor embeddings

        feats = [self.cls_token.repeat(bs, 1)]
        for idx in range(length):
            mel = x[:, idx]
            feat = self.feat_extr(mel)
            feats.append(feat)

        feats = torch.stack(feats, dim=1)

        # x = x.reshape(bs*length, 1, height, width)
        # feats = self.feat_extr(x)
        # feats = feats.reshape(bs, length, -1)
        # print("feature extracting: ", time.time() - start_time)
        # start_time= time.time()

        '''

            MASKING

        '''

        for i in range(feats.shape[0]):
            for j in range(1, feats.shape[1]):
                if random.random() < mask_rate:
                    feats[i,j] = self.mask_token
        # masked_indice = (torch.rand([bs, length]) > mask_rate).to(self.device)
        # feats[~masked_indice] = self.mask_token

        #
        # print('Masking', time.time()-start_time)
        # start_time = time.time()

        pos = torch.arange(0, length+1).repeat(bs, 1).to(self.device)
        rep = self.embedder(feats, pos)
        rep = self.encoder(rep).last_hidden_state

        # MLM loss

        de_embd = self.inv_embedder(rep[:, 1:])
        # print(de_embd.shape)
        # print('transforming: ', time.time() - start_time)
        # start_time = time.time()

        # ori_list = torch.unbind(ori_feats, dim=1)
        # latent_loss = F.mse_loss(de_embd, feats)
        recons_feats = []
        for idx in range(length):
            rec_feat = de_embd[:, idx]
            rec_feat = self.feat_extr_decoder(rec_feat)
            recons_feats.append(rec_feat)
        recons = torch.stack(recons_feats, dim=1)
        # de_embd = de_embd.reshape(bs*length, -1)
        # recons = self.feat_extr_decoder(de_embd)
        # recons = recons.reshape(bs*length, 1, height, width)
        # print(recons.shape)
        # print('decoding: ', time.time() - start_time)
        loss = F.mse_loss(recons, x)
        return loss



    def straight_forward(self, x, mode='vq'):
        x = x.to(self.device)
        bs = x.shape[0]
        length = x.shape[1]
        height = x.shape[3]
        width = x.shape[4]

        # feature extractor embeddings

        feats = [self.cls_token.repeat(bs, 1)]
        for idx in range(length):
            mel = x[:, idx]
            feat = self.feat_extr(mel)
            feats.append(feat)

        feats = torch.stack(feats, dim=1)


        pos = torch.arange(0, length+1).repeat(bs, 1).to(self.device)
        rep = self.embedder(feats, pos)
        rep = self.encoder(rep).last_hidden_state

        return rep

    def ae_forward(self, x):
        x = x.to(self.device)
        recon_x = self.feat_extr_decoder(self.feat_extr(x))
        loss = F.mse_loss(recon_x, x)
        return loss

    def forward(self, x, mode='mlm', mask_rate=0.5):
        if mode == 'mlm' and mask_rate > 0.:
            loss = self.mlm_forward(x, mask_rate=mask_rate)
            return loss
        elif mode == 'vq' and mask_rate > 0.:
            loss = self.vq_wav2vec_forward(x, mask_rate=mask_rate)
            return loss
        elif mode == 'ae':
            loss = self.ae_forward(x),
            return loss
        else:
            return self.straight_forward(x)

# dim, num_vars, groups, combine_groups, vq_dim
class ConfigFE:
    def __init__(self, nf, num_vars, groups, combine_groups, vq_dim, max_len):
        self.nf = nf
        self.num_vars = num_vars
        self.groups = groups
        self.combine_groups = combine_groups
        self.vq_dim = vq_dim
        self.max_len = max_len

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

fe_config = ConfigFE(32, 320, 2, True, 256, 25)
bert_config = BertConfig(hidden_size=256, num_attention_heads=4, num_hidden_layers=8, intermediate_size=1024)
uni_sml = UniSML(fe_config, bert_config).to(device)

lr = 3e-4
epochs = 1000
fe_config = ConfigFE(16, 320, 2, True, 256, 40)
bert_config = BertConfig(hidden_size=256, num_attention_heads=4, num_hidden_layers=8, intermediate_size=1024, hidden_dropout_prob=0.3)
uni_sml = UniSML(fe_config, bert_config).to('cuda')

opt = torch.optim.AdamW(uni_sml.parameters(), lr=lr)

for e in tqdm(range(epochs)):
    epoch_losses = []
    uni_sml.train()

    for x in vae_loader:
        x = x.to('cuda')
        loss = uni_sml.ae_forward(x)

        opt.zero_grad()
        loss.backward()
        opt.step()

        epoch_losses.append(loss.cpu().data.item())

    mean_loss = np.mean(np.array(epoch_losses))
    if (e+1) % 50 == 0:
        print('Loss at %d epoch: %.5f' % (e, mean_loss))

uni_sml.save_fe('ckpt/ckpt_fe.pkl')
uni_sml.save_fe_dec('ckpt/ckpt_fe_dec.pkl')
torch.save(opt.state_dict(), 'ckpt/autoencoder_opt.pkl')

# Downstream

In [None]:
import os
ori_path = 'Data/genres_original'

genres = ['reggae', 'pop', 'rock', 'hiphop', 'metal', 'country', 'disco', 'classical', 'blues', 'jazz']

import torch
import torchaudio
import torchaudio.functional as AF
import torchaudio.transforms as AT

n_fft = 1024
win_length = None
hop_length = 512
n_mels = 64
sample_rate = 22050
top_db = 80

mel_spectrogram = AT.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",
    power=2.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="htk",
)
ampl2db = AT.AmplitudeToDB(top_db=top_db)
# melspec = mel_spectrogram(waveform)
# melspec.shape


from tqdm.notebook import tqdm

genres_dict = {
    'reggae': 0,
    'pop': 1,
    'rock': 2,
    'hiphop': 3,
    'metal': 4,
    'country': 5,
    'disco': 6,
    'classical': 7,
    'blues': 8,
    'jazz': 9

}
mel_list = []
for g in tqdm(genres):
    for wav in os.listdir(os.path.join(ori_path, g)):
        wav_path = os.path.join(ori_path, g, wav)
        try:
            waveform, sample_rate = torchaudio.load(wav_path)
        except RuntimeError:
            # print(wav_path)
            continue
        melspect = mel_spectrogram(waveform)
        melspect = ampl2db(melspect)
        # print(melspect.shape)
        assert genres_dict[g] != None
        mel_list.append([melspect, torch.tensor(genres_dict[g])])

mels = []
for i, (elem, label) in enumerate(mel_list):
    if elem.shape[-1] >= 1293:
        elem = elem[..., :1293]
        mels.append(elem)

mels = torch.stack(mels, dim=0)

mels = mels.detach() / 100

new_mel_list = []
for mel, (_, label) in zip(mels, mel_list):
    new_mel_list.append([mel, label])

In [None]:
class MuseData(Dataset):
    def __init__(self, data, transform, mode='train', max_len=25):
        # processed_data = []

        x_shape = []
        for idx, (x, l) in enumerate(data):
            try:
                x_shape.append(x.shape[-1])
            except ValueError:
                print(idx)
        # x_shape = [l.shape[-1] for (l, _) in data]
        self.max_len = max_len
        # self.length_per_sec = int(np.max(np.array(x_shape)) // 128)

        self.data = data
        self.transform = transform
        self.mode = mode
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data, label = self.data[idx]
        ori_width = data.shape[-1]
        ori_height = data.shape[-2]

        if self.mode  == 'train':
            width = random.randint(int(ori_width * 0.9), int(ori_width * 1.1))
            scaler = VT.Resize((ori_height, width))
            cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2 ))
            # cropper(scaler(mel_list[0])).shape
            brightness = random.uniform(0.9, 1.1)
            mel = scaler(data)
            mel = (cropper(mel) * brightness).detach()
            # if self.mode == 'train':
            #     mel = self.transform(mel)
        else:
            # cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2))
            # mel = cropper(data)
            mel = data[..., :(self.max_len * ori_height) // 2 ]

        split_mel = []
        # print(mel.shape[2])
        for idx in range(self.max_len):

            piece = mel[:, :, idx*ori_height//2: (idx+1)*ori_height//2]
            if piece.shape[-1] != ori_height//2:
                print('Error piece')

            # piece = torch.zeros(piece)
            split_mel.append(piece)
        split_mel = torch.stack(split_mel, dim=0)
        return split_mel, label

import random
batch_size = 32
random.shuffle(new_mel_list)


# train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8),len(dataset) - int(len(dataset) * 0.8)])
train_set, val_set = MuseData(new_mel_list[:int(0.8 * len(new_mel_list))], None, 'train'), MuseData(new_mel_list[int(0.8*len(new_mel_list)):], None, 'val')
# train_set = MuseData(train_data)
train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=True,
        shuffle=True,

    )

# val_set = MuseData(val_data)
val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=True,

    )

In [None]:

class CLSF(nn.Module):
    def __init__(self, config):
        super(CLSF, self).__init__()
        self.model = nn.Sequential(
            # nn.Linear(config.hidden_size, config.hidden_size),
            # nn.ReLU(True),
            # nn.Dropout(0.2),
            nn.Linear(config.hidden_size, 10),
        )

    def forward(self, x):
        return self.model(x)
lr = 1e-3
epochs = 2000

fe_config = ConfigFE(16, 320, 2, True, 256, 40)
bert_config = BertConfig(hidden_size=256, num_attention_heads=4, num_hidden_layers=8, intermediate_size=1024, hidden_dropout_prob=0.1)

uni_sml = UniSML(fe_config, bert_config).to('cuda')
uni_sml.load_state_dict(torch.load('ckpt/uni_sml.pkl'))

for param in uni_sml.parameters():
    param.requires_grad = False

class Pooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # mean_hidden = hidden_states.mean(1)
        pooled_output = self.dense(hidden_states)
        pooled_output = self.activation(pooled_output)
        return pooled_output

pooler = Pooler(bert_config).to('cuda')
clsf = CLSF(bert_config).to('cuda')
crit = nn.CrossEntropyLoss()
opt = torch.optim.AdamW(list(uni_sml.parameters()) + list(pooler.parameters()) + list(clsf.parameters()), lr=lr,
                        weight_decay=0.1)



In [None]:
def multi_acc(y_pred, y_test):
    # print(y_pred.shape)
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    acc = torch.round(acc * 100)
    return acc


In [None]:
train_losses = []
acc_script = []
for epoch in tqdm(range(epochs)):

    loss_script = []

    uni_sml.train()
    pooler.train()
    clsf.train()

    total, correct = 0, 0
    for x, label in train_loader:
        x = x.to('cuda')
        label = label.to('cuda')

        rep = uni_sml(x, mode='straight', mask_rate=0.5)

        logit = clsf(pooler(rep[:, 0])).squeeze()
        loss = crit(logit, label)

        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_script.append(loss.cpu().data.item())
        _, pred = torch.max(logit.data, 1)
        total += label.size(0)
        correct += (pred == label).sum().item()

    accuracy = (correct / total) * 100
    if (epoch+1) % 1 == 0:
        print('train accuracy in %d epoch: %.4f' % ((epoch+1), accuracy))
    mean_loss = np.mean(np.array(loss_script))
    train_losses.append(mean_loss)

    with torch.no_grad():

        total, correct = 0, 0

        uni_sml.eval()
        pooler.eval()
        clsf.eval()

        for x, label in val_loader:
            x = x.to('cuda')
            label = label.to('cuda')

            rep = uni_sml(x, mode='straight', mask_rate=0.5)
            # print(rep.shape)
            logit = clsf(pooler(rep[:, 0])).squeeze()

            _, pred = torch.max(logit.data, 1)
            total += label.size(0)
            correct += (pred == label).sum().item()

        accuracy = (correct / total) * 100
        if (epoch+1) % 1 == 0:
            print('valid accuracy in %d epoch: %.4f' % ((epoch+1), accuracy))
            print('-' * 30)

In [None]:
import json
def parse_annotation_file(json_file):
    tracks = {}
    with open(json_file) as f:
        examples = json.load(f)
        for song_id in examples:
            tracks[song_id] = {
                'track_id': examples[song_id]['extra']['songs_info']['song_id'],
                'split': examples[song_id]['split'],
                'labels': examples[song_id]['y']
            }
    return tracks

infors = parse_annotation_file('clips_45seconds/emomusic.json')

In [None]:
n_fft = 1024
win_length = None
hop_length = 512
n_mels = 64
sample_rate = 44100
top_db = 80

mel_spectrogram = AT.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",
    power=2.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="htk",
)
ampl2db = AT.AmplitudeToDB(top_db=top_db)

mel_list = []
file_path = getListOfFiles('emo music')
for wav_path in tqdm(file_path):
    wav_id = wav_path[10: -4].zfill(4)

    try:
        waveform, sample_rate = torchaudio.load(wav_path)
        emotion = infors[wav_id]['labels']
    except RuntimeError:
        print(wav_path)
        continue
    except KeyError:
        print(wav_id)
        continue
    # scaler = VT.Resize((64, 1938))
    melspect = mel_spectrogram(waveform)
    melspect = ampl2db(melspect)
    melspect = (melspect).detach() * 0.01
    # print(melspect.shape)
    mel_list.append([melspect, emotion])

class MuseData(Dataset):
    def __init__(self, data, transform='None', mode='train', max_len=25):
        # processed_data = []

        x_shape = []
        for idx, (x, l) in enumerate(data):
            try:
                x_shape.append(x.shape[-1])
            except ValueError:
                print(idx)
        # x_shape = [l.shape[-1] for (l, _) in data]
        self.max_len = max_len
        # self.length_per_sec = int(np.max(np.array(x_shape)) // 128)

        self.data = data
        self.transform = transform
        self.mode = mode
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data, label = self.data[idx]
        ori_width = data.shape[-1]
        ori_height = data.shape[-2]

        if self.mode  == 'train':
            width = random.randint(int(ori_width * 0.9), int(ori_width * 1.1))
            scaler = VT.Resize((ori_height, width))
            cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2 ))
            # cropper(scaler(mel_list[0])).shape
            brightness = random.uniform(0.9, 1.1)
            mel = scaler(data)
            mel = (cropper(mel) * brightness).detach()
            # if self.mode == 'train':
            #     mel = self.transform(mel)
        else:
            # cropper = VT.RandomCrop((ori_height, self.max_len * ori_height // 2))
            # mel = cropper(data)
            mel = data[..., :(self.max_len * ori_height) // 2 ]

        split_mel = []
        # print(mel.shape[2])
        for idx in range(self.max_len):

            piece = mel[:, :, idx*ori_height//2: (idx+1)*ori_height//2]
            if piece.shape[-1] != ori_height//2:
                print('Error piece')

            # piece = torch.zeros(piece)
            split_mel.append(piece)
        split_mel = torch.stack(split_mel, dim=0)
        return split_mel, torch.tensor(label[0]), torch.tensor(label[1])

import random
batch_size = 32
random.shuffle(mel_list)


# train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8),len(dataset) - int(len(dataset) * 0.8)])
train_set, val_set = MuseData(mel_list[:int(0.8 * len(mel_list))], None, 'train'), MuseData(mel_list[int(0.8*len(mel_list)):], None, 'val')
# train_set = MuseData(train_data)
train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=True,
        shuffle=True,

    )

# val_set = MuseData(val_data)
val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=0,
        drop_last=False,
        shuffle=True,

    )



In [None]:
class CLSF(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.prob_valence = nn.Linear(config.hidden_size, 1)
        self.prob_arousal = nn.Linear(config.hidden_size, 1)

    def forward(self, x):
        return self.prob_valence(x), self.prob_arousal(x)

lr = 1e-4
epochs = 2000

fe_config = ConfigFE(16, 320, 2, True, 256, 40)
bert_config = BertConfig(hidden_size=256, num_attention_heads=4, num_hidden_layers=8, intermediate_size=1024, hidden_dropout_prob=0.1)

uni_sml = UniSML(fe_config, bert_config).to('cuda')
uni_sml.load_state_dict(torch.load('ckpt/uni_sml.pkl'))

for param in uni_sml.parameters():
      param.requires_grad = False

class Pooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # mean_hidden = hidden_states.mean(1)
        pooled_output = self.dense(hidden_states)
        pooled_output = self.activation(pooled_output)
        return pooled_output

pooler = Pooler(bert_config).to('cuda')
clsf = CLSF(bert_config).to('cuda')
crit = nn.MSELoss()
opt = torch.optim.AdamW(list(uni_sml.parameters()) + list(pooler.parameters()) + list(clsf.parameters()), lr=lr,
                        weight_decay=0.1)

In [None]:
import torchmetrics
from torchmetrics import R2Score

scorer = R2Score().to('cuda')

In [None]:
train_losses = []
acc_script = []
for epoch in tqdm(range(epochs)):

    loss_script = []

    uni_sml.train()
    pooler.train()
    clsf.train()

    flag = 0
    for x, label_v, label_a in train_loader:
        x = x.to('cuda')
        label_v = label_v.to('cuda')
        label_a = label_a.to('cuda')

        rep = uni_sml(x, mode='straight')

        logit_v, logit_a = clsf(pooler(rep[:, 1:].mean(1)))
        logit_v, logit_a = logit_v.squeeze(-1), logit_a.squeeze(-1)

        loss = crit(logit_v, label_v) + crit(logit_a, label_a)

        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_script.append(loss.cpu().data.item())
        if flag == 0:
            total_label_v, total_label_a = label_v.data, label_a.data
            total_logit_v, total_logit_a = logit_v.data, logit_a.data
            flag = 1
        else:
            total_label_v = torch.cat([total_label_v, label_v.data], dim=0)
            total_label_a = torch.cat([total_label_a, label_a.data], dim=0)

            total_logit_v = torch.cat([total_logit_v, logit_v.data], dim=0)
            total_logit_a = torch.cat([total_logit_a, logit_a.data], dim=0)

    r2_v = scorer(total_logit_v, total_label_v)
    r2_a = scorer(total_logit_a, total_label_a)

    mean_loss = np.mean(np.array(loss_script))
    train_losses.append(mean_loss)

    if (epoch+1) % 5 == 0:
        print('train loss in %d epoch: %.4f' % ((epoch+1), mean_loss))
        print('r2 score valence: %.4f' % r2_v)
        print('r2 score arousal: %.4f' % r2_a)

    with torch.no_grad():

        total, correct = 0, 0

        uni_sml.eval()
        pooler.eval()
        clsf.eval()

        for x, label_v, label_a in train_loader:
            x = x.to('cuda')
            label_v = label_v.to('cuda')
            label_a = label_a.to('cuda')

            rep = uni_sml(x, mode='straight', mask_rate=0.5)
            # print(rep.shape)
            logit_v, logit_a = clsf(pooler(rep[:, 1:].mean(1)))
            logit_v, logit_a = logit_v.squeeze(-1), logit_a.squeeze(-1)

            if flag == 0:
                total_label_v, total_label_a = label_v.data, label_a.data
                total_logit_v, total_logit_a = logit_v.data, logit_a.data
                flag = 1
            else:
                total_label_v = torch.cat([total_label_v, label_v.data], dim=0)
                total_label_a = torch.cat([total_label_a, label_a.data], dim=0)

                total_logit_v = torch.cat([total_logit_v, logit_v.data], dim=0)
                total_logit_a = torch.cat([total_logit_a, logit_a.data], dim=0)


        r2_v = scorer(total_logit_v, total_label_v)
        r2_a = scorer(total_logit_a, total_label_a)

        if (epoch+1) % 5 == 0:
            print('r2 score in %d epoch: %.4f, %.4f' % ((epoch+1), r2_v, r2_a))
            print('-' * 30)