In [None]:
!pip install datasets==2.16.0
!pip install huggingface-hub==0.20.0
!apt-get install -y libsox-dev
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip install causal-conv1d==1.4.0 && pip install mamba-ssm==2.2.2

Collecting datasets==2.16.0
  Downloading datasets-2.16.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.16.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets==2.16.0)
  Downloading dill-0.3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting fsspec<=2023.10.0,>=2023.1.0 (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets==2.16.0)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
Collecting multiprocess (from datasets==2.16.0)
  Downloading multiprocess-0.70.18-py312-none-any.whl.metadata (7.5 kB)
  Downloading multiprocess-0.70.17-py312-none-any.whl.metadata (7.2 kB)
  Downloading multiprocess-0.70.15-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.16.0-py3-none-any.whl (507 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Keyword‑Spotting with Mamba SSM
# =========================================================
# Key differences vs. the previous MobileNet script
# ------------------------------------------------
# 1. Uses Mamba state‑space blocks instead of a CNN.
# 2. Maintains separate front‑ends for train vs. eval (no masks in eval).
# 3. Handles variable‑length sequences with a custom `collate_fn` (pads –80 dB).
# 4. One‑Cycle LR scheduler stepped correctly each batch.
# 5. Saves best + final checkpoints to Google Drive.


from __future__ import annotations
import json, os, random
from pathlib import Path
from typing import Tuple, Dict
import torch, torchaudio
import torch.nn as nn
import os
import math
import torch.nn.functional as F
import torchvision.models as tvm
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tqdm.notebook import tqdm
from mamba_ssm import Mamba

from google.colab import drive
drive.mount('/content/drive')

  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


Mounted at /content/drive


In [None]:
# ---------------------------------------------------------------------
# 2. Waveform-level augmentation (shift + noise)
# ---------------------------------------------------------------------
class WaveToSpec:
    def __init__(self,
                 feature_type: str = "mel",
                 sample_rate: int = 16_000,
                 n_fft: int = 2048,
                 hop_length: int = 256,
                 n_mels: int = 128,
                 n_mfcc: int = 40,
                 top_db: int | None = 80,
                 apply_mask: bool = True,
                 freq_mask_param: int = 15,
                 time_mask_param: int = 10):
        self.feature_type = feature_type.lower(); assert self.feature_type in {"mel","mfcc"}
        self.apply_mask = apply_mask and self.feature_type == "mel"

        if self.feature_type == "mel":
            self.spec = T.MelSpectrogram(sample_rate, n_fft, hop_length, n_mels, power=2)
            self.to_db = T.AmplitudeToDB(stype="power", top_db=top_db)
            if self.apply_mask:
                self.freq_mask = T.FrequencyMasking(freq_mask_param)
                self.time_mask = T.TimeMasking(time_mask_param)
        else:
            self.spec = T.MFCC(sample_rate, n_mfcc,
                                melkwargs=dict(n_fft=n_fft, hop_length=hop_length, n_mels=n_mels))
            self.to_db = None
            self.freq_mask = self.time_mask = None

    def __call__(self, wav: torch.Tensor) -> torch.Tensor:
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
        feats = self.spec(wav)
        if self.to_db is not None:
            feats = self.to_db(feats.clamp(min=1e-10))
        if self.apply_mask:
            # stronger SpecAugment — apply two freq & two time masks
            feats = self.freq_mask(feats); feats = self.time_mask(feats)
            feats = self.freq_mask(feats); feats = self.time_mask(feats)
        return feats

class Augment:
    def __init__(self, stretch: Tuple[float,float]=(1.0,1.0),
                 shift_ms: int = 100,
                 noise: Tuple[float,float]=(0.,0.005),
                 sr: int = 16_000):
        self.stretch = stretch
        self.shift   = int(shift_ms * sr / 1000)
        self.noise   = noise
        self.sr      = sr

    def _shift(self, x: torch.Tensor):
        if self.shift == 0:
            return x
        s = int(torch.randint(-self.shift, self.shift + 1, ()).item())
        if s == 0:
            return x
        return (F.pad(x, (s, 0))[:, :-s] if s > 0 else F.pad(x, (0, -s))[:, -s:])

    def __call__(self, wav: torch.Tensor):
        squeezed = False
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
            squeezed = True
        if self.stretch != (1.0, 1.0):
            factor = float(torch.empty(()).uniform_(*self.stretch))
            if abs(factor - 1.0) > 1e-3:
                wav, _ = torchaudio.sox_effects.apply_effects_tensor(
                    wav, self.sr, [["tempo", f"{factor}"]]
                )
        wav = self._shift(wav)
        if self.noise[1] > 0:
            sigma = float(torch.empty(()).uniform_(*self.noise))
            if sigma > 0:
                wav = wav + sigma * torch.randn_like(wav)
        return wav.squeeze(0) if squeezed else wav

In [None]:
# ---------------------------------------------------------------------
# 3. Dataset wrapper with dataset-level normalization
# ---------------------------------------------------------------------
class SpeechCommands(Dataset):
    def __init__(self, hf_split, aug: Augment | None, frontend: WaveToSpec,
                 wav_len: int = 16_000, mean: float = 0.0, std: float = 1.0):
        self.ds, self.aug, self.front = hf_split, aug, frontend
        self.wav_len = wav_len
        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        sample = self.ds[idx]
        wav = torch.from_numpy(sample["audio"]["array"]).float()

        if wav.numel() < self.wav_len:
            wav = F.pad(wav, (0, self.wav_len - wav.numel()))
        else:
            wav = wav[: self.wav_len]

        if self.aug:
            wav = self.aug(wav)
        feats = self.front(wav)           # [C=1, 40, ~55] for MFCC
        feats = (feats - self.mean) / (self.std + 1e-6)  # Normalize with precomputed stats
        feats = feats.squeeze(0).transpose(0, 1)  # [T=55, 40]
        return feats, sample["label"]

In [None]:
# ---------------------------------------------------------------------
# 4. Helper funcs = collate function + data mean/std + LR decay func
# ---------------------------------------------------------------------
from torch.nn.utils.rnn import pad_sequence

# Add padding to spectrograms if needed
def collate_fn(batch):
    feats, lbls = zip(*batch)
    # also return true lengths for mask-aware pooling
    lens = torch.tensor([f.size(0) for f in feats], dtype=torch.long)
    feats_padded = pad_sequence(feats, batch_first=True, padding_value=0.0)  # [B, T_max, F]
    return feats_padded, torch.tensor(lbls), lens


# Defines your learning rate schedule
def lr_lambda(step):
    ' ' 'Start with a very low learning rate and gradually increase it' ' '
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return max(0.003, 0.5 * (1.0 + math.cos(math.pi * progress)))


# Precompute dataset-level mean and std for MFCCs
def compute_dataset_stats(ds, frontend, wav_len=16_000):
    feats_all = []
    for sample in ds:
        wav = torch.from_numpy(sample["audio"]["array"]).float()
        if wav.numel() < wav_len:
            wav = F.pad(wav, (0, wav_len - wav.numel()))
        else:
            wav = wav[: wav_len]
        feats = frontend(wav).squeeze(0).transpose(0, 1)  # [T, 40]
        feats_all.append(feats)
    feats_all = torch.cat(feats_all, dim=0)
    return feats_all.mean().item(), feats_all.std().item()


In [None]:
# ---------------------------------------------------------------------
# 5. Mamba KWS model with normalization and residuals
# ---------------------------------------------------------------------
class MambaKWS(nn.Module):
    def __init__(self, num_classes: int, d_model=256, d_state=32, expand=2, n_layers=8, in_ch=1, feature_dim=128):
        super().__init__()

        # Convolutional embedding layer for feature extraction
        self.conv_embed = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.SiLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.SiLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.SiLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.SiLU(),
            nn.MaxPool2d((2, 1)),
        )

        # Calculate the flattened dimension after convolutions to project to d_model
        freq_dim_after_conv = feature_dim // 4
        flattened_dim = 64 * freq_dim_after_conv

        # Projection layer to map flattened conv features to Mamba's dimension
        self.proj = nn.Sequential(
            nn.Linear(flattened_dim, d_model),
            nn.LayerNorm(d_model),
            nn.SiLU(),
            nn.Dropout(0.1)
        )

        # Add Mamba blocks with layer norm and residuals
        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "norm": nn.LayerNorm(d_model),
                "mamba": Mamba(d_model=d_model, d_state=d_state, expand=expand),
                "dropout": nn.Dropout(max(0.02, 0.05 - (i * 0.005)))
            }) for i in range(n_layers)
        ])
        self.pre_classifier_norm = nn.LayerNorm(d_model)

        # Classifier head
        self.classifier_dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.SiLU(),
            nn.Dropout(0.05),
            nn.Linear(d_model // 2, num_classes)
        )

    # accept lengths for mask-aware pooling
    def forward(self, x, lengths: torch.Tensor | None = None):  # x: [B, T, F]
        # reshape for Conv2d: [B, T, F] -> [B, 1, F, T]
        x = x.permute(0, 2, 1).unsqueeze(1)

        # conv front-end
        x = self.conv_embed(x)                  # [B, 64, F', T']

        # flatten per time-step and project
        x = x.permute(0, 3, 1, 2).contiguous().flatten(2)  # [B, T', 64*F']
        x = self.proj(x)                                   # [B, T', d_model]

        for i, blk in enumerate(self.blocks):
            residual = x
            x = blk["norm"](x)
            x = blk["mamba"](x)
            x = blk["dropout"](x)
            x = residual + x

        x = self.pre_classifier_norm(x)

        # mask-aware mean pooling over time
        if lengths is not None:
            t_lens = torch.div(lengths, 2, rounding_mode='floor').clamp(min=1).to(x.device)  # first pool halves time
            Tprime = x.size(1)
            mask = (torch.arange(Tprime, device=x.device)[None, :] < t_lens[:, None]).float()  # [B, T']
            mask = mask.unsqueeze(-1)  # [B, T', 1]
            x_sum = (x * mask).sum(dim=1)                            # [B, d_model]
            denom = mask.sum(dim=1).clamp(min=1.0)                   # [B, 1]
            pooled = x_sum / denom
        else:
            pooled = x.mean(dim=1)

        main_output = self.classifier(self.classifier_dropout(pooled))
        return main_output

# ---------------------------------------------------------------------
@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    tot = correct = loss_sum = 0
    for batch in loader:
        xb, yb, lb = batch  # unpack lengths from loader
        xb, yb, lb = xb.to(device), yb.to(device), lb.to(device)
        logits = model(xb, lengths=lb)
        loss = criterion(logits, yb)
        loss_sum += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        tot += xb.size(0)
    return loss_sum / tot, 100 * correct / tot

In [None]:
# ---------------------------------------------------------------------
# 6. Main script
# ---------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = (device.type == "cuda")

    # ---- dataset
    ds = load_dataset("google/speech_commands", "v0.02")
    n_classes = len(ds["train"].features["label"].names)

    # log-mel + SpecAugment (train only)
    feature_type = "mel"        # "mel"/"mfcc"
    Epochs = 100
    base_lr = 5e-4
    warmup_frac = 0.12          # % of epochs warmup

    frontend_train = WaveToSpec(
        feature_type=feature_type,
        n_mfcc=40, n_mels=128,
        apply_mask=True,         # SpecAugment on train
        freq_mask_param=15,
        time_mask_param=25
    )
    frontend_eval = WaveToSpec(
        feature_type=feature_type,
        n_mfcc=40, n_mels=128,
        apply_mask=False
    )

    frontend_stats = WaveToSpec(feature_type=feature_type, n_mfcc=40, n_mels=128, apply_mask=False)

    # Waveform augs shift + a bit of noise
    aug = Augment(shift_ms=100, noise=(0., 0.01))

    # Normalization stats
    train_mean, train_std = compute_dataset_stats(ds["train"], frontend_stats)

    # Datasets
    train_ds = SpeechCommands(ds["train"], aug, frontend_train, mean=train_mean, std=train_std)
    val_ds   = SpeechCommands(ds["validation"], None, frontend_eval, mean=train_mean, std=train_std)
    test_ds  = SpeechCommands(ds["test"], None, frontend_eval, mean=train_mean, std=train_std)

    # Loaders
    dl_kwargs = dict(
        batch_size=64,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )
    train_dl = DataLoader(train_ds, shuffle=True, **dl_kwargs)
    val_dl   = DataLoader(val_ds, shuffle=False, **dl_kwargs)

    # Model
    model = MambaKWS(n_classes, d_model=192, d_state=16, n_layers=12).to(device)

    # Loss/opt/sched (per-batch schedule; short warmup; cosine with floor)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.07)
    opt = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=5e-4, betas=(0.9, 0.999))

    steps_per_epoch = len(train_dl)
    total_steps     = steps_per_epoch * Epochs
    warmup_steps    = int(total_steps * warmup_frac)
    sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)


    # Training loop with AMP, clipping, collapse guard, and checkpoints
    best_val_acc = 0.0
    prev_val_acc = 0.0
    BEST_PATH = Path("/content/best_kws.pt")
    BEST_PATH.parent.mkdir(parents=True, exist_ok=True)

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    global_step = 0  # track steps to know when warmup ends

    for epoch in range(1, Epochs + 1):
        model.train()
        running_loss = correct = total = 0.0

        pbar = tqdm(train_dl, desc=f"Epoch {epoch:02d}")
        for batch in pbar:
            # unpack lengths and pass to model
            xb, yb, lb = batch
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            lb = lb.to(device, non_blocking=True)

            with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16,
                                    enabled=use_amp and torch.cuda.is_bf16_supported()):
                if torch.isnan(xb).any():
                    xb = torch.nan_to_num(xb, nan=0.0)

                logits = model(xb, lengths=lb)
                loss = criterion(logits, yb)

                if not torch.isfinite(loss):
                    continue

            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.3)
            scaler.step(opt)
            scaler.update()
            sched.step()  # per-batch warmup + cosine
            global_step += 1

            pred = logits.argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
            running_loss += loss.item() * yb.size(0)

            pbar.set_postfix(
                train_loss=f"{running_loss / max(1,total):.3f}",
                train_acc=f"{100 * correct / max(1,total):.1f}%",
                lr=f"{opt.param_groups[0]['lr']:.2e}"
            )

        tr_acc = 100.0 * correct / max(1, total)
        val_loss, val_acc = evaluate(model, val_dl, device, criterion)
        print(f"Epoch {epoch:02d} ➜ train {tr_acc:.1f}% | val {val_acc:.1f}% (loss {val_loss:.3f}) | lr {opt.param_groups[0]['lr']:.2e}")

        # Collapse guard: big sudden drop -> reload best + shrink LR
        if epoch > 1 and prev_val_acc > 50.0 and val_acc < 0.5 * prev_val_acc:
            print(f"WARNING: accuracy collapse ({prev_val_acc:.2f}% → {val_acc:.2f}%). Restoring best and reducing LR ×5.")
            if BEST_PATH.exists():
                model.load_state_dict(torch.load(BEST_PATH, map_location=device))
            for g in opt.param_groups:
                g['lr'] = max(g['lr'] / 5.0, 1e-6)
            print(f"New LR: {opt.param_groups[0]['lr']:.2e}")

        # Best-by-accuracy checkpoint
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), BEST_PATH)
            print(f"** Saved new best model params ** @ {best_val_acc:.1f}%")

        prev_val_acc = val_acc



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/1.94G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/229M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/112M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/84848 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/9982 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4890 [00:00<?, ? examples/s]

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


Epoch 01:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 01 ➜ train 11.1% | val 31.2% (loss 2.795) | lr 4.17e-05
** Saved new best model params ** @ 31.2%


Epoch 02:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 02 ➜ train 50.3% | val 80.9% (loss 1.212) | lr 8.33e-05
** Saved new best model params ** @ 80.9%


Epoch 03:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 03 ➜ train 75.2% | val 91.2% (loss 0.818) | lr 1.25e-04
** Saved new best model params ** @ 91.2%


Epoch 04:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 04 ➜ train 82.4% | val 93.6% (loss 0.725) | lr 1.67e-04
** Saved new best model params ** @ 93.6%


Epoch 05:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 05 ➜ train 85.1% | val 93.7% (loss 0.708) | lr 2.08e-04
** Saved new best model params ** @ 93.7%


Epoch 06:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 06 ➜ train 86.6% | val 94.6% (loss 0.674) | lr 2.50e-04
** Saved new best model params ** @ 94.6%


Epoch 07:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 07 ➜ train 87.8% | val 94.6% (loss 0.668) | lr 2.92e-04
** Saved new best model params ** @ 94.6%


Epoch 08:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 08 ➜ train 88.3% | val 95.1% (loss 0.651) | lr 3.33e-04
** Saved new best model params ** @ 95.1%


Epoch 09:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 09 ➜ train 89.1% | val 95.9% (loss 0.628) | lr 3.75e-04
** Saved new best model params ** @ 95.9%


Epoch 10:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 10 ➜ train 89.5% | val 95.8% (loss 0.631) | lr 4.17e-04


Epoch 11:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 11 ➜ train 89.9% | val 96.0% (loss 0.621) | lr 4.58e-04
** Saved new best model params ** @ 96.0%


Epoch 12:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 12 ➜ train 89.9% | val 95.5% (loss 0.634) | lr 5.00e-04


Epoch 13:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 13 ➜ train 90.7% | val 95.8% (loss 0.623) | lr 5.00e-04


Epoch 14:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 14 ➜ train 91.1% | val 96.2% (loss 0.610) | lr 4.99e-04
** Saved new best model params ** @ 96.2%


Epoch 15:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 15 ➜ train 91.3% | val 96.6% (loss 0.603) | lr 4.99e-04
** Saved new best model params ** @ 96.6%


Epoch 16:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 16 ➜ train 91.7% | val 96.4% (loss 0.607) | lr 4.97e-04


Epoch 17:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 17 ➜ train 92.1% | val 96.5% (loss 0.602) | lr 4.96e-04


Epoch 18:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 18 ➜ train 92.4% | val 96.9% (loss 0.590) | lr 4.94e-04
** Saved new best model params ** @ 96.9%


Epoch 19:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 19 ➜ train 92.7% | val 96.8% (loss 0.592) | lr 4.92e-04


Epoch 20:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 20 ➜ train 92.9% | val 96.7% (loss 0.596) | lr 4.90e-04


Epoch 21:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 21 ➜ train 93.0% | val 96.7% (loss 0.593) | lr 4.87e-04


Epoch 22:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 22 ➜ train 93.2% | val 97.0% (loss 0.586) | lr 4.84e-04
** Saved new best model params ** @ 97.0%


Epoch 23:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 23 ➜ train 93.3% | val 96.9% (loss 0.589) | lr 4.81e-04


Epoch 24:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 24 ➜ train 93.4% | val 96.7% (loss 0.598) | lr 4.77e-04


Epoch 25:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 25 ➜ train 93.6% | val 96.8% (loss 0.593) | lr 4.74e-04


Epoch 26:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 26 ➜ train 93.8% | val 96.8% (loss 0.592) | lr 4.69e-04


Epoch 27:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 27 ➜ train 93.9% | val 96.8% (loss 0.592) | lr 4.65e-04


Epoch 28:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 28 ➜ train 94.1% | val 96.7% (loss 0.597) | lr 4.60e-04


Epoch 29:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 29 ➜ train 94.2% | val 97.2% (loss 0.581) | lr 4.55e-04
** Saved new best model params ** @ 97.2%


Epoch 30:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 30 ➜ train 94.3% | val 97.1% (loss 0.585) | lr 4.50e-04


Epoch 31:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 31 ➜ train 94.6% | val 97.0% (loss 0.589) | lr 4.45e-04


Epoch 32:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 32 ➜ train 94.6% | val 97.2% (loss 0.584) | lr 4.39e-04


Epoch 33:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 33 ➜ train 94.7% | val 97.1% (loss 0.585) | lr 4.33e-04


Epoch 34:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 34 ➜ train 94.7% | val 97.2% (loss 0.584) | lr 4.27e-04


Epoch 35:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 35 ➜ train 95.0% | val 97.2% (loss 0.580) | lr 4.20e-04
** Saved new best model params ** @ 97.2%


Epoch 36:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 36 ➜ train 94.9% | val 97.1% (loss 0.585) | lr 4.14e-04


Epoch 37:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 37 ➜ train 95.1% | val 97.2% (loss 0.583) | lr 4.07e-04


Epoch 38:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 38 ➜ train 95.0% | val 97.2% (loss 0.585) | lr 4.00e-04


Epoch 39:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 39 ➜ train 95.2% | val 97.4% (loss 0.579) | lr 3.93e-04
** Saved new best model params ** @ 97.4%


Epoch 40:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 40 ➜ train 95.4% | val 97.1% (loss 0.589) | lr 3.85e-04


Epoch 41:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 41 ➜ train 95.4% | val 97.4% (loss 0.582) | lr 3.78e-04
** Saved new best model params ** @ 97.4%


Epoch 42:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 42 ➜ train 95.6% | val 97.2% (loss 0.580) | lr 3.70e-04


Epoch 43:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 43 ➜ train 95.5% | val 97.3% (loss 0.586) | lr 3.62e-04


Epoch 44:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 44 ➜ train 95.6% | val 97.5% (loss 0.579) | lr 3.54e-04
** Saved new best model params ** @ 97.5%


Epoch 45:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 45 ➜ train 95.9% | val 97.6% (loss 0.577) | lr 3.46e-04
** Saved new best model params ** @ 97.6%


Epoch 46:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 46 ➜ train 95.9% | val 97.4% (loss 0.579) | lr 3.37e-04


Epoch 47:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 47 ➜ train 95.9% | val 97.1% (loss 0.589) | lr 3.29e-04


Epoch 48:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 48 ➜ train 96.0% | val 97.4% (loss 0.582) | lr 3.20e-04


Epoch 49:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 49 ➜ train 96.1% | val 97.4% (loss 0.582) | lr 3.12e-04


Epoch 50:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 50 ➜ train 96.0% | val 97.3% (loss 0.585) | lr 3.03e-04


Epoch 51:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 51 ➜ train 96.2% | val 97.5% (loss 0.576) | lr 2.94e-04


Epoch 52:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 52 ➜ train 96.3% | val 97.6% (loss 0.578) | lr 2.86e-04


Epoch 53:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 53 ➜ train 96.3% | val 97.5% (loss 0.579) | lr 2.77e-04


Epoch 54:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 54 ➜ train 96.5% | val 97.5% (loss 0.584) | lr 2.68e-04


Epoch 55:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 55 ➜ train 96.5% | val 97.3% (loss 0.589) | lr 2.59e-04


Epoch 56:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 56 ➜ train 96.5% | val 97.4% (loss 0.589) | lr 2.50e-04


Epoch 57:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 57 ➜ train 96.6% | val 97.4% (loss 0.581) | lr 2.41e-04


Epoch 58:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 58 ➜ train 96.6% | val 97.5% (loss 0.581) | lr 2.32e-04


Epoch 59:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 59 ➜ train 96.8% | val 97.4% (loss 0.587) | lr 2.23e-04


Epoch 60:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 60 ➜ train 96.7% | val 97.4% (loss 0.585) | lr 2.14e-04


Epoch 61:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 61 ➜ train 96.9% | val 97.4% (loss 0.583) | lr 2.06e-04


Epoch 62:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 62 ➜ train 96.9% | val 97.5% (loss 0.585) | lr 1.97e-04


Epoch 63:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 63 ➜ train 96.9% | val 97.5% (loss 0.584) | lr 1.88e-04


Epoch 64:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 64 ➜ train 97.0% | val 97.5% (loss 0.581) | lr 1.80e-04


Epoch 65:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 65 ➜ train 97.0% | val 97.5% (loss 0.585) | lr 1.71e-04


Epoch 66:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 66 ➜ train 97.1% | val 97.4% (loss 0.585) | lr 1.63e-04


Epoch 67:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 67 ➜ train 97.0% | val 97.4% (loss 0.586) | lr 1.54e-04


Epoch 68:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 68 ➜ train 97.2% | val 97.6% (loss 0.582) | lr 1.46e-04


Epoch 69:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 69 ➜ train 97.3% | val 97.6% (loss 0.583) | lr 1.38e-04
** Saved new best model params ** @ 97.6%


Epoch 70:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 70 ➜ train 97.4% | val 97.6% (loss 0.584) | lr 1.30e-04


Epoch 71:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 71 ➜ train 97.4% | val 97.5% (loss 0.588) | lr 1.22e-04


Epoch 72:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 72 ➜ train 97.4% | val 97.6% (loss 0.583) | lr 1.15e-04


Epoch 73:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 73 ➜ train 97.5% | val 97.7% (loss 0.580) | lr 1.07e-04
** Saved new best model params ** @ 97.7%


Epoch 74:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 74 ➜ train 97.5% | val 97.6% (loss 0.583) | lr 1.00e-04


Epoch 75:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 75 ➜ train 97.4% | val 97.7% (loss 0.582) | lr 9.31e-05


Epoch 76:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 76 ➜ train 97.6% | val 97.7% (loss 0.579) | lr 8.63e-05
** Saved new best model params ** @ 97.7%


Epoch 77:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 77 ➜ train 97.6% | val 97.6% (loss 0.582) | lr 7.96e-05


Epoch 78:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 78 ➜ train 97.6% | val 97.6% (loss 0.585) | lr 7.32e-05


Epoch 79:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 79 ➜ train 97.6% | val 97.6% (loss 0.584) | lr 6.70e-05


Epoch 80:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 80 ➜ train 97.7% | val 97.6% (loss 0.582) | lr 6.11e-05


Epoch 81:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 81 ➜ train 97.7% | val 97.7% (loss 0.585) | lr 5.53e-05


Epoch 82:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 82 ➜ train 97.8% | val 97.7% (loss 0.584) | lr 4.99e-05


Epoch 83:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 83 ➜ train 97.8% | val 97.7% (loss 0.584) | lr 4.46e-05


Epoch 84:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 84 ➜ train 97.8% | val 97.6% (loss 0.583) | lr 3.97e-05


Epoch 85:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 85 ➜ train 97.9% | val 97.7% (loss 0.585) | lr 3.50e-05


Epoch 86:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 86 ➜ train 97.8% | val 97.6% (loss 0.585) | lr 3.06e-05


Epoch 87:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 87 ➜ train 97.8% | val 97.7% (loss 0.585) | lr 2.64e-05


Epoch 88:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 88 ➜ train 97.9% | val 97.7% (loss 0.586) | lr 2.26e-05


Epoch 89:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 89 ➜ train 97.9% | val 97.7% (loss 0.586) | lr 1.90e-05


Epoch 90:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 90 ➜ train 98.0% | val 97.6% (loss 0.586) | lr 1.58e-05


Epoch 91:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 91 ➜ train 98.0% | val 97.7% (loss 0.586) | lr 1.28e-05


Epoch 92:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 92 ➜ train 98.0% | val 97.7% (loss 0.586) | lr 1.01e-05


Epoch 93:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 93 ➜ train 97.9% | val 97.7% (loss 0.586) | lr 7.77e-06


Epoch 94:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 94 ➜ train 97.9% | val 97.7% (loss 0.585) | lr 5.71e-06


Epoch 95:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 95 ➜ train 98.0% | val 97.7% (loss 0.586) | lr 3.97e-06


Epoch 96:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 96 ➜ train 98.0% | val 97.6% (loss 0.586) | lr 2.54e-06


Epoch 97:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 97 ➜ train 98.0% | val 97.7% (loss 0.586) | lr 1.50e-06


Epoch 98:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 98 ➜ train 97.9% | val 97.7% (loss 0.586) | lr 1.50e-06


Epoch 99:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 99 ➜ train 97.9% | val 97.7% (loss 0.586) | lr 1.50e-06


Epoch 100:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 100 ➜ train 97.9% | val 97.6% (loss 0.585) | lr 1.50e-06


In [None]:
# --- Save LAST params locally + Drive ---
torch.save(model.state_dict(), "/content/last_kws_mamba_noAux.pt")

CKPT_DIR = "/content/drive/MyDrive/kws_models_noAux_192_12"
os.makedirs(CKPT_DIR, exist_ok=True)
torch.save(model.state_dict(), f"{CKPT_DIR}/last_kws_mamba_noAux.pt")
print("Saved LAST model to Drive")

# --- Copy BEST (by val_acc) to Drive if it exists ---
best_local = "/content/best_kws.pt"
if os.path.exists(best_local):
    import shutil
    dst = f"{CKPT_DIR}/best_kws.pt"
    shutil.copy2(best_local, dst)          # copies the already-saved BEST
    print(f"Copied BEST model (val_acc={best_val_acc:.2f}%) to Drive: {dst}")
else:
    print("WARNING: no best checkpoint was found to copy.")

Saved LAST model to Drive
Copied BEST model (val_acc=97.75%) to Drive: /content/drive/MyDrive/kws_models_noAux_192_12/best_kws.pt
