In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/si699-music-tagging/

In [None]:
!python3 preprocessing/convert_npy.py

In [None]:
import os
import random
import torch
import librosa
from librosa import display
import numpy as np
import glob
import torchaudio
from sklearn.preprocessing import LabelBinarizer
import csv
from transformers import AutoConfig, AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from sklearn.metrics import *
# from run.models import *
from run.attention_modules import *
import collections
import torch
import yaml
import json
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from transformers import BertModel, BertTokenizer
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2PreTrainedModel,
    Wav2Vec2Model
)
import logging
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.functional.classification import multilabel_auroc
from torchmetrics.classification import MultilabelPrecision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)
torch.manual_seed(config['seed'])
random.seed(config['seed'])
print("Run on:", device)

In [None]:
data_root = "data/autotagging_moodtheme/0*/*.mp3"
files = sorted(glob.glob(data_root))
print("Size:", len(files))
waveform, sr = librosa.load(files[0], sr=None, mono=True, offset=0.0, duration=None)
plt.figure(figsize=(15, 3))
display.waveshow(y=waveform, sr=sr)
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.show()

In [None]:
S = librosa.feature.melspectrogram(y=waveform, sr=sr)
fig, ax = plt.subplots()
fig.set_figheight(3)
fig.set_figwidth(18)
S_dB = librosa.power_to_db(S, ref=np.max)
img = librosa.display.specshow(S_dB, x_axis='time',
                         y_axis='mel', sr=sr,
                         fmax=8000, ax=ax)
fig.colorbar(img, ax=ax, format='%+2.0f dB')
ax.set(title='Mel-frequency spectrogram')

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, tracks_dict, npy_root, config, tags, data_type, feature_extractor_type):
        self.npy_root = npy_root
        self.config = config
        self.tracks_dict = tracks_dict
        self.tags = tags
        self.mlb = LabelBinarizer().fit(self.tags)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.title_dict = {}
        self.prepare_title()
        self.data = []
        self.input_ids = []
        self.attention_mask = []
        self.labels = []
        self.data_type = data_type
        self.prepare_data()
        self.feature_extractor_type = feature_extractor_type

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        assert 0 <= index < len(self)
        waveform = self.data[index]
        input_ids = self.input_ids[index]
        attention_mask = self.attention_mask[index]
        target = self.labels[index]
        if self.feature_extractor_type == 'raw':
            mel_spec = torch.Tensor(waveform)
        if self.feature_extractor_type == 'ast':
            feature_extractor = AutoFeatureExtractor.from_pretrained(
                "MIT/ast-finetuned-audioset-10-10-0.4593",
                sampling_rate=self.config['sample_rate'],
                num_mel_bins=self.config['n_mels']
            )
            encoding = feature_extractor(waveform, sampling_rate=self.config['sample_rate'], annotations=target, return_tensors="pt")
            mel_spec = encoding['input_values'].squeeze()
            mel_spec = torch.transpose(mel_spec, 0, 1)
        if self.feature_extractor_type == 'wav2vec':
            feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
                "facebook/wav2vec2-base-960h"
            )
            encoding = feature_extractor(waveform, sampling_rate=self.config['sample_rate'],
                                         return_tensors="pt")
            mel_spec = encoding['input_values'].squeeze()
        return mel_spec, input_ids, attention_mask, target
    
    def prepare_title(self):
        whole_filenames = sorted(glob.glob(os.path.join(self.npy_root, "*/*.npy")))
        titles = []
        for filename in whole_filenames:
            file_id = os.path.join(filename.split('/')[-2], filename.split('/')[-1].split('.')[0])
            titles.append(self.tracks_dict[file_id][1])
        encoding = self.tokenizer(titles, return_tensors='pt', padding=True, truncation=True)
        input_ids = encoding['input_ids']
        attention_mask = encoding['attention_mask']
        for idx, filename in enumerate(whole_filenames):
            file_id = os.path.join(filename.split('/')[-2], filename.split('/')[-1].split('.')[0])
            self.title_dict[file_id] = (input_ids[idx], attention_mask[idx])
        
    def prepare_data(self):
        whole_filenames = sorted(glob.glob(os.path.join(self.npy_root, "*/*.npy")))
        train_size = int(len(whole_filenames) * 0.8)
        filenames = []
        random.shuffle(whole_filenames)
        if self.data_type == 'train':
            filenames = whole_filenames[:train_size]
        if self.data_type == 'valid':
            filenames = whole_filenames[train_size:]
        for filename in tqdm(filenames):
            file_id = os.path.join(filename.split('/')[-2], filename.split('/')[-1].split('.')[0])
            if file_id not in self.tracks_dict:
                print(file_id)
                continue
            self.data.append(np.load(filename))
            self.input_ids.append(self.title_dict[file_id][0])
            self.attention_mask.append(self.title_dict[file_id][1])
            self.labels.append(np.sum(self.mlb.transform(self.tracks_dict[file_id][0]), axis=0))

In [None]:
def get_tags(tag_file, npy_root, isMap):
    id2title_dict = {}
    with open('data/raw.meta.tsv') as fp:
        reader = csv.reader(fp, delimiter='\t')
        next(reader, None)
        for row in reader:
            id2title_dict[row[0]] = row[3]

    if isMap:
        f = open('tag_categorize.json')
        data = json.load(f)
        categorize = {}
        for k, v in data.items():
            for i in v[1:-1].split(', '):
                categorize[i] = k
    tracks = {}
    total_tags = []
    with open(tag_file) as fp:
        reader = csv.reader(fp, delimiter='\t')
        next(reader, None)  # skip header
        for row in reader:
            if not os.path.exists(os.path.join(npy_root, row[3].replace('.mp3', '.npy'))):
                print(os.path.join(npy_root, row[3].replace('.mp3', '.npy')))
                continue
            track_id = row[3].split('.')[0]
            tags = []
            for tag in row[5:]:
                if isMap:
                    tags.append(categorize[tag.split('---')[-1]])
                else:
                    tags.append(tag.split('---')[-1])
            tracks[track_id] = (list(set(tags)), id2title_dict[row[0]])
            total_tags += list(set(tags))
    print("Distribution of tags:", collections.Counter(total_tags))
    plt.figure(figsize=(10,3))
    plt.xticks(rotation=90)
    plt.hist(total_tags)
    plt.savefig('dist.png')
    return tracks, list(set(total_tags))

In [None]:
print("Preparing dataset...")
tag_file = 'data/autotagging_moodtheme.tsv'
npy_root = 'data/waveform'
tracks_dict, tags = get_tags(tag_file, npy_root, True)
N_CLASSES = len(tags)
print(N_CLASSES)

In [None]:
transform = 'raw'
batch_size = 4
train_dataset = MyDataset(tracks_dict, npy_root, config, tags, "train", transform)
val_dataset = MyDataset(tracks_dict, npy_root, config, tags, "valid", transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

In [None]:
for waveform, input_ids, attention_mask, label in val_loader:
    print(waveform)
    print(input_ids)
    print(attention_mask)
    print(label)
    break

In [None]:
def train(model, epoch, criterion, optimizer, train_loader, is_title=False):
    losses = []
    ground_truth = []
    prediction = []
    model.train()
    for waveform, input_ids, attention_mask, label in tqdm(train_loader):
        waveform, label = waveform.to(device), label.to(device)
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
        if is_title:
            output = model(waveform, input_ids, attention_mask)
        else:
            output = model(waveform)
        loss = criterion(output, label.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.cpu().detach())
        ground_truth.append(label)
        prediction.append(output)
    get_eval_metrics(prediction, ground_truth, 'train', epoch, losses)


@torch.no_grad()
def validate(model, epoch, criterion, val_loader, is_title=False):
    losses = []
    ground_truth = []
    prediction = []
    model.eval()
    for waveform, input_ids, attention_mask, label in tqdm(val_loader):
        waveform, label = waveform.to(device), label.to(device)
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
        if is_title:
            output = model(waveform, input_ids, attention_mask)
        else:
            output = model(waveform)
        loss = criterion(output, label.float())
        losses.append(loss.cpu().detach())
        ground_truth.append(label)
        prediction.append(output)
    pre = get_eval_metrics(prediction, ground_truth, 'val', epoch, losses)
    return pre


def get_eval_metrics(outputs, labels, run_type, epoch, losses):
    outputs = torch.cat(outputs, dim=0)
    labels = torch.cat(labels, dim=0)
    assert outputs.shape == labels.shape
    # 1. number of correctly predicted tags divided by the total number of tags
    prob_classes = []
    for i in range(labels.size(0)):
        label = labels[i]
        k = label.sum()
        _, idx = outputs[i].topk(k=k)
        predict = torch.zeros_like(outputs[i])
        predict[idx] = 1
        prob_classes.append(predict)
    prob_classes = torch.stack(prob_classes)
    matched_1s = torch.mul(prob_classes, labels)
    correct_tag_percentage = matched_1s.sum() / labels.sum()

    # 2. Auroc
    auroc = multilabel_auroc(outputs, labels, num_labels=N_CLASSES, average="macro", thresholds=None).item()

    # 3. avg precision
    metric = MultilabelPrecision(average='macro', num_labels=N_CLASSES, thresholds=None).to(device)
    pre = metric(outputs, labels).item()

    # write tensorboard and logging file
    writer.add_scalar("Loss/{}".format(run_type), np.mean(losses), epoch)
    writer.add_scalar("Auroc/{}".format(run_type), auroc, epoch)
    writer.add_scalar("Pre/{}".format(run_type), pre, epoch)
    writer.add_scalar("Avg_percent/{}".format(run_type), correct_tag_percentage, epoch)
    print("{} - epoch: {}, loss: {}, auroc: {}, pre: {}, avg percent: {}".format(
        run_type, epoch, np.mean(losses), auroc, pre, correct_tag_percentage))
    logging.info("{} - epoch: {}, loss: {}, auroc: {}, pre: {}, avg percent: {}".format(
        run_type, epoch, np.mean(losses), auroc, pre, correct_tag_percentage))
    return correct_tag_percentage


def get_model(model_name, tags):
    if model_name =='samplecnn':
        model = SampleCNN(N_CLASSES, config).to(device)
    elif model_name == 'crnn':
        model = CRNN(N_CLASSES, config).to(device)
    elif model_name =='fcn':
        model = FCN(N_CLASSES, config).to(device)
    elif model_name == 'musicnn':
        model = Musicnn(N_CLASSES, config).to(device)
    elif model_name == 'musicnn_title':
        model = MusicnnwithTitle(N_CLASSES, config).to(device)
    elif model_name == 'shortchunkcnn_res':
        model = ShortChunkCNN_Res(N_CLASSES, config).to(device)
    elif model_name == 'cnnsa':
        model = CNNSA(N_CLASSES, config).to(device)
    elif model_name == 'baseline2':
        model = Baseline2(N_CLASSES, config).to(device)
    elif model_name == 'wav2vec':
        model_config = AutoConfig.from_pretrained(
            "facebook/wav2vec2-base-960h",
            num_labels=N_CLASSES,
            label2id={label: i for i, label in enumerate(tags)},
            id2label={i: label for i, label in enumerate(tags)},
            finetuning_task="wav2vec2_clf",
        )
        model = Wav2Vec2ForSpeechClassification(model_config).to(device)
    else:
        model = SampleCNN(N_CLASSES, config).to(device)
    return model

In [None]:
class MusicnnwithTitle(nn.Module):
    def __init__(self, num_classes, config=None):
        super(Musicnn, self).__init__()
        self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=config['sample_rate'],
                                                  n_fft=config['n_fft'],
                                                  f_min=config['fmin'],
                                                  f_max=config['fmax'],
                                                  n_mels=config['n_mels'])

        # Spectrogram
        self.to_db = torchaudio.transforms.AmplitudeToDB()
        self.spec_bn = nn.BatchNorm2d(1)

        # Pons front-end
        m1 = Conv_V(1, 204, (int(0.7 * 96), 7))
        m2 = Conv_V(1, 204, (int(0.4 * 96), 7))
        m3 = Conv_H(1, 51, 129)
        m4 = Conv_H(1, 51, 65)
        m5 = Conv_H(1, 51, 33)
        self.layers = nn.ModuleList([m1, m2, m3, m4, m5])

        # Pons back-end
        backend_channel = 512
        self.layer1 = Conv_1d(561, backend_channel, kernel_size=7, stride=1, padding=3, pooling=1)
        self.layer2 = Conv_1d(backend_channel, backend_channel, kernel_size=7, stride=1, padding=3, pooling=1)
        self.layer3 = Conv_1d(backend_channel, backend_channel, kernel_size=7, stride=1, padding=3, pooling=1)

        # Dense
        dense_channel = 500
        self.dense1 = nn.Linear((561 + (backend_channel * 3)) * 2, dense_channel)
        self.bn = nn.BatchNorm1d(dense_channel)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.dense2 = nn.Linear(dense_channel, num_classes)
        
        self.bert = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
        self.dense3 = nn.Linear(self.bert.config.hidden_size, 256)
        self.dense4 = nn.Linear(256, num_classes)
        
        self.dense5 = nn.Linear(2*num_classes, num_classes)
        

    def forward(self, x, input_ids=None, attention_mask=None):
        # Spectrogram
        x = self.spec(x)
        x = self.to_db(x)
        x = x.unsqueeze(1)
        x = self.spec_bn(x)

        # Pons front-end
        out = []
        for layer in self.layers:
            out.append(layer(x))
        out = torch.cat(out, dim=1)

        # Pons back-end
        length = out.size(2)
        res1 = self.layer1(out)
        res2 = self.layer2(res1) + res1
        res3 = self.layer3(res2) + res2
        out = torch.cat([out, res1, res2, res3], 1)

        mp = nn.MaxPool1d(length)(out)
        avgp = nn.AvgPool1d(length)(out)

        out = torch.cat([mp, avgp], dim=1)
        out = out.squeeze(2)

        out = self.relu(self.bn(self.dense1(out)))
        out = self.dropout(out)
        out = self.dense2(out)
        
        out_title = self.bert(input_ids, attention_mask=attention_mask)
        out_title = self.dense3(out_title.pooler_output)
        out_title = self.dropout(out_title)
        out_title = self.dense4(out_title)
        
        out = torch.cat((out_title, out), dim=1)
        out = self.dense5(out)
        out = nn.Sigmoid()(out)
        return out

class Res_2d(nn.Module):
    def __init__(self, input_channels, output_channels, shape=3, stride=2):
        super(Res_2d, self).__init__()
        # convolution
        self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
        self.bn_1 = nn.BatchNorm2d(output_channels)
        self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
        self.bn_2 = nn.BatchNorm2d(output_channels)

        # residual
        self.diff = False
        if (stride != 1) or (input_channels != output_channels):
            self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
            self.bn_3 = nn.BatchNorm2d(output_channels)
            self.diff = True
        self.relu = nn.ReLU()

    def forward(self, x):
        # convolution
        out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))

        # residual
        if self.diff:
            x = self.bn_3(self.conv_3(x))
        out = x + out
        out = self.relu(out)
        return out

class CNNSA(nn.Module):
    '''
    Won et al. 2019
    Toward interpretable music tagging with self-attention.
    Feature extraction with CNN + temporal summary with Transformer encoder.
    '''
    def __init__(self, n_class, config=None):
        super(CNNSA, self).__init__()
        self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=config['sample_rate'],
                                                  n_fft=config['n_fft'],
                                                  f_min=config['fmin'],
                                                  f_max=config['fmax'],
                                                  n_mels=config['n_mels'])

        # Spectrogram
        self.to_db = torchaudio.transforms.AmplitudeToDB()
        self.spec_bn = nn.BatchNorm2d(1)

        # CNN
        n_channels = 128
        self.layer1 = Res_2d(1, n_channels, stride=2)
        self.layer2 = Res_2d(n_channels, n_channels, stride=2)
        self.layer3 = Res_2d(n_channels, n_channels * 2, stride=2)
        self.layer4 = Res_2d(n_channels * 2, n_channels * 2, stride=(2, 1))
        self.layer5 = Res_2d(n_channels * 2, n_channels * 2, stride=(2, 1))
        self.layer6 = Res_2d(n_channels * 2, n_channels * 2, stride=(2, 1))
        self.layer7 = Res_2d(n_channels * 2, n_channels * 2, stride=(2, 1))

        # Transformer encoder
        bert_config = BertConfig(vocab_size=256,
                                 hidden_size=256,
                                 num_hidden_layers=2,
                                 num_attention_heads=8,
                                 intermediate_size=1024,
                                 hidden_act="gelu",
                                 hidden_dropout_prob=0.4,
                                 max_position_embeddings=700,
                                 attention_probs_dropout_prob=0.5)
        self.encoder = BertEncoder(bert_config)
        self.pooler = BertPooler(bert_config)
        self.vec_cls = self.get_cls(256)

        # Dense
        self.dropout = nn.Dropout(0.5)
        self.dense = nn.Linear(256, n_class)

    def get_cls(self, channel):
        np.random.seed(0)
        single_cls = torch.Tensor(np.random.random((1, channel)))
        vec_cls = torch.cat([single_cls for _ in range(64)], dim=0)
        vec_cls = vec_cls.unsqueeze(1)
        return vec_cls

    def append_cls(self, x):
        batch, _, _ = x.size()
        part_vec_cls = self.vec_cls[:batch].clone()
        part_vec_cls = part_vec_cls.to(x.device)
        return torch.cat([part_vec_cls, x], dim=1)

    def forward(self, x):
        # Spectrogram
        x = self.spec(x)
        x = self.to_db(x)
        x = x.unsqueeze(1)
        x = self.spec_bn(x)

        # CNN
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = x.squeeze(2)

        # Get [CLS] token
        x = x.permute(0, 2, 1)
        x = self.append_cls(x)

        # Transformer encoder
        x = self.encoder(x)
        x = x[-1]
        x = self.pooler(x)

        # Dense
        x = self.dropout(x)
        x = self.dense(x)
        x = nn.Sigmoid()(x)

        return x

In [None]:
model_name = 'cnnsa'
learning_rate = 1e-4
num_epochs = 15
model = get_model(model_name, tags)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
print("Training and validating model...")
writer = SummaryWriter('runs/{}_{}_{}_{}'.format(model_name, learning_rate, batch_size, len(tags)))
logging.basicConfig(filename="log/log_{}_{}_{}_{}".format(model_name, learning_rate, batch_size, len(tags)),
                    format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.INFO)
best_pre = float('-inf')
for epoch in range(num_epochs):
    train(model, epoch, criterion, optimizer, train_loader, False)
    pre = validate(model, epoch, criterion, val_loader, False)
    if pre > best_pre:
        print("Best avg precision:", pre)
        best_pre = pre
        torch.save(model.state_dict(), 'model/{}_best_score_{}_{}.pt'.format(model_name, learning_rate, len(tags)))

In [None]:
def baseline1(tag_file, npy_root, batch_size, isMap, val_loader, tags):
    whole_filenames = sorted(glob.glob(os.path.join(npy_root, "*/*.npy")))
    train_size = int(len(whole_filenames) * 0.8)
    filenames = []
    random.shuffle(whole_filenames)
    train_filenames = whole_filenames[:train_size]
    train_ids = []
    for filename in train_filenames:
        train_ids.append(filename.split('/')[-2] + '/' + filename.split('/')[-1])
    if isMap:
        f = open('tag_categorize.json')
        data = json.load(f)
        categorize = {}
        for k, v in data.items():
            for i in v[1:-1].split(', '):
                categorize[i] = k
    train_total_tags = []
    with open(tag_file) as fp:
        reader = csv.reader(fp, delimiter='\t')
        next(reader, None)  # skip header
        for row in reader:
            if row[3].replace('.mp3', '.npy') not in train_ids:
                # if not in train set
                continue
            if not os.path.exists(os.path.join(npy_root, row[3].replace('.mp3', '.npy'))):
                print(os.path.join(npy_root, row[3].replace('.mp3', '.npy')))
                continue
            tmp = []
            for tag in row[5:]:
                if isMap:
                    tmp.append(categorize[tag.split('---')[-1]])
                else:
                    tmp.append(tag.split('---')[-1])
            train_total_tags += list(set(tmp))

    train_dist_tags = collections.Counter(train_total_tags)
    print(train_dist_tags)
    total = 0
    for v in train_dist_tags.values():
        total += v
    probs = []
    for t in tags:
        probs.append(train_dist_tags[t]/total)
    labels, outputs = [], []
    for _, _, _, label in val_loader:  
        labels.append(label)
        for _ in range(label.size(0)):
            outputs.append(probs)
    
    outputs = torch.Tensor(outputs)
    labels = torch.cat(labels, dim=0)
    assert outputs.shape == labels.shape, "{}, {}".format(outputs.shape, labels.shape)
    # 1. number of correctly predicted tags divided by the total number of tags
    prob_classes = []
    for i in range(labels.size(0)):
        label = labels[i]
        k = label.sum()
        _, idx = outputs[i].topk(k=k)
        predict = torch.zeros_like(outputs[i])
        predict[idx] = 1
        prob_classes.append(predict)
    prob_classes = torch.stack(prob_classes)
    matched_1s = torch.mul(prob_classes, labels)
    correct_tag_percentage = matched_1s.sum() / labels.sum()

    # 2. Auroc
    auroc = multilabel_auroc(outputs, labels, num_labels=N_CLASSES, average="macro", thresholds=None).item()

    # 3. avg precision
    metric = MultilabelPrecision(average='macro', num_labels=N_CLASSES, thresholds=None).to(device)
    pre = metric(outputs, labels).item()

    print("auroc: {}, pre: {}, avg percent: {}".format(auroc, pre, correct_tag_percentage))
baseline1(tag_file, npy_root, batch_size, False, val_loader, tags)