In [2]:
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Trainer, TrainingArguments
from datasets import load_dataset, Audio
import numpy as np
import torch
import json
import evaluate

In [3]:
with open("../config.json", mode = "r") as f:
    data = json.load(f)
    SAMPLING_RATE = data["sampling_rate"]
    SEGMENT_LEN = data["segment_length"]
    OVERLAP_LEN = data["overlap_length"]

In [None]:
extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h")

In [None]:
dataset = load_dataset("Saads/xecanto_birds", split = "train")
dataset = dataset.class_encode_column("common_name")
dataset = dataset.shuffle(seed = 42)

In [None]:
dataset[10]

In [None]:
labels = dataset.features["common_name"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
dataset = dataset.train_test_split(test_size = 0.2)
dataset

In [None]:
dataset = dataset.remove_columns([
    "primary_label",
    "secondary_labels",
    "scientific_name",
    "author",
    "license",
    "rating",
    "type",
    "latitude",
    "longitude",
    "url"
])
dataset

In [None]:
dataset["train"][0:3]

### Chunking

In [None]:
def chunk_audio(audio_array, chunk_length = SEGMENT_LEN, overlap = OVERLAP_LEN):
    chunk_length = chunk_length * SAMPLING_RATE
    overlap = overlap * SAMPLING_RATE
    
    chunks = []
    start = 0
    while start + chunk_length <= len(audio_array):
        chunks.append(audio_array[start : start + chunk_length])
        start += (chunk_length - overlap)
    
    if start < len(audio_array):
        last_chunk = audio_array[start:]
        padded_last_chunk = np.pad(last_chunk, (0, chunk_length - len(last_chunk)))
        chunks.append(padded_last_chunk)
    
    return chunks

In [None]:
def preprocess(row):
    chunked_batched_data = {}
    chunks = chunk_audio(row["audio"]["array"])
    inputs = extractor(chunks, sampling_rate = SAMPLING_RATE)
    chunked_batched_data["input_values"] = inputs["input_values"]
    chunked_batched_data["common_name"] = [row["common_name"]] * len(chunks)
    return chunked_batched_data    

In [None]:
def concate(batch):
   return {
       "concate_input_values": [chunk for chunks in batch["input_values"] for chunk in chunks],
       "chunked_common_name": [label for chunks in batch["common_name"] for label in chunks]
   }

In [None]:
dataset = dataset.cast_column("audio", Audio(sampling_rate = SAMPLING_RATE))
dataset = dataset.map(
    preprocess,
    remove_columns = "audio",
    batched = False,
    num_proc = 16,
    writer_batch_size = 200
)
dataset = dataset.map(
    concate,
    remove_columns = ["input_values", "common_name"],
    batched = True,
    batch_size = 16,
    num_proc = 16,
    writer_batch_size = 100
)
len(dataset["train"][0]["concate_input_values"])

In [None]:
dataset = dataset.rename_column("concate_input_values", "input_values")
dataset = dataset.rename_column("chunked_common_name", "label")
dataset = dataset.shuffle(seed = 42)

In [None]:
model = AutoModelForAudioClassification.from_pretrained(
    "facebook/wav2vec2-large-960h",
    num_labels = len(id2label),
    label2id = label2id,
    id2label = id2label
)

In [None]:
accuracy = evaluate.load("accuracy")

In [None]:
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis = 1)
    return accuracy.compute(predictions = predictions, references = eval_pred.label_ids)

In [None]:
training_args = TrainingArguments(
    output_dir = "checkpoints-10-2",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate = 3e-5,
    per_device_train_batch_size = 8,
    # gradient_accumulation_steps = 4,
    per_device_eval_batch_size = 8,
    num_train_epochs = 7,
    # warmup_ratio = 0.1,
    logging_steps = 10,
    load_best_model_at_end = True,
    metric_for_best_model = "eval_loss",
    fp16 = True
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = dataset["train"],
    eval_dataset = dataset["test"],
    processing_class = extractor,
    compute_metrics = compute_metrics
)

trainer.train()