In [None]:

!pip -q install torch torchaudio --upgrade
!pip -q install matplotlib ipywidgets

import os, math, random, numpy as np, torch, torchaudio
from torch import nn
from torch.utils.data import DataLoader
from torchaudio.datasets import SPEECHCOMMANDS
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython.display import Audio, display

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
print("✅ Ready on", device)


In [None]:

DATA_ROOT = "./data"
TARGET_LABELS = ["yes", "no", "up", "down"]  # keep small for speed
SAMPLE_RATE = 16000
DURATION = 1.0  # seconds
NUM_SAMPLES = int(SAMPLE_RATE * DURATION)

class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__(DATA_ROOT, download=True)
        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f: return [os.path.normpath(os.path.join(self._path, ln.strip())) for ln in f]
        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = set(load_list("validation_list.txt") + load_list("testing_list.txt"))
            self._walker = [w for w in self._walker if w not in excludes]

    def __getitem__(self, n):
        item = super().__getitem__(n)  # (waveform, sample_rate, label, speaker_id, utterance_number)
        waveform, sr, label, *_ = item
        # keep only target labels; map others to 'other' or skip
        if label not in TARGET_LABELS:
            return None
        return (waveform, sr, label)

def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0: return None
    waves, srs, labels = zip(*batch)
    # pad/crop to NUM_SAMPLES and resample if needed
    out = []
    for w, sr in zip(waves, srs):
        if sr != SAMPLE_RATE:
            w = torchaudio.functional.resample(w, sr, SAMPLE_RATE)
        if w.shape[1] < NUM_SAMPLES:
            pad = NUM_SAMPLES - w.shape[1]
            w = F.pad(w, (0, pad))
        else:
            w = w[:, :NUM_SAMPLES]
        out.append(w)
    x = torch.stack(out, dim=0)          # [B, 1, T]
    y = torch.tensor([TARGET_LABELS.index(l) for l in labels], dtype=torch.long)
    return x, y

train_set = SubsetSC("training")
val_set   = SubsetSC("validation")
test_set  = SubsetSC("testing")

train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set, batch_size=64, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_set, batch_size=64, shuffle=False, collate_fn=collate_fn)

print("Batches:", len(train_loader), len(val_loader), len(test_loader))


In [None]:
# @title 🧠 CNN audio classifier (log-mel inside forward)
N_MELS = 64

class AudioCNN(nn.Module):
    def __init__(self, n_classes=len(TARGET_LABELS)):
        super().__init__()
        self.melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE, n_fft=1024, hop_length=256, n_mels=N_MELS
        )
        self.amplog = torchaudio.transforms.AmplitudeToDB(stype='power')

        self.net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d((2,2)),
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d((2,2)),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.AdaptiveMaxPool2d((8,8)),
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*8*8, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, n_classes)
        )

    def forward(self, waveform):  # waveform: [B,1,T]
        with torch.no_grad():  # Mel layers are differentiable; set no_grad=False if you want to backprop through them
            m = self.melspec(waveform)       # [B,1, n_mels, time]
            m = self.amplog(m)
        z = self.net(m)
        return self.head(z)

model = AudioCNN().to(device)
print(model)


In [None]:
# @title 🚂 Train for a few epochs
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
criterion = nn.CrossEntropyLoss()

def run_epoch(loader, train=True):
    model.train(train)
    total, correct, loss_sum = 0, 0, 0.0
    for batch in loader:
        if batch is None: continue
        x, y = batch
        x, y = x.to(device), y.to(device)
        if train:
            optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        if train:
            loss.backward()
            optimizer.step()
        loss_sum += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return loss_sum/total, correct/total

EPOCHS = 3
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_loader, True)
    va_loss, va_acc = run_epoch(val_loader, False)
    print(f"Epoch {ep}: train loss {tr_loss:.4f} acc {tr_acc:.3f} | val loss {va_loss:.4f} acc {va_acc:.3f}")


In [None]:

test_loss, test_acc = run_epoch(test_loader, False)
print(f"Clean Test: loss {test_loss:.4f} | acc {test_acc:.3f}")


In [None]:

def fgsm_attack(model, waveforms, labels, epsilon=0.002):
    # waveforms: [B,1,T], requires_grad for FGSM
    wave = waveforms.clone().detach().to(device)
    wave.requires_grad_(True)

    logits = model(wave)
    loss = F.cross_entropy(logits, labels.to(device))
    model.zero_grad()
    loss.backward()

    # sign of gradient
    grad_sign = wave.grad.data.sign()
    adv_wave = wave + epsilon * grad_sign
    adv_wave = torch.clamp(adv_wave, -1.0, 1.0).detach()
    return adv_wave

def eval_under_attack(model, loader, epsilon=0.002, max_batches=30):
    model.eval()
    total, correct = 0, 0
    for b_idx, batch in enumerate(loader):
        if b_idx >= max_batches: break
        if batch is None: continue
        x, y = batch
        x, y = x.to(device), y.to(device)

        adv = fgsm_attack(model, x, y, epsilon=epsilon)
        with torch.no_grad():
            logits = model(adv)
            pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return correct/total

for eps in [0.0, 0.0005, 0.001, 0.002, 0.004]:
    if eps == 0.0:
        acc = eval_under_attack(model, test_loader, epsilon=0.0)  # forwards only
    else:
        acc = eval_under_attack(model, test_loader, epsilon=eps)
    print(f"Epsilon {eps:.4f} → accuracy {acc:.3f}")


In [None]:
# @title 🎧 Inspect one sample (audio + spectrograms)
def show_mel(wave, sr=SAMPLE_RATE, title="Mel"):
    melspec = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=1024, hop_length=256, n_mels=64)(wave)
    mlog = torchaudio.transforms.AmplitudeToDB(stype='power')(melspec)
    plt.imshow(mlog.squeeze(0).cpu().numpy(), origin='lower', aspect='auto')
    plt.title(title); plt.xlabel("Frames"); plt.ylabel("Mel bins"); plt.colorbar(); plt.show()

batch = next(iter(test_loader))
x, y = batch
x0 = x[0:1].to(device)
y0 = y[0:1].to(device)

with torch.no_grad():
    pred_clean = model(x0).argmax(1).item()

x_adv = fgsm_attack(model, x0, y0, epsilon=0.002)
with torch.no_grad():
    pred_adv = model(x_adv).argmax(1).item()

print("True label:", TARGET_LABELS[y0.item()])
print("Pred (clean):", TARGET_LABELS[pred_clean], " | Pred (FGSM):", TARGET_LABELS[pred_adv])

# waveforms to CPU numpy
w_clean = x0[0,0].detach().cpu().numpy()
w_adv   = x_adv[0,0].detach().cpu().numpy()

# Play audio (might be subtle; headphones recommended)
display(Audio(w_clean, rate=SAMPLE_RATE))
display(Audio(w_adv,   rate=SAMPLE_RATE))

# Visualize mel
plt.figure(figsize=(10,3)); plt.plot(w_clean); plt.title("Waveform (clean)"); plt.show()
plt.figure(figsize=(10,3)); plt.plot(w_adv);   plt.title("Waveform (adversarial)"); plt.show()
show_mel(x0[0].cpu())
show_mel(x_adv[0].cpu(), title="Mel (adversarial)")
