In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# Setup

In [2]:
import time

import pandas as pd
from transformers import AutoFeatureExtractor
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

from src.dataset import *
from src.train import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
RES_DIR_PATH = "res"
AUDIOS_DIR_PATH = os.path.join(RES_DIR_PATH, "mp3_data")
MODELS_DIR_PATH = os.path.join(RES_DIR_PATH, "models")
DATASETS_DIR_PATH = os.path.join(RES_DIR_PATH, "datasets")

CSV_PATH = os.path.join(RES_DIR_PATH, "samples.csv")

MODEL_NAME = "facebook/wav2vec2-base"
TARGET_FEATURE = "genre"

TEST_SIZE = 0.2

In [4]:
model_id = MODEL_NAME.replace("/", "-")
run_name = f"{model_id}-{time.strftime('%Y%m%d-%H%M%S')}"

# Analysis

In [5]:
top_n = {
  "genre": 5
}
keep_features = ["genre", "category"]

csv_base_path = CSV_PATH.split(".")[0]
filtered_csv_path = "_".join([csv_base_path] + [f"{f}{n}" for f, n in top_n.items()]) + ".csv"

if os.path.exists(filtered_csv_path):
    df = pd.read_csv(filtered_csv_path)
else:
    df = pd.read_csv(CSV_PATH)
    df = filter_df(df, audios_dir_path=AUDIOS_DIR_PATH, keep_features=keep_features, top_n=top_n)
    df.to_csv(filtered_csv_path, index=False)

df.head()

Unnamed: 0,genre,category,mp3_path,id
0,Hip Hop,Trumpet,res/mp3_data/01 Hip Hop/Abandoned Brass Stabs.mp3,01_Hip_Hop_Abandoned_Brass_Stabs
1,Hip Hop,Timpani,res/mp3_data/01 Hip Hop/Abandoned Orchestral L...,01_Hip_Hop_Abandoned_Orchestral_Layers
2,Hip Hop,Electronic Beats,res/mp3_data/01 Hip Hop/Afloat Beat.mp3,01_Hip_Hop_Afloat_Beat
3,Hip Hop,Synthesizer,res/mp3_data/01 Hip Hop/Afloat Pad.mp3,01_Hip_Hop_Afloat_Pad
4,Hip Hop,Synthetic Bass,res/mp3_data/01 Hip Hop/Afloat Sub Bass.mp3,01_Hip_Hop_Afloat_Sub_Bass


# Dataset

In [6]:
encoded_dataset_path = os.path.join(DATASETS_DIR_PATH, f"encoded-{model_id}")
encoded_dataset_path

'res/datasets/encoded-facebook-wav2vec2-base'

In [7]:
encoded_ds = None
if os.path.exists(encoded_dataset_path):
    ds = datasets.load_from_disk(encoded_dataset_path)
    encoded_ds = add_audio_column(ds)
else:
    feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
    ds = get_dataset(df)
    print("Splitting dataset into train and test")
    ds = ds.train_test_split(test_size=TEST_SIZE)
    ds = add_audio_column(ds)
    print("Applying preprocessing to dataset")
    encoded_ds = ds.map(get_preprocess_func(feature_extractor), batched=True)
    encoded_ds.save_to_disk(encoded_dataset_path)

encoded_ds

DatasetDict({
    train: Dataset({
        features: ['id', 'audio_path', 'genre', 'category', 'input_values', 'audio'],
        num_rows: 21988
    })
    test: Dataset({
        features: ['id', 'audio_path', 'genre', 'category', 'input_values', 'audio'],
        num_rows: 5497
    })
})

In [8]:
from src.utils import play_random_audios

play_random_audios(encoded_ds["train"], 3, print_features=["id", "genre"])

id: 15_Reggaeton_Pop_Latin_Pop_FX_02 - genre: Reggaeton Pop


id: Jam_Pack_Symphony_Orchestra_Laureate_All - genre: Orchestral


id: Boys_Noize_Circuit_Pressure_Beat_02 - genre: Electro House


# Training

In [9]:
prepared_ds = prepare_ds(encoded_ds, df, TARGET_FEATURE)
prepared_ds

Casting the dataset: 100%|██████████| 13/13 [00:14<00:00,  1.10s/ba]
Casting the dataset: 100%|██████████| 4/4 [00:05<00:00,  1.47s/ba]


DatasetDict({
    train: Dataset({
        features: ['id', 'label', 'input_values'],
        num_rows: 12450
    })
    test: Dataset({
        features: ['id', 'label', 'input_values'],
        num_rows: 3129
    })
})

In [10]:
class_feature = prepared_ds["train"].features["label"]
l2i, i2l = create_label_maps(class_feature)

l2i, i2l

({'Hip Hop': 0,
  'Electronic/Dance': 1,
  'Rock/Blues': 2,
  'World/Ethnic': 3,
  'Orchestral': 4},
 {0: 'Hip Hop',
  1: 'Electronic/Dance',
  2: 'Rock/Blues',
  3: 'World/Ethnic',
  4: 'Orchestral'})

In [None]:
model = AutoModelForAudioClassification.from_pretrained(
    "facebook/wav2vec2-base",
    num_labels=class_feature.num_classes,
    label2id=l2i,
    id2label=i2l,
)

In [None]:
training_args = TrainingArguments(
    run_name=run_name,
    output_dir="out",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=256,
    per_device_eval_batch_size=512,
    num_train_epochs=10,
    logging_steps=50,
)

In [None]:
import wandb

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=feature_extractor,
    compute_metrics=get_metrics_func(),
)

trainer.train()
wandb.finish()

In [None]:
trainer.save_model(os.path.join(MODELS_DIR_PATH, run_name))