In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("../src"))

from models.base_model import AutoAudioBaseModel
import evaluate
import torch
from transformers import (
    AutoModelForAudioClassification,
    TrainingArguments,
    Trainer,
    AutoFeatureExtractor,
)
import pandas as pd
import numpy as np
import uuid

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import subprocess
import pandas as pd

dataset_path = "data/gtzan-dataset-music-genre-classification"

if not os.path.exists(dataset_path):
    print("Dataset not found. Downloading...")
    os.makedirs(dataset_path, exist_ok=True)
    subprocess.run(
        [
            "kaggle",
            "datasets",
            "download",
            "-d",
            "andradaolteanu/gtzan-dataset-music-genre-classification",
            "-p",
            dataset_path,
            "--unzip",
        ]
    )
    print("Download complete.")
else:
    print("Dataset already exists.")

genres_path = os.path.join(dataset_path, "Data/genres_original")
paths = []
labels = []
for genre in os.listdir(genres_path):
    folder_path = os.path.join(genres_path, genre)
    for filename in os.listdir(folder_path):
        paths.append(os.path.join(folder_path, filename))
        labels.append(genre)
df = pd.DataFrame({"file_path": paths, "label": labels})


from models.transformer import AudioTransformer
import preprocessing as pre
from sklearn.model_selection import train_test_split
import numpy as np

df_train = df.sample(100, random_state=42)
df_test = df.sample(100, random_state=42)
data = df_train

data.reset_index(drop=True, inplace=True)
features, audios = pre.aggregate_audio_features(data)
features.reset_index(drop=True, inplace=True)
audios.reset_index(drop=True, inplace=True)

labels = data["label"]
unique = np.unique(labels)
n_unique = len(unique)

label2id = {}
id2label = {}
for i, label in enumerate(unique):
    label2id[label] = str(i)
    id2label[str(i)] = label

test_size = 0.2
indices = labels.index
train_indices, test_indices = train_test_split(
    indices, test_size=test_size, random_state=42, shuffle=True
)
labels_train = labels.loc[train_indices].values.reshape(-1)
audios_train = audios.loc[train_indices]
labels_test = labels.loc[test_indices].values.reshape(-1)
audios_test = audios.loc[test_indices]


Dataset already exists.


In [8]:
class AudioTransformer(AutoAudioBaseModel):
    def __init__(self, num_labels: int, label2id: dict, id2label: dict):
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(
            "facebook/wav2vec2-base"
        )
        self.model = AutoModelForAudioClassification.from_pretrained(
            "facebook/wav2vec2-base",
            num_labels=num_labels,
            label2id=label2id,
            id2label=id2label,
        )
        self.id = str(uuid.uuid4())
        self.path = "outputs/transformer" + self.id

    def fit(self, train_dataset, test_dataset):
        def preprocess_function(examples):
            audio_arrays = [x["array"] for x in examples["audio"]]
            inputs = self.feature_extractor(
                audio_arrays, sampling_rate=self.feature_extractor.sampling_rate, max_length=16000, truncation=True
            )
            return inputs
        encoded_train_data = train_dataset.map(preprocess_function, remove_columns="audio", batched=True)
        encoded_train_data = encoded_train_data.rename_column("intent_class", "label")

        training_args = TrainingArguments(
            output_dir=self.path,
            eval_strategy="epoch",
            save_strategy="epoch",
            learning_rate=3e-5,
            per_device_train_batch_size=32,
            gradient_accumulation_steps=4,
            per_device_eval_batch_size=32,
            num_train_epochs=10,
            warmup_ratio=0.1,
            logging_steps=10,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            push_to_hub=False,
        )

        accuracy = evaluate.load("accuracy")

        def compute_metrics(eval_pred):
            predictions = np.argmax(eval_pred.predictions, axis=1)
            return accuracy.compute(
                predictions=predictions, references=eval_pred.label_ids
            )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=encoded_train_data,
            eval_dataset=test_dataset,
            processing_class=self.feature_extractor,
            compute_metrics=compute_metrics,
        )

        trainer.train()

    def predict(self, features: pd.DataFrame) -> np.ndarray:
        # TODO: get features into correct format
        with torch.no_grad():
            logits = self.model(features["file_path"]).logits
            predicted_class_ids = torch.argmax(
                logits, dim=1
            ).item()  # TODO: check if dim is correct
            predicted_labels = self.model.config.id2label[predicted_class_ids]
        return predicted_labels

    def __str__(self) -> str:
        return "Transformer"

In [7]:
train_dataset.head()

Unnamed: 0,audio,label
0,"[0.051511027, 0.084782496, 0.06722592, 0.07209...",blues
1,"[-0.16309336, -0.1845084, -0.12476219, -0.1213...",pop
2,"[-0.009800588, 0.0051853997, 0.018586956, 0.00...",classical
3,"[0.11665485, 0.3679105, 0.47211415, 0.5051253,...",disco
4,"[-0.024236003, -0.088823035, -0.12571168, -0.1...",jazz


In [9]:
model = AudioTransformer(n_unique, label2id, id2label)
train_dataset = pd.DataFrame({"audio": audios_train["audio"].values, "label": labels_train})
test_dataset = pd.DataFrame({"audio": audios_test["audio"].values, "label": labels_test})
model.fit(train_dataset, test_dataset)


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/10 [14:36<?, ?it/s]


TypeError: AudioTransformer.fit.<locals>.preprocess_function() got an unexpected keyword argument 'remove_columns'