In [7]:
# -----------------------------
# 1. Imports
# -----------------------------
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk  #
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification #
from functools import partial #
import torchaudio.transforms as T
import torchmetrics
from tqdm import tqdm
import numpy as np
import random
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import os

# -----------------------------
# 2. Configuration & Setup
# -----------------------------
SEED = 42
BATCH_SIZE = 32
NUM_WORKERS = 24
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-4
MODEL_SAVE_PATH = "./best_finetuned_model.pt"
EPOCHS = 10

MODEL_CHECKPOINT = "MIT/ast-finetuned-speech-commands-v2"

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

print(f"Using device: {DEVICE}")
print(f"Number of workers: {NUM_WORKERS}")

Using device: cuda
Number of workers: 24


In [8]:
print(f"Loading feature extractor for: {MODEL_CHECKPOINT}")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_CHECKPOINT)

print(f"Loading model for: {MODEL_CHECKPOINT}")
# We set num_labels=1 to match your original BCEWithLogitsLoss setup
model = AutoModelForAudioClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=1,
    ignore_mismatched_sizes=True # Allow replacing the head
).to(DEVICE)

Loading feature extractor for: MIT/ast-finetuned-speech-commands-v2
Loading model for: MIT/ast-finetuned-speech-commands-v2


Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-speech-commands-v2 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([35]) in the checkpoint and torch.Size([1]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([35, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# -----------------------------
# 3. Load Dataset
# -----------------------------
# Assuming the dataset is in a directory relative to the notebook
# Using os.path.join for better cross-platform compatibility
# DATASET_PATH = "../dataset/ds_2_noaugment_test.hf"
DATASET_PATH = "../../dataset/ds_3_raw_chunked.hf"
ONLINE_PATH = None

if ONLINE_PATH is None:
    if not os.path.exists(DATASET_PATH):
        print(f"Error: Dataset path not found at {DATASET_PATH}")
        print("Please update the DATASET_PATH variable to point to your dataset directory.")
    else:
        dataset = load_from_disk(DATASET_PATH)
        print("Dataset loaded successfully.")

        # Set format to PyTorch
        dataset = dataset.with_format("torch", columns=["audio", "label"])

        print("\nDataset splits:")
        print({k: v.shape for k, v in dataset.items()})

        print("\nDataset features:")
        print(dataset["train"].features)

Dataset loaded successfully.

Dataset splits:
{'train': (617500, 2), 'val': (78726, 2), 'test': (78483, 2)}

Dataset features:
{'audio': List(Value('float32')), 'label': ClassLabel(names=['other', 'drone'])}


In [15]:
spec_augmentations = nn.Sequential(
    T.FrequencyMasking(freq_mask_param=5),
    T.TimeMasking(time_mask_param=5)
).to(DEVICE)

def collate_fn_finetune(batch, augment=False):
    # 1. Extract audio arrays and labels from the batch
    xs = [b["audio"] for b in batch]
    ys = [b["label"] for b in batch]

    inputs = feature_extractor(
        xs,
        sampling_rate=feature_extractor.sampling_rate,
        padding="max_length", # Pad to model's max (e.g., 10s)
        truncation=True,
        # max_length=16000 * 10, # Explicitly set to 10s
        return_tensors="pt"
    )

    # 'input_values' are the spectrograms: (B, N_MELS, N_FRAMES)
    # e.g., (32, 128, 1024) for AST
    input_values = inputs["input_values"]

    labels = torch.tensor(ys, dtype=torch.float32).unsqueeze(1)

    return input_values, labels

# --- Create partial functions for train/validation ---
train_collate_fn = partial(collate_fn_finetune, augment=True)
val_collate_fn = partial(collate_fn_finetune, augment=False)

In [16]:
train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, collate_fn=train_collate_fn, pin_memory=True)
valid_loader = DataLoader(dataset["val"], batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, collate_fn=val_collate_fn, pin_memory=True)
test_loader = DataLoader(dataset["test"], batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, collate_fn=val_collate_fn, pin_memory=True)

print(f"\nCreated DataLoaders with Batch Size: {BATCH_SIZE}")

# Check a sample batch
try:
    sample_x, sample_y = next(iter(train_loader))
    print(f"Sample batch shape - X (Spectrograms): {sample_x.shape}, Y (Labels): {sample_y.shape}")
except Exception as e:
    print(f"Could not load a sample batch: {e}")
    print("Check your NUM_WORKERS setting or collate_fn.")


Created DataLoaders with Batch Size: 32


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7541d30d60>Exception ignored in: 
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f7541d30d60>
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
Traceback (most recent call last):
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
        if w.is_alive():if w.is_alive():

  File "/home/pierre/Applications/miniconda3/lib/python3.13/multiprocessing/pr

Could not load a sample batch: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ~~~~~~~~~~~~~~~^^^^^^
  File "/tmp/ipykernel_52595/1621439837.py", line 11, in collate_fn_finetune
    inputs = feature_extractor(
        xs,
    ...<4 lines>...
        return_tensors="pt"
    )
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py", line 219, in __call__
    features = [self._extract_fbank_features(waveform, max_length=self.m

    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7541d30d60>
Traceback (most recent call last):
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
  File "/home/pierre/Applications/miniconda3/lib/python3.13/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7541d30d60>
Traceback (most recent call last):
  File "/home/pierre/Documents/Projects/PST4/AI/.venv/lib/p