<a href="https://colab.research.google.com/github/Zig302/KWS-Mamba-Project/blob/main/notebooks/KWS_RetNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets==2.16.0
!pip install huggingface-hub==0.20.0
!apt-get install -y libsox-dev
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip install yet-another-retnet==0.5.1

Collecting datasets==2.16.0
  Downloading datasets-2.16.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.16.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets==2.16.0)
  Downloading dill-0.3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting fsspec<=2023.10.0,>=2023.1.0 (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets==2.16.0)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
Collecting multiprocess (from datasets==2.16.0)
  Downloading multiprocess-0.70.18-py312-none-any.whl.metadata (7.5 kB)
  Downloading multiprocess-0.70.17-py312-none-any.whl.metadata (7.2 kB)
  Downloading multiprocess-0.70.15-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.16.0-py3-none-any.whl (507 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Keyword‑Spotting with RetNet
# =========================================================

from __future__ import annotations
import json, os, random
from pathlib import Path
from typing import Tuple, Dict
import torch, torchaudio
import torch.nn as nn
import os
import math
import torch.nn.functional as F
import torchvision.models as tvm
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tqdm.notebook import tqdm
from yet_another_retnet.retention import MultiScaleRetention
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ---------------------------------------------------------------------
# 2. Waveform-level augmentation (shift + noise)
# ---------------------------------------------------------------------
class WaveToSpec:
    def __init__(self,
                 feature_type: str = "mel",
                 sample_rate: int = 16_000,
                 n_fft: int = 2048,
                 hop_length: int = 256,
                 n_mels: int = 128,
                 n_mfcc: int = 40,
                 top_db: int | None = 80,
                 apply_mask: bool = True,
                 freq_mask_param: int = 15,
                 time_mask_param: int = 10):
        self.feature_type = feature_type.lower(); assert self.feature_type in {"mel","mfcc"}
        self.apply_mask = apply_mask and self.feature_type == "mel"

        if self.feature_type == "mel":
            self.spec = T.MelSpectrogram(sample_rate, n_fft, hop_length, n_mels, power=2)
            self.to_db = T.AmplitudeToDB(stype="power", top_db=top_db)
            if self.apply_mask:
                self.freq_mask = T.FrequencyMasking(freq_mask_param)
                self.time_mask = T.TimeMasking(time_mask_param)
        else:
            self.spec = T.MFCC(sample_rate, n_mfcc,
                                melkwargs=dict(n_fft=n_fft, hop_length=hop_length, n_mels=n_mels))
            self.to_db = None
            self.freq_mask = self.time_mask = None

    def __call__(self, wav: torch.Tensor) -> torch.Tensor:
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
        feats = self.spec(wav)
        if self.apply_mask:
            # Stronger SpecAugment — apply two freq & two time masks
            feats = self.freq_mask(feats); feats = self.time_mask(feats)
            feats = self.freq_mask(feats); feats = self.time_mask(feats)
        if self.to_db is not None:
            feats = self.to_db(feats.clamp(min=1e-10))
        return feats

class Augment:
    def __init__(self, stretch: Tuple[float,float]=(1.0,1.0),
                 shift_ms: int = 100,
                 noise: Tuple[float,float]=(0.,0.005),
                 sr: int = 16_000):
        self.stretch = stretch
        self.shift   = int(shift_ms * sr / 1000)
        self.noise   = noise
        self.sr      = sr

    def _shift(self, x: torch.Tensor):
        if self.shift == 0:
            return x
        s = int(torch.randint(-self.shift, self.shift + 1, ()).item())
        if s == 0:
            return x
        return (F.pad(x, (s, 0))[:, :-s] if s > 0 else F.pad(x, (0, -s))[:, -s:])

    def __call__(self, wav: torch.Tensor):
        squeezed = False
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
            squeezed = True
        if self.stretch != (1.0, 1.0):
            factor = float(torch.empty(()).uniform_(*self.stretch))
            if abs(factor - 1.0) > 1e-3:
                wav, _ = torchaudio.sox_effects.apply_effects_tensor(
                    wav, self.sr, [["tempo", f"{factor}"]]
                )
        wav = self._shift(wav)
        if self.noise[1] > 0:
            sigma = float(torch.empty(()).uniform_(*self.noise))
            if sigma > 0:
                wav = wav + sigma * torch.randn_like(wav)
        return wav.squeeze(0) if squeezed else wav

In [None]:
# ---------------------------------------------------------------------
# 3. Dataset wrapper with dataset-level normalization
# ---------------------------------------------------------------------
class SpeechCommands(Dataset):
    def __init__(self, hf_split, aug: Augment | None, frontend: WaveToSpec,
                 wav_len: int = 16_000, mean: float = 0.0, std: float = 1.0):
        self.ds, self.aug, self.front = hf_split, aug, frontend
        self.wav_len = wav_len
        self.mean = mean
        self.std = std

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        sample = self.ds[idx]
        wav = torch.from_numpy(sample["audio"]["array"]).float()

        if wav.numel() < self.wav_len:
            wav = F.pad(wav, (0, self.wav_len - wav.numel()))
        else:
            wav = wav[: self.wav_len]

        if self.aug:
            wav = self.aug(wav)
        feats = self.front(wav)           # [C=1, 40, ~55] for MFCC
        feats = (feats - self.mean) / (self.std + 1e-6)  # Normalize with precomputed stats
        feats = feats.squeeze(0).transpose(0, 1)  # [T≈55, 40]
        return feats, sample["label"]

In [None]:
# ---------------------------------------------------------------------
# 4. Helper funcs = collate function + data mean/std + LR decay func
# ---------------------------------------------------------------------
from torch.nn.utils.rnn import pad_sequence

# Add padding to spectrograms if needed
def collate_fn(batch):
    feats, lbls = zip(*batch)
    # Also return true lengths for mask-aware pooling
    lens = torch.tensor([f.size(0) for f in feats], dtype=torch.long)
    feats_padded = pad_sequence(feats, batch_first=True, padding_value=0.0)  # [B, T_max, F]
    return feats_padded, torch.tensor(lbls), lens


# Defines your learning rate schedule
def lr_lambda(step):
    ' ' 'Start with a very low learning rate and gradually increase it' ' '
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return max(0.005, 0.5 * (1.0 + math.cos(math.pi * progress)))


# Precompute dataset-level mean and std for MFCCs
def compute_dataset_stats(ds, frontend, wav_len=16_000):
    feats_all = []
    for sample in ds:
        wav = torch.from_numpy(sample["audio"]["array"]).float()
        if wav.numel() < wav_len:
            wav = F.pad(wav, (0, wav_len - wav.numel()))
        else:
            wav = wav[: wav_len]
        feats = frontend(wav).squeeze(0).transpose(0, 1)  # [T, 40]
        feats_all.append(feats)
    feats_all = torch.cat(feats_all, dim=0)
    return feats_all.mean().item(), feats_all.std().item()


# NEW: channel GroupNorm that works on [B, T, d]
class ChannelGroupNorm(nn.Module):
    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True):
        super().__init__()
        self.gn = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine)
    def forward(self, x):  # x: [B, T, D]
        return self.gn(x.transpose(1, 2)).transpose(1, 2)


In [None]:
# ---------------------------------------------------------------------
# RetNet KWS model (Conv embed -> Retention blocks -> mask-aware pooling)
# ---------------------------------------------------------------------
class RetNetBlock(nn.Module):
    """Pre-norm residual block with MultiScaleRetention (parallel form)."""
    def __init__(self, d_model: int, n_heads: int, drop: float):
        super().__init__()
        self.norm = ChannelGroupNorm(num_groups=n_heads, num_channels=d_model)
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)

        self.retention = MultiScaleRetention(
            embed_dim=d_model, num_heads=n_heads,
            relative_position=False
        )
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(drop)
        self.ffn_norm = ChannelGroupNorm(num_groups=n_heads, num_channels=d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 2*d_model),
            nn.SiLU(),
            nn.Dropout(drop),
            nn.Linear(2*d_model, d_model),
            nn.Dropout(drop),
        )

    def forward(self, x):                    # x: [B, T, d]
        # Retention sublayer
        residual = x
        x = self.norm(x)
        q = F.normalize(self.q_proj(x), dim=-1)
        k = F.normalize(self.k_proj(x), dim=-1)
        v = self.v_proj(x)
        y, _ = self.retention.forward_parallel(q, k, v)
        y = self.out_proj(y)
        x = residual + self.dropout(y)

        # FFN sublayer
        residual = x
        x = self.ffn_norm(x)
        x = residual + self.ffn(x)
        return x

class RetNetKWS(nn.Module):
    def __init__(self, num_classes: int, d_model=256, n_layers=8, n_heads=8, in_ch=1, feature_dim=128):
        super().__init__()
        # Convolutional embedding
        self.conv_embed = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.SiLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.SiLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.SiLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.SiLU(),
            nn.MaxPool2d((2, 1)),
        )
        # projection to model width
        freq_dim_after_conv = feature_dim // 4
        flattened_dim = 64 * freq_dim_after_conv
        self.proj = nn.Sequential(
            nn.Linear(flattened_dim, d_model),
            nn.LayerNorm(d_model),
            nn.SiLU(),
            nn.Dropout(0.1)
        )
        # RetNet blocks
        self.blocks = nn.ModuleList([
            RetNetBlock(d_model=d_model, n_heads=n_heads,
                        drop=max(0.02, 0.03 - (i * 0.003)))
            for i in range(n_layers)
        ])
        self.pre_classifier_norm = nn.LayerNorm(d_model)
        self.classifier_dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.SiLU(),
            nn.Dropout(0.05),
            nn.Linear(d_model // 2, num_classes)
        )

    def forward(self, x, lengths: torch.Tensor | None = None):   # x: [B, T, F]
        # reshape for Conv2d: [B, T, F] -> [B, 1, F, T]
        x = x.permute(0, 2, 1).unsqueeze(1)
        x = self.conv_embed(x)                                  # [B, 64, F', T']
        x = x.permute(0, 3, 1, 2).contiguous().flatten(2)       # [B, T', 64*F']
        x = self.proj(x)                                        # [B, T', d_model]

        for blk in self.blocks:
            x = blk(x)

        x = self.pre_classifier_norm(x)

        # mask-aware mean pooling across time
        if lengths is not None:
            t_lens = torch.div(lengths, 2, rounding_mode='floor').clamp(min=1).to(x.device)
            Tprime = x.size(1)
            mask = (torch.arange(Tprime, device=x.device)[None, :] < t_lens[:, None]).float().unsqueeze(-1)
            x_sum = (x * mask).sum(dim=1)
            denom = mask.sum(dim=1).clamp(min=1.0)
            pooled = x_sum / denom
        else:
            pooled = x.mean(dim=1)

        logits = self.classifier(self.classifier_dropout(pooled))
        return logits

# ---------------------------------------------------------------------
@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    tot = correct = loss_sum = 0
    for xb, yb, lb in loader:
        xb, yb, lb = xb.to(device), yb.to(device), lb.to(device)
        logits = model(xb, lengths=lb)
        loss = criterion(logits, yb)
        loss_sum += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        tot += xb.size(0)
    return loss_sum / tot, 100.0 * correct / tot

In [None]:
# ---------------------------------------------------------------------
# 6. Main script
# ---------------------------------------------------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = (device.type == "cuda")

    # dataset
    ds = load_dataset("google/speech_commands", "v0.02")
    n_classes = len(ds["train"].features["label"].names)

    # log-mel + SpecAugment (train only)
    feature_type = "mel"        # "mel"/"mfcc"
    Epochs = 100
    base_lr = 5e-4
    warmup_frac = 0.12

    frontend_train = WaveToSpec(
        feature_type=feature_type,
        n_mfcc=40, n_mels=128,
        apply_mask=True,         # SpecAugment on train
        freq_mask_param=12,
        time_mask_param=20
    )
    frontend_eval = WaveToSpec(
        feature_type=feature_type,
        n_mfcc=40, n_mels=128,
        apply_mask=False
    )

    frontend_stats = WaveToSpec(feature_type=feature_type, n_mfcc=40, n_mels=128, apply_mask=False)

    # Waveform augs shift + a bit of noise
    aug = Augment(shift_ms=120, noise=(0., 0.007))

    # Normalization stats
    train_mean, train_std = compute_dataset_stats(ds["train"], frontend_stats)

    # Datasets
    train_ds = SpeechCommands(ds["train"], aug, frontend_train, mean=train_mean, std=train_std)
    val_ds   = SpeechCommands(ds["validation"], None, frontend_eval, mean=train_mean, std=train_std)
    test_ds  = SpeechCommands(ds["test"], None, frontend_eval, mean=train_mean, std=train_std)

    # Loaders
    dl_kwargs = dict(
        batch_size=64,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )
    train_dl = DataLoader(train_ds, shuffle=True, **dl_kwargs)
    val_dl   = DataLoader(val_ds, shuffle=False, **dl_kwargs)

    # Model
    model = RetNetKWS(n_classes, d_model=128, n_layers=6, n_heads=8).to(device)

    # Loss/opt/sched (per-batch schedule; short warmup; cosine with floor)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
    opt = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=1.8e-4, betas=(0.9, 0.999))

    steps_per_epoch = len(train_dl)
    true_steps      = steps_per_epoch * Epochs
    total_steps     = steps_per_epoch * int(Epochs * 1.5)  # longer cosine tail
    warmup_steps    = int(true_steps * warmup_frac)

    sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

    # Plateau after warmup
    plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="min", factor=0.5, patience=3, threshold=1e-3, cooldown=0, min_lr=1e-6, verbose=True
    )

    best_val_acc = 0.0
    prev_val_acc = 0.0
    BEST_PATH = Path("/content/best_kws_retnet.pt")
    BEST_PATH.parent.mkdir(parents=True, exist_ok=True)

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    global_step = 0

    for epoch in range(1, Epochs + 1):
        model.train()
        running_loss = correct = total = 0.0

        pbar = tqdm(train_dl, desc=f"Epoch {epoch:02d}")
        for xb, yb, lb in pbar:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            lb = lb.to(device, non_blocking=True)

            with torch.amp.autocast(device_type=device.type, enabled=use_amp):
                if torch.isnan(xb).any():
                    xb = torch.nan_to_num(xb, nan=0.0)

                logits = model(xb, lengths=lb)
                loss = criterion(logits, yb)

                if not torch.isfinite(loss):
                    continue

            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.15)
            scaler.step(opt)
            scaler.update()
            sched.step()  # per-batch warmup+cosine
            global_step += 1

            pred = logits.argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
            running_loss += loss.item() * yb.size(0)

            pbar.set_postfix(
                train_loss=f"{running_loss / max(1,total):.3f}",
                train_acc=f"{100 * correct / max(1,total):.1f}%",
                lr=f"{opt.param_groups[0]['lr']:.2e}"
            )

        tr_acc = 100.0 * correct / max(1, total)
        val_loss, val_acc = evaluate(model, val_dl, device, criterion)
        print(f"Epoch {epoch:02d} ➜ train {tr_acc:.1f}% | val {val_acc:.1f}% (loss {val_loss:.3f}) | lr {opt.param_groups[0]['lr']:.2e}")

        # After warmup, allow Plateau to adjust LR (epoch-level)
        if global_step >= warmup_steps:
            plateau.step(val_loss)

        # collapse guard
        if epoch > 1 and prev_val_acc > 50.0 and val_acc < 0.5 * prev_val_acc:
            print(f"WARNING: accuracy collapse ({prev_val_acc:.2f}% → {val_acc:.2f}%). Restoring best and reducing LR ×5.")
            if BEST_PATH.exists():
                model.load_state_dict(torch.load(BEST_PATH, map_location=device))
            for g in opt.param_groups:
                g['lr'] = max(g['lr'] / 5.0, 1e-6)
            print(f"New LR: {opt.param_groups[0]['lr']:.2e}")

        # best checkpoint (accuracy)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), BEST_PATH)
            print(f"** Saved new best RetNet model ** @ {best_val_acc:.1f}%")

        prev_val_acc = val_acc



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/1.94G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/229M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/112M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/84848 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/9982 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4890 [00:00<?, ? examples/s]

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


Epoch 01:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 01 ➜ train 9.0% | val 22.3% (loss 2.884) | lr 4.17e-05
** Saved new best RetNet model ** @ 22.3%


Epoch 02:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 02 ➜ train 39.5% | val 64.3% (loss 1.542) | lr 8.33e-05
** Saved new best RetNet model ** @ 64.3%


Epoch 03:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 03 ➜ train 66.7% | val 81.4% (loss 0.976) | lr 1.25e-04
** Saved new best RetNet model ** @ 81.4%


Epoch 04:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 04 ➜ train 77.4% | val 88.3% (loss 0.761) | lr 1.67e-04
** Saved new best RetNet model ** @ 88.3%


Epoch 05:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 05 ➜ train 82.3% | val 91.0% (loss 0.675) | lr 2.08e-04
** Saved new best RetNet model ** @ 91.0%


Epoch 06:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 06 ➜ train 84.8% | val 91.2% (loss 0.659) | lr 2.50e-04
** Saved new best RetNet model ** @ 91.2%


Epoch 07:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 07 ➜ train 86.3% | val 93.1% (loss 0.605) | lr 2.92e-04
** Saved new best RetNet model ** @ 93.1%


Epoch 08:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 08 ➜ train 87.3% | val 93.4% (loss 0.589) | lr 3.33e-04
** Saved new best RetNet model ** @ 93.4%


Epoch 09:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 09 ➜ train 88.1% | val 92.5% (loss 0.619) | lr 3.75e-04


Epoch 10:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 10 ➜ train 88.6% | val 94.1% (loss 0.566) | lr 4.17e-04
** Saved new best RetNet model ** @ 94.1%


Epoch 11:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 11 ➜ train 89.3% | val 92.0% (loss 0.625) | lr 4.58e-04


Epoch 12:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 12 ➜ train 89.4% | val 94.0% (loss 0.565) | lr 5.00e-04


Epoch 13:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 13 ➜ train 90.0% | val 95.0% (loss 0.537) | lr 5.00e-04
** Saved new best RetNet model ** @ 95.0%


Epoch 14:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 14 ➜ train 90.7% | val 94.8% (loss 0.534) | lr 5.00e-04


Epoch 15:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 15 ➜ train 91.0% | val 95.2% (loss 0.529) | lr 4.99e-04
** Saved new best RetNet model ** @ 95.2%


Epoch 16:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 16 ➜ train 91.5% | val 95.4% (loss 0.519) | lr 4.99e-04
** Saved new best RetNet model ** @ 95.4%


Epoch 17:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 17 ➜ train 91.7% | val 95.4% (loss 0.519) | lr 4.98e-04


Epoch 18:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 18 ➜ train 92.0% | val 95.7% (loss 0.508) | lr 4.98e-04
** Saved new best RetNet model ** @ 95.7%


Epoch 19:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 19 ➜ train 92.2% | val 95.7% (loss 0.511) | lr 4.97e-04


Epoch 20:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 20 ➜ train 92.4% | val 96.1% (loss 0.498) | lr 4.96e-04
** Saved new best RetNet model ** @ 96.1%


Epoch 21:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 21 ➜ train 92.7% | val 95.5% (loss 0.512) | lr 4.95e-04


Epoch 22:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 22 ➜ train 92.9% | val 95.8% (loss 0.508) | lr 4.94e-04


Epoch 23:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 23 ➜ train 93.0% | val 96.0% (loss 0.499) | lr 4.92e-04


Epoch 24:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 24 ➜ train 93.2% | val 96.1% (loss 0.494) | lr 4.91e-04


Epoch 25:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 25 ➜ train 93.4% | val 96.1% (loss 0.495) | lr 4.89e-04


Epoch 26:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 26 ➜ train 93.5% | val 96.0% (loss 0.497) | lr 4.87e-04


Epoch 27:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 27 ➜ train 93.5% | val 96.0% (loss 0.498) | lr 4.86e-04


Epoch 28:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 28 ➜ train 93.8% | val 96.4% (loss 0.484) | lr 4.84e-04
** Saved new best RetNet model ** @ 96.4%


Epoch 29:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 29 ➜ train 93.7% | val 96.5% (loss 0.487) | lr 4.82e-04
** Saved new best RetNet model ** @ 96.5%


Epoch 30:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 30 ➜ train 93.9% | val 96.4% (loss 0.485) | lr 4.79e-04


Epoch 31:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 31 ➜ train 94.1% | val 96.4% (loss 0.483) | lr 4.77e-04


Epoch 32:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 32 ➜ train 94.1% | val 96.4% (loss 0.483) | lr 4.75e-04


Epoch 33:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 33 ➜ train 94.4% | val 96.5% (loss 0.486) | lr 4.72e-04
** Saved new best RetNet model ** @ 96.5%


Epoch 34:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 34 ➜ train 94.4% | val 96.4% (loss 0.482) | lr 4.69e-04


Epoch 35:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 35 ➜ train 94.5% | val 96.5% (loss 0.481) | lr 4.67e-04
** Saved new best RetNet model ** @ 96.5%


Epoch 36:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 36 ➜ train 94.6% | val 96.8% (loss 0.472) | lr 4.64e-04
** Saved new best RetNet model ** @ 96.8%


Epoch 37:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 37 ➜ train 94.7% | val 96.6% (loss 0.479) | lr 4.61e-04


Epoch 38:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 38 ➜ train 94.6% | val 96.3% (loss 0.482) | lr 4.57e-04


Epoch 39:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 39 ➜ train 94.8% | val 96.4% (loss 0.487) | lr 4.54e-04


Epoch 40:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 40 ➜ train 94.9% | val 96.7% (loss 0.476) | lr 4.51e-04


Epoch 41:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 41 ➜ train 95.0% | val 96.7% (loss 0.478) | lr 4.47e-04


Epoch 42:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 42 ➜ train 94.9% | val 96.5% (loss 0.486) | lr 4.44e-04


Epoch 43:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 43 ➜ train 95.0% | val 96.5% (loss 0.483) | lr 4.40e-04


Epoch 44:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 44 ➜ train 95.1% | val 96.7% (loss 0.478) | lr 4.37e-04


Epoch 45:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 45 ➜ train 95.1% | val 96.4% (loss 0.486) | lr 4.33e-04


Epoch 46:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 46 ➜ train 95.3% | val 96.7% (loss 0.476) | lr 4.29e-04


Epoch 47:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 47 ➜ train 95.3% | val 96.6% (loss 0.485) | lr 4.25e-04


Epoch 48:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 48 ➜ train 95.3% | val 96.7% (loss 0.478) | lr 4.21e-04


Epoch 49:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 49 ➜ train 95.4% | val 96.8% (loss 0.473) | lr 4.16e-04
** Saved new best RetNet model ** @ 96.8%


Epoch 50:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 50 ➜ train 95.5% | val 96.9% (loss 0.473) | lr 4.12e-04
** Saved new best RetNet model ** @ 96.9%


Epoch 51:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 51 ➜ train 95.5% | val 96.6% (loss 0.481) | lr 4.08e-04


Epoch 52:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 52 ➜ train 95.6% | val 96.8% (loss 0.473) | lr 4.03e-04


Epoch 53:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 53 ➜ train 95.6% | val 97.0% (loss 0.470) | lr 3.99e-04
** Saved new best RetNet model ** @ 97.0%


Epoch 54:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 54 ➜ train 95.7% | val 96.9% (loss 0.472) | lr 3.94e-04


Epoch 55:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 55 ➜ train 95.8% | val 96.9% (loss 0.468) | lr 3.89e-04


Epoch 56:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 56 ➜ train 95.8% | val 96.6% (loss 0.480) | lr 3.85e-04


Epoch 57:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 57 ➜ train 95.8% | val 96.8% (loss 0.473) | lr 3.80e-04


Epoch 58:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 58 ➜ train 95.9% | val 96.7% (loss 0.476) | lr 3.75e-04


Epoch 59:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 59 ➜ train 95.9% | val 96.9% (loss 0.472) | lr 3.70e-04


Epoch 60:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 60 ➜ train 96.1% | val 96.8% (loss 0.477) | lr 3.65e-04


Epoch 61:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 61 ➜ train 96.1% | val 96.8% (loss 0.476) | lr 3.60e-04


Epoch 62:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 62 ➜ train 96.0% | val 96.8% (loss 0.471) | lr 3.55e-04


Epoch 63:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 63 ➜ train 96.2% | val 96.8% (loss 0.481) | lr 3.50e-04


Epoch 64:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 64 ➜ train 96.3% | val 96.7% (loss 0.480) | lr 3.44e-04


Epoch 65:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 65 ➜ train 96.2% | val 96.9% (loss 0.472) | lr 3.39e-04


Epoch 66:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 66 ➜ train 96.3% | val 96.7% (loss 0.476) | lr 3.34e-04


Epoch 67:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 67 ➜ train 96.4% | val 96.9% (loss 0.474) | lr 3.28e-04


Epoch 68:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 68 ➜ train 96.3% | val 96.7% (loss 0.477) | lr 3.23e-04


Epoch 69:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 69 ➜ train 96.5% | val 96.9% (loss 0.477) | lr 3.17e-04


Epoch 70:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 70 ➜ train 96.5% | val 97.0% (loss 0.472) | lr 3.12e-04
** Saved new best RetNet model ** @ 97.0%


Epoch 71:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 71 ➜ train 96.5% | val 96.8% (loss 0.478) | lr 3.06e-04


Epoch 72:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 72 ➜ train 96.6% | val 96.9% (loss 0.473) | lr 3.01e-04


Epoch 73:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 73 ➜ train 96.6% | val 96.5% (loss 0.486) | lr 2.95e-04


Epoch 74:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 74 ➜ train 96.5% | val 96.8% (loss 0.475) | lr 2.90e-04


Epoch 75:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 75 ➜ train 96.7% | val 97.0% (loss 0.476) | lr 2.84e-04


Epoch 76:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 76 ➜ train 96.7% | val 96.9% (loss 0.477) | lr 2.78e-04


Epoch 77:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 77 ➜ train 96.8% | val 96.8% (loss 0.477) | lr 2.73e-04


Epoch 78:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 78 ➜ train 96.8% | val 97.1% (loss 0.472) | lr 2.67e-04
** Saved new best RetNet model ** @ 97.1%


Epoch 79:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 79 ➜ train 96.8% | val 97.0% (loss 0.476) | lr 2.61e-04


Epoch 80:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 80 ➜ train 96.9% | val 96.9% (loss 0.474) | lr 2.56e-04


Epoch 81:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 81 ➜ train 97.0% | val 97.0% (loss 0.470) | lr 2.50e-04


Epoch 82:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 82 ➜ train 96.9% | val 97.2% (loss 0.467) | lr 2.44e-04
** Saved new best RetNet model ** @ 97.2%


Epoch 83:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 83 ➜ train 97.1% | val 97.0% (loss 0.476) | lr 2.39e-04


Epoch 84:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 84 ➜ train 97.1% | val 97.0% (loss 0.476) | lr 2.33e-04


Epoch 85:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 85 ➜ train 97.2% | val 97.0% (loss 0.473) | lr 2.27e-04


Epoch 86:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 86 ➜ train 97.1% | val 97.1% (loss 0.472) | lr 2.22e-04


Epoch 87:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 87 ➜ train 97.2% | val 96.8% (loss 0.484) | lr 2.16e-04


Epoch 88:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 88 ➜ train 97.2% | val 96.9% (loss 0.478) | lr 2.10e-04


Epoch 89:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 89 ➜ train 97.3% | val 97.0% (loss 0.473) | lr 2.05e-04


Epoch 90:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 90 ➜ train 97.3% | val 97.2% (loss 0.470) | lr 1.99e-04


Epoch 91:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 91 ➜ train 97.3% | val 97.0% (loss 0.476) | lr 1.94e-04


Epoch 92:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 92 ➜ train 97.4% | val 97.0% (loss 0.474) | lr 1.88e-04


Epoch 93:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 93 ➜ train 97.4% | val 97.0% (loss 0.477) | lr 1.83e-04


Epoch 94:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 94 ➜ train 97.4% | val 97.1% (loss 0.473) | lr 1.77e-04


Epoch 95:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 95 ➜ train 97.5% | val 97.0% (loss 0.478) | lr 1.72e-04


Epoch 96:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 96 ➜ train 97.4% | val 97.0% (loss 0.476) | lr 1.66e-04


Epoch 97:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 97 ➜ train 97.5% | val 97.0% (loss 0.478) | lr 1.61e-04


Epoch 98:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 98 ➜ train 97.5% | val 97.0% (loss 0.480) | lr 1.56e-04


Epoch 99:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 99 ➜ train 97.6% | val 97.0% (loss 0.478) | lr 1.50e-04


Epoch 100:   0%|          | 0/1326 [00:00<?, ?it/s]

Epoch 100 ➜ train 97.6% | val 97.0% (loss 0.479) | lr 1.45e-04


In [None]:
# --- Save LAST params locally + Drive ---
torch.save(model.state_dict(), "/content/last_kws_retnet.pt")

CKPT_DIR = "/content/drive/MyDrive/kws_models"
os.makedirs(CKPT_DIR, exist_ok=True)
torch.save(model.state_dict(), f"{CKPT_DIR}/last_kws_retnet.pt")
print("Saved LAST model to Drive")

# --- Copy BEST (by val_acc) to Drive if it exists ---
best_local = "/content/best_kws_retnet.pt"
if os.path.exists(best_local):
    import shutil
    dst = f"{CKPT_DIR}/best_kws_retnet-small.pt"
    shutil.copy2(best_local, dst)          # copies the already-saved BEST
    print(f"Copied BEST model (val_acc={best_val_acc:.2f}%) to Drive: {dst}")
else:
    print("WARNING: no best checkpoint was found to copy.")

Saved LAST model to Drive
Copied BEST model (val_acc=97.24%) to Drive: /content/drive/MyDrive/kws_models/best_kws_retnet-small.pt
