In [1]:
import os
import math
from pathlib import Path
from pprint import pprint


import numpy as np
import torch


from datasets import load_dataset, Audio, ClassLabel
from transformers import (
ASTFeatureExtractor,
ASTConfig,
ASTForAudioClassification,
TrainingArguments,
Trainer,
)


import evaluate
from audiomentations import Compose, AddGaussianSNR, GainTransition, Gain, ClippingDistortion, TimeStretch, PitchShift


# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# Where to save outputs
OUT_DIR = Path("./runs/ast_esc50")
OUT_DIR.mkdir(parents=True, exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [2]:
print("Loading dataset ashraq/esc50 from the Hub...")
dataset = load_dataset("ashraq/esc50", split="train")
print(dataset)

Loading dataset ashraq/esc50 from the Hub...


Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take', 'audio'],
    num_rows: 2000
})


In [3]:
import numpy as np


# fetch mapping from the dataset
# create class names ordered by label id
df = dataset.select_columns(["target", "category"]).to_pandas()
unique_idx = np.unique(df["target"], return_index=True)[1]
class_names = df.iloc[unique_idx]["category"].tolist()

In [4]:
num_labels = len(class_names)
print(f"Found {num_labels} classes")


from datasets import Features, Value


# rename & cast
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.rename_column("target", "labels")


# Convert labels to ClassLabel so Trainer can understand
label_feature = ClassLabel(names=class_names)
dataset = dataset.cast_column("labels", label_feature)


# Build label2id and id2label
label2id = {name: i for i, name in enumerate(class_names)}
id2label = {i: name for name, i in label2id.items()} # note: label_feature already has mapping


print(label2id)

Found 50 classes
{'dog': 0, 'rooster': 1, 'pig': 2, 'cow': 3, 'frog': 4, 'cat': 5, 'hen': 6, 'insects': 7, 'sheep': 8, 'crow': 9, 'rain': 10, 'sea_waves': 11, 'crackling_fire': 12, 'crickets': 13, 'chirping_birds': 14, 'water_drops': 15, 'wind': 16, 'pouring_water': 17, 'toilet_flush': 18, 'thunderstorm': 19, 'crying_baby': 20, 'sneezing': 21, 'clapping': 22, 'breathing': 23, 'coughing': 24, 'footsteps': 25, 'laughing': 26, 'brushing_teeth': 27, 'snoring': 28, 'drinking_sipping': 29, 'door_wood_knock': 30, 'mouse_click': 31, 'keyboard_typing': 32, 'door_wood_creaks': 33, 'can_opening': 34, 'washing_machine': 35, 'vacuum_cleaner': 36, 'clock_alarm': 37, 'clock_tick': 38, 'glass_breaking': 39, 'helicopter': 40, 'chainsaw': 41, 'siren': 42, 'car_horn': 43, 'engine': 44, 'train': 45, 'church_bells': 46, 'airplane': 47, 'fireworks': 48, 'hand_saw': 49}


In [5]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)
model_input_name = feature_extractor.model_input_names[0] # usually 'input_values'
SAMPLING_RATE = feature_extractor.sampling_rate
print("Feature extractor sampling rate:", SAMPLING_RATE)

Feature extractor sampling rate: 16000


In [6]:
# 5) Optional: compute dataset mean/std for normalization
# This can be slow. We sample or process full train set depending on resource availability.


def estimate_mean_std(ds, num_samples=None):
    """Estimate mean and std of spectrogram inputs produced by feature_extractor.
    We process raw waveforms through the feature_extractor WITHOUT normalization and compute mean/std of resulting tensors.
    """
    print("Estimating mean/std on dataset...")
    # Temporarily disable normalization in the extractor
    original_do_normalize = feature_extractor.do_normalize
    feature_extractor.do_normalize = False


    indices = list(range(len(ds))) if num_samples is None else np.linspace(0, len(ds)-1, num=num_samples, dtype=int).tolist()


    means = []
    stds = []
    for idx in indices:
        item = ds[int(idx)]
        wav = item["audio"]["array"]
        # feature_extractor accepts python lists of arrays
        inputs = feature_extractor([wav], sampling_rate=SAMPLING_RATE, return_tensors="pt")
        # model input tensor: shape (1, n_mels, time) or similar; flatten to compute global mean/std
        arr = inputs.get(model_input_name).squeeze().numpy()
        means.append(np.mean(arr))
        stds.append(np.std(arr))


    feature_extractor.do_normalize = original_do_normalize
    mean = float(np.mean(means))
    std = float(np.mean(stds))
    print(f"Estimated mean={mean:.6f}, std={std:.6f}")
    return mean, std


# Estimate on a subset for speed (e.g., 200 samples). Increase for better estimate.
mean, std = estimate_mean_std(dataset, num_samples=200)
feature_extractor.mean = mean
feature_extractor.std = std
feature_extractor.do_normalize = True

Estimating mean/std on dataset...
Estimated mean=-3.365088, std=4.383699


In [7]:
import torch


def preprocess_audio(batch):
    # batch is a dictionary with 'audio' and 'labels'
    wavs = [a["array"] for a in batch["audio"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    # return as numpy to keep datasets happy; Trainer/Collator will convert to tensors
    return {model_input_name: inputs.get(model_input_name), "labels": batch["labels"]}

In [37]:
audio_augmentations = Compose([
    AddGaussianSNR(min_snr_db=10, max_snr_db=20, p=0.5),
    Gain(min_gain_db=-6, max_gain_db=6, p=0.5) if 'Gain' in globals() else Gain(min_gain_db=-6, max_gain_db=6),
    GainTransition(min_gain_db=-6, max_gain_db=6, min_duration=0.01, max_duration=0.3, duration_unit="fraction", p=0.2),
    ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=0.3),
    TimeStretch(min_rate=0.8, max_rate=1.2, p=0.3),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.3),
], p=0.8, shuffle=True)




def preprocess_audio_with_transforms(batch):
    print(batch)
    wavs = [audio_augmentations(w["array"], sample_rate=SAMPLING_RATE) for w in batch["audio"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    return {model_input_name: inputs.get(model_input_name), "labels": batch["labels"]}

In [9]:
# split if 'test' not present
if "test" not in dataset.features:
    dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=42, stratify_by_column="labels")


    print(dataset)

DatasetDict({
    train: Dataset({
        features: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'audio'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'audio'],
        num_rows: 400
    })
})


In [32]:
dataset["train"].set_transform(preprocess_audio_with_transforms, output_all_columns=False)
dataset["test"].set_transform(preprocess_audio, output_all_columns=False)

In [33]:
print(dataset["train"][0]["input_values"])

tensor([[-0.3840, -0.5803, -0.1875,  ..., -0.1889, -0.1511, -0.1185],
        [-0.8791, -0.4846, -0.0918,  ..., -0.0967, -0.1159, -0.2865],
        [-0.4386, -0.4456, -0.0529,  ..., -0.2288, -0.1986, -0.2741],
        ...,
        [ 0.3838,  0.3838,  0.3838,  ...,  0.3838,  0.3838,  0.3838],
        [ 0.3838,  0.3838,  0.3838,  ...,  0.3838,  0.3838,  0.3838],
        [ 0.3838,  0.3838,  0.3838,  ...,  0.3838,  0.3838,  0.3838]])


In [34]:
config = ASTConfig.from_pretrained(pretrained_model)
config.num_labels = num_labels
config.label2id = label2id
config.id2label = id2label


model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)
# Initialize classifier weights (optional, some weights already re-initialized by ignore_mismatched_sizes)
model.init_weights()
model.to(device)

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ASTForAudioClassification(
  (audio_spectrogram_transformer): ASTModel(
    (embeddings): ASTEmbeddings(
      (patch_embeddings): ASTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ASTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ASTLayer(
          (attention): ASTAttention(
            (attention): ASTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ASTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ASTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=T

In [35]:
training_args = TrainingArguments(
    output_dir=str(OUT_DIR),
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=5e-5,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="no",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=torch.cuda.is_available(),
)


# prepare metrics
accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")
recall = evaluate.load("recall")
f1 = evaluate.load("f1")
AVERAGE = "macro" if num_labels > 2 else "binary"


import numpy as np


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=preds, references=labels)
    metrics.update(precision.compute(predictions=preds, references=labels, average=AVERAGE))
    metrics.update(recall.compute(predictions=preds, references=labels, average=AVERAGE))
    metrics.update(f1.compute(predictions=preds, references=labels, average=AVERAGE))
    return metrics


# The default data collator won't accept torch tensors returned in transforms; use a custom collator
from dataclasses import dataclass
from typing import Any, Dict


@dataclass
class DataCollatorWithPadding:
    feature_extractor: ASTFeatureExtractor


    def __call__(self, features: Any) -> Dict[str, torch.Tensor]:
        print("HERRRE")
        # features is a list of dicts with keys: model_input_name and 'labels'
        input_values = [f[model_input_name].squeeze(0) if isinstance(f[model_input_name], torch.Tensor) else torch.tensor(f[model_input_name]) for f in features]
        # batch using the extractor's pad function
        batch = self.feature_extractor.pad({model_input_name: input_values}, return_tensors="pt")
        labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
        batch["labels"] = labels
        return batch


collator = DataCollatorWithPadding(feature_extractor=feature_extractor)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,
    data_collator=collator,
)

In [38]:
train_result = trainer.train()
metrics = train_result.metrics
trainer.save_state()


# final evaluation
eval_metrics = trainer.evaluate(eval_dataset=dataset["test"])
print("Eval metrics:")
print(eval_metrics)

KeyError: 'audio'