In [None]:
!pip install -q torch torchaudio timm

In [None]:
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
model = AutoModelForAudioClassification.from_pretrained(
        "MIT/ast-finetuned-audioset-10-10-0.4593",
        ignore_mismatched_sizes=True,
        num_labels=10
).cuda()

feat_extractor = AutoFeatureExtractor.from_pretrained(
        "MIT/ast-finetuned-audioset-10-10-0.4593")

In [None]:
import torchaudio, torch, pandas as pd, os
ROOT = "/kaggle/input/urbansound8k"

meta = pd.read_csv(f"{ROOT}/UrbanSound8K.csv")
id2label = {i:c for i,c in enumerate(sorted(meta['class'].unique()))}
label2id = {c:i for i,c in id2label.items()}
meta["label_id"] = meta["class"].map(label2id)

def load_resample(path):
    wav, sr = torchaudio.load(path)
    wav = wav.mean(0)          # mono
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    return wav

In [None]:
class UrbanSoundHF(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
        self.target_len = 16000 * 10.24  # 163840 samples = 10.24 seconds

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_path = f"{ROOT}/fold{row.fold}/{row.slice_file_name}"
        waveform, sr = torchaudio.load(file_path)
        waveform = waveform.mean(0)  # mono

        # Resample to 16kHz if needed
        if sr != 16000:
            waveform = torchaudio.functional.resample(waveform, sr, 16000)

        # Pad or trim to target length (163840 samples)
        if waveform.shape[0] < self.target_len:
            pad_len = int(self.target_len - waveform.shape[0])
            waveform = torch.nn.functional.pad(waveform, (0, pad_len))
        else:
            waveform = waveform[:int(self.target_len)]

        return waveform, row.label_id

In [None]:
dataset = UrbanSoundHF(meta)
for i in range(5):
    wav, label = dataset[i]
    print(f"Sample {i}: shape = {wav.shape}, label = {label}")

In [None]:
def collate(batch):
    wavs, labels = zip(*batch)

    for i, w in enumerate(wavs):
        if len(w) < 400:
            print(f"❗ Short waveform at index {i}: len = {len(w)}")
        if not isinstance(w, torch.Tensor):
            print(f"❗ Not a tensor at index {i}: type = {type(w)}")

    # 🛠 Convert to list of float32 numpy arrays for HF feature_extractor
    wavs = [w.numpy().astype("float32") for w in wavs]

    inputs = feat_extractor(wavs, sampling_rate=16000, return_tensors="pt", padding=True)
    return {**inputs, "labels": torch.tensor(labels)}

In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn.functional import cross_entropy
from tqdm import tqdm
from torch import nn

all_acc = []
fold_accuracies = [] 

for fold in range(1, 11):
    train_df = meta[meta.fold != fold]
    val_df   = meta[meta.fold == fold]

    train_dl = DataLoader(UrbanSoundHF(train_df), batch_size=4,
                          shuffle=True, collate_fn=collate, num_workers=4)
    val_dl   = DataLoader(UrbanSoundHF(val_df), batch_size=4,
                          shuffle=False, collate_fn=collate, num_workers=4)
    
    # fresh classification head each fold
    try:
        in_features = model.classifier.dense.in_features  # original
    except AttributeError:
        in_features = model.classifier[0].in_features   # ASTMLPHead uses a "dense" layer internally
    model.classifier = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.Tanh(),
        nn.Dropout(0.1),
        nn.Linear(512, 10)
    ).to("cuda")

    optim = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

    best = 0
    for epoch in range(8):
        model.train()
        for batch in tqdm(train_dl, desc=f"Fold {fold} Epoch {epoch}", leave=False):
            batch = {k: v.to("cuda") for k, v in batch.items()}
            outs = model(**batch)
            loss = outs.loss
            loss.backward(); optim.step(); optim.zero_grad()
    
        # Validation
        model.eval(); correct = 0
        with torch.no_grad():
            for batch in val_dl:
                batch = {k: v.to("cuda") for k, v in batch.items()}
                logits = model(**batch).logits
                preds = logits.argmax(1)
                correct += (preds == batch["labels"]).sum().item()
        acc = correct / len(val_dl.dataset)
        fold_accuracies.append(acc)
        print(f"Fold {fold} Epoch {epoch} Accuracy: {acc:.4f}")

all_acc.append(fold_accuracies)
print(f"\n10-fold mean accuracy: {sum(all_acc)/10:.4f}")


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
for i, acc_list in enumerate(all_acc, 1):
    plt.plot(acc_list, label=f"Fold {i}")

plt.title("Validation Accuracy vs. Epochs (Per Fold)")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
torch.save(model.state_dict(), "/kaggle/working/ast_urbansound8k_finetuned.pth")