In [1]:
!pip install evaluate transformers[torch] torchaudio wandb

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting wandb
  Downloading wandb-0.19.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting transformers[torch]
  Downloading transformers-4.48.0-py3-none-any.whl.metadata (44 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting xxhash (from evaluate)
  Downloading xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.17-py312-none-any.whl.metadata (7.2 kB)
Collecting huggingface-hub>=0.7.0 (from evaluate)
  Downloading huggingface_hub-0.27.1-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers[torch])
  Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers[torch])
 

In [1]:
import json
import torch
import wandb
import random
import evaluate
import numpy as np
import matplotlib.pyplot as plt

from collections import Counter 
from datasets import DatasetDict
from torch import tensor, Tensor
from torchaudio import transforms
from torch.nn import CrossEntropyLoss
from torch.nn.functional import interpolate, pad
from transformers.trainer_utils import EvalPrediction
from sklearn.utils.class_weight import compute_class_weight
from transformers import ASTFeatureExtractor, ASTConfig, ASTForAudioClassification, TrainingArguments, Trainer

In [2]:
%env WANDB_PROJECT=genre_classification

env: WANDB_PROJECT=genre_classification


In [3]:
wandb.login(key="...", host="https://wandb.justinkonratt.com")

[34m[1mwandb[0m: Currently logged in as: [33mcodesdowork[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for wandb.justinkonratt.com to your netrc file: /home/jovyan/.netrc


True

In [3]:
with open("./dataset/genres.json", "r", encoding="utf-8") as f:
    genres = json.load(f)
genres

['Bad',
 'Bassy',
 'Big Room',
 'Bounce',
 'Chill',
 'Chillstep',
 'Classic',
 'Coding',
 'Country',
 'Cro',
 'Deep House',
 'Drum and Bass',
 'Dubstep',
 'EDM',
 'Electro',
 'Electro House',
 'Emotional',
 'Epic',
 'Folk',
 'Frenchcore',
 'Glitch Hop',
 'God',
 'Groove',
 'Hands Up',
 'Hardcore',
 'Hardstyle',
 'Harp',
 'Hip Hop & Rap',
 'Historic',
 'Latino',
 'Lounge',
 'Malle',
 'Minimal',
 'Motivation',
 'Orchestra Pop',
 'Orchestral Electro',
 'OVERWERK',
 'Pop',
 'Pop mit Beat',
 'Psy',
 'Psytrance',
 'RnB',
 'Rock',
 'Synthpop',
 'Techno',
 'Tekk',
 'Trance',
 'Weihnachten']

In [None]:
preprocessed_dataset = DatasetDict.load_from_disk("./dataset/music_lib")
preprocessed_dataset

DatasetDict({
    train: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 16378
    })
    validate: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 2342
    })
    test: Dataset({
        features: ['input_values', 'labels', 'paths'],
        num_rows: 4675
    })
})

In [5]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)

model_input_name = feature_extractor.model_input_names[0]
labels_name = "labels"
paths_name = "paths"

In [6]:
config = ASTConfig.from_pretrained(pretrained_model)
config.num_labels = len(genres)
config.label2id = { genre: idx for idx, genre in enumerate(genres) }
config.id2label = { idx: genre for idx, genre in enumerate(genres) }
config.max_length = 498
config.hidden_dropout_prob = 0.05
config.attention_probs_dropout_prob = 0.05

model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- audio_spectrogram_transformer.embeddings.position_embeddings: found shape torch.Size([1, 1214, 768]) in the checkpoint and torch.Size([1, 590, 768]) in the model instantiated
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([48]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([48, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def calc_aggregated_accuracy(predictions: Tensor, labels: Tensor):
    v_score = 0
    w_score = 0
    a_score = 0
    m_score = 0

    songs_per_genre = [0 for _ in range(len(genres))]
    correctly_predicted_per_genre = [0 for _ in range(len(genres))]

    eval_subset = None
    for subset in preprocessed_dataset.values():
        if len(subset) == predictions.shape[0]:
            eval_subset = subset

    entries_per_song = list(Counter(eval_subset[paths_name]).values())
    song_count = len(entries_per_song)
    start_idx = 0
    for song_entries in entries_per_song:
        label = labels[start_idx]
        songs_per_genre[label] += 1
        song_logits = predictions[start_idx:start_idx + song_entries,:]
        start_idx += song_entries

        # voting
        v_pred = song_logits.argmax(dim=1).mode().values.item()
        if v_pred == label:
            v_score += 1

        # weighting
        confidence = song_logits.softmax(dim=1).max(dim=1).values
        weighted_logits = (song_logits.T * confidence).T
        w_pred = weighted_logits.mean(dim=0).argmax().item()
        if w_pred == label:
            w_score += 1

        # average
        a_pred = song_logits.mean(dim=0).argmax().item()
        if a_pred == label:
            a_score += 1
            correctly_predicted_per_genre[label] += 1

        # max
        m_pred = song_logits.max(dim=0).values.argmax().item()
        if m_pred == label:
            m_score += 1

    return {
        "voting_score": v_score / song_count,
        "weighting_score": w_score / song_count,
        "mean_pooling_score": a_score / song_count,
        "max_pooling_score": m_score / song_count,
        **{ f"{genre}_a_accuracy": correctly_predicted_per_genre[idx] / songs_per_genre[idx] for idx, genre in enumerate(genres) }
    }

accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

def compute_metrics(eval_pred: EvalPrediction):
    logits = eval_pred.predictions
    labels = eval_pred.label_ids
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=labels)
    metrics.update(precision.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=labels, average=AVERAGE))
    metrics.update(calc_aggregated_accuracy(tensor(logits), tensor(labels)))
    return metrics

In [8]:
def visualize_spectrum(specs, size=(10,6), cols=1, rows=1):
    plt.figure(figsize=size)
    for idx, spec in enumerate(specs):
        plt.subplot(rows, cols, idx + 1)
        plt.imshow(spec.T, aspect='auto', origin='lower', cmap='viridis')
        plt.colorbar(label="Amplitude")
        plt.xlabel("Time Frames")
        plt.ylabel("Frequency Bins")
        plt.tight_layout()
    plt.tight_layout()
    plt.show()

In [9]:
class SpecAugmentPipeline:
    def __init__(
            self,
            p=0.5,
            effects_p=0.5,
            time_mask_param=30,
            freq_mask_param=20,
            noise_level=0.05,
            stretch_range=(0.95, 1.05),
            shift_range=3,
            amplitude_range=(0.8, 1.2)
    ):
        self.p = p
        self.effects_p = effects_p
        self.time_mask = transforms.TimeMasking(time_mask_param)
        self.freq_mask = transforms.FrequencyMasking(freq_mask_param)
        self.noise_level = noise_level
        self.stretch_range = stretch_range
        self.shift_range = shift_range
        self.amplitude_range = amplitude_range

    def add_noise(self, spec):
        return spec + torch.randn_like(spec) * self.noise_level
    
    def time_stretch(self, spec):
        spec = spec.unsqueeze(0)
        factor = random.uniform(*self.stretch_range)
        new_steps = int(spec.size(-1) * factor)
        new_spec = interpolate(spec, (spec.size(-2), new_steps), mode="bilinear", align_corners=False).squeeze(0)
        return new_spec.resize_(spec.size(-3), spec.size(-2), spec.size(-1)) if factor >= 1 else pad(new_spec, (0, spec.size(-1) - new_steps))

    def frequency_shift(self, spec):
        shift = random.randint(-self.shift_range, self.shift_range)
        return torch.roll(spec, shifts=shift, dims=-2)

    def amplitude_scaling(self, spec):
        return spec * random.uniform(*self.amplitude_range)

    def __call__(self, spec):
        if random.random() >= self.p:
            return spec
        
        spec = spec.transpose(-1, -2)
        if random.random() < self.effects_p:
            spec = self.time_mask(spec)
        
        if random.random() < self.effects_p:
            spec = self.freq_mask(spec)

        if random.random() < self.effects_p:
            spec = self.add_noise(spec)

        if random.random() < self.effects_p:
            spec = self.time_stretch(spec)

        if random.random() < self.effects_p:
            spec = self.frequency_shift(spec)

        if random.random() < self.effects_p:
            spec = self.amplitude_scaling(spec)

        return spec.transpose(-1, -2)

In [10]:
aug_pipe = SpecAugmentPipeline(p=0.5)
def augmentation(sample):
    if model_input_name in sample:
        sample[model_input_name] = aug_pipe(tensor(sample[model_input_name]))
    return sample

preprocessed_dataset["train"].set_transform(augmentation)

In [11]:
labels = preprocessed_dataset["train"][labels_name]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float, device="cuda")
class_weights

tensor([ 0.3256,  4.0142,  1.6644,  0.6226,  2.5275,  3.4465,  1.2057,  8.5302,
         1.0832,  1.8149,  2.6247,  2.0679,  2.7297,  1.9062,  0.5557,  0.8817,
         4.2651,  8.5302, 17.0604,  1.5509,  3.1593,  4.8744,  0.9612,  0.0710,
         0.6261,  0.4414, 17.0604,  0.8794,  6.0930,  2.3532,  0.9833,  1.0832,
         2.5275,  0.8530,  4.2651,  3.2496,  8.5302, 13.6483,  5.2494,  1.4582,
        11.3736,  2.3532,  0.5603, 17.0604,  0.7898,  3.4121,  0.3243,  4.5494],
       device='cuda:0')

In [12]:
def weighted_loss_func(outputs, labels, num_items_in_batch):
    logits = outputs.get("logits")
    loss_fct = CrossEntropyLoss(weight=class_weights.to(logits.device))
    return loss_fct(logits, labels)

In [13]:
batch_size = 128
devices = 2
warmup_epochs = 1

training_args = TrainingArguments(
    output_dir="./runs/ast_classifier",
    report_to="wandb",
    run_name="better_stretch",
    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_steps=round(len(preprocessed_dataset["train"]) / batch_size * warmup_epochs),
    push_to_hub=False,
    num_train_epochs=15,
    per_device_train_batch_size=batch_size // devices,
    per_device_eval_batch_size=batch_size // devices,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="mean_pooling_score",
    greater_is_better=True,
    logging_strategy="steps",
    logging_steps=1,
    fp16=True,
    save_total_limit=3
)

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=preprocessed_dataset["train"],
    eval_dataset=preprocessed_dataset["validate"],
    compute_metrics=compute_metrics,
    compute_loss_func=weighted_loss_func
)

In [None]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mcodesdowork[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Voting Score,Weighting Score,Mean Pooling Score,Max Pooling Score,Bad A Accuracy,Bassy A Accuracy,Big room A Accuracy,Bounce A Accuracy,Chill A Accuracy,Chillstep A Accuracy,Classic A Accuracy,Coding A Accuracy,Country A Accuracy,Cro A Accuracy,Deep house A Accuracy,Drum and bass A Accuracy,Dubstep A Accuracy,Edm A Accuracy,Electro A Accuracy,Electro house A Accuracy,Emotional A Accuracy,Epic A Accuracy,Folk A Accuracy,Frenchcore A Accuracy,Glitch hop A Accuracy,God A Accuracy,Groove A Accuracy,Hands up A Accuracy,Hardcore A Accuracy,Hardstyle A Accuracy,Harp A Accuracy,Hip hop & rap A Accuracy,Historic A Accuracy,Latino A Accuracy,Lounge A Accuracy,Malle A Accuracy,Minimal A Accuracy,Motivation A Accuracy,Orchestra pop A Accuracy,Orchestral electro A Accuracy,Overwerk A Accuracy,Pop A Accuracy,Pop mit beat A Accuracy,Psy A Accuracy,Psytrance A Accuracy,Rnb A Accuracy,Rock A Accuracy,Synthpop A Accuracy,Techno A Accuracy,Tekk A Accuracy,Trance A Accuracy,Weihnachten A Accuracy
1,4.0475,3.3806,0.103757,0.079766,0.14974,0.076865,0.129787,0.129787,0.144681,0.104255,0.0,0.5,0.833333,0.0,0.0,0.666667,0.875,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.090909,0.0,1.0,1.0,0.0,0.0,0.0,0.5,0.0,0.0,0.272727,1.0,0.272727,0.0,0.0,0.454545,0.0,1.0,0.454545,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5,0.0,0.5,0.0
2,2.8324,2.854285,0.270282,0.278681,0.252046,0.173308,0.376596,0.365957,0.353191,0.351064,0.033333,0.0,0.0,0.0625,0.0,0.0,0.875,0.0,0.0,0.0,0.0,0.2,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.142857,0.0,0.5,0.2,0.540146,0.1875,0.363636,1.0,0.909091,1.0,0.75,0.545455,0.555556,1.0,0.909091,0.5,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.416667,0.333333,0.5,0.0
3,3.8754,2.703272,0.297182,0.318177,0.322385,0.254543,0.421277,0.421277,0.408511,0.376596,0.366667,1.0,0.5,0.0,0.75,0.666667,0.875,0.0,0.222222,0.666667,0.0,0.0,0.333333,0.0,0.0,0.090909,1.0,1.0,0.0,0.571429,0.0,1.0,0.6,0.335766,0.75,0.454545,1.0,0.818182,1.0,0.5,0.636364,0.444444,0.5,0.818182,0.5,0.0,1.0,0.0,0.0,0.571429,1.0,0.25,0.058824,0.0,0.583333,0.333333,0.666667,0.0
4,1.3965,2.552171,0.32152,0.319745,0.348444,0.27654,0.412766,0.506383,0.489362,0.434043,0.5,1.0,0.666667,0.1875,0.25,0.0,1.0,0.0,0.333333,0.666667,0.25,0.4,0.666667,0.0,0.529412,0.0,1.0,1.0,1.0,0.285714,0.0,0.5,0.6,0.452555,0.5625,0.409091,1.0,0.818182,0.5,0.5,0.818182,0.444444,0.75,0.727273,1.0,0.666667,0.0,0.0,0.5,0.571429,0.0,0.25,0.588235,0.0,0.5,0.666667,0.6,0.0


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [16]:
trainer._load_best_model()

In [17]:
model.eval()
with torch.no_grad():
    results = trainer.predict(preprocessed_dataset["test"])
results

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


PredictionOutput(predictions=array([[ 2.3468895 , -1.9537987 ,  4.9273396 , ...,  0.22495458,
         2.7202008 , -0.7991523 ],
       [-1.4988972 ,  0.8217171 , -0.79373574, ..., -0.9814781 ,
        -1.1862347 , -2.207212  ],
       [ 4.4040856 , -2.074986  ,  4.2714896 , ..., -3.085356  ,
        -0.21997654, -1.6877873 ],
       ...,
       [ 6.7366014 , -1.6003168 , -0.6002155 , ..., -3.960124  ,
        -0.43739069, -1.7418805 ],
       [ 9.027047  , -1.838432  ,  1.3532044 , ..., -3.1699862 ,
        -1.8196579 , -0.30163497],
       [ 5.187873  ,  1.8964833 , -2.5847745 , ..., -1.297473  ,
         0.1496332 ,  0.34114793]], dtype=float32), label_ids=array([35, 35, 35, ...,  0,  0,  0]), metrics={'test_loss': 2.74991774559021, 'test_accuracy': 0.533903743315508, 'test_precision': 0.39847213890473193, 'test_recall': 0.3760192702529492, 'test_f1': 0.3755052133777681, 'test_voting_score': 0.6971307120085016, 'test_weighting_score': 0.7449521785334751, 'test_mean_pooling_score': 0

In [18]:
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Create confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Normalize the confusion matrix (optional)
cm_normalized_row = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm_normalized_row, xticklabels=genres, yticklabels=genres)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()


NameError: name 'true_labels' is not defined

In [16]:
wandb.finish()

0,1
eval/Bad_a_accuracy,▁▁▅▆▇███
eval/Bassy_a_accuracy,█▁██████
eval/Big Room_a_accuracy,▁▅▄▅▅█▅█
eval/Bounce_a_accuracy,▁▂▁▇▆▇█▂
eval/Chill_a_accuracy,▁▁██████
eval/Chillstep_a_accuracy,▃▁██▃▆▁▃
eval/Classic_a_accuracy,▁▅▁▁▁▁█▁
eval/Coding_a_accuracy,▁▁▁▁▁▁▁▁
eval/Country_a_accuracy,▂▁▆▄▆█▆▆
eval/Cro_a_accuracy,▁▁▇█▆██▇

0,1
eval/Bad_a_accuracy,0.93333
eval/Bassy_a_accuracy,1.0
eval/Big Room_a_accuracy,1.0
eval/Bounce_a_accuracy,0.25
eval/Chill_a_accuracy,0.75
eval/Chillstep_a_accuracy,0.33333
eval/Classic_a_accuracy,0.75
eval/Coding_a_accuracy,0.0
eval/Country_a_accuracy,0.66667
eval/Cro_a_accuracy,0.83333
