In [None]:
import argparse
import os
import math
import random
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import librosa
import soundfile as sf

import timm
import torchaudio

from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score

In [None]:
@dataclass
class Config:
    sample_rate: int = 16000
    epoch_sec: int = 30
    n_mels: int = 128
    fmin: int = 20
    fmax: int = 7600
    hop_length: int = 160
    win_length: int = 400
    n_fft: int = 1024
    top_db: Optional[int] = 80
    spec_size: int = 224
    num_classes: int = 4
    vit_name: str = "vit_base_patch16_224"
    vit_drop_rate: float = 0.1
    vit_drop_path_rate: float = 0.1
    train_batch_size: int = 32
    val_batch_size: int = 48
    lr: float = 2e-4
    weight_decay: float = 1e-4
    warmup_epochs: int = 2
    epochs: int = 25
    mixup_alpha: float = 0.0
    spec_aug_freq_masks: int = 2
    spec_aug_time_masks: int = 2
    spec_aug_freq_width: int = 16
    spec_aug_time_width: int = 32
    label_smoothing: float = 0.05
    num_workers: int = 4
    seed: int = 42
    class_names: List[str] = None


DEFAULT_CLASS_NAMES_4 = ["Wake", "Light", "Deep", "REM"]
DEFAULT_CLASS_NAMES_5 = ["Wake", "N1", "N2", "N3", "REM"]


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def load_audio(path: str, target_sr: int) -> np.ndarray:
    if TORCHAUDIO_AVAILABLE:
        try:
            wav, sr = torchaudio.load(path)
            wav = wav.mean(0, keepdim=False).numpy()
        except Exception:
            wav, sr = sf.read(path, always_2d=False)
            if wav.ndim == 2:
                wav = wav.mean(axis=1)
    else:
        wav, sr = sf.read(path, always_2d=False)
        if wav.ndim == 2:
            wav = wav.mean(axis=1)

    if sr != target_sr:
        wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr, res_type="kaiser_fast")
        sr = target_sr

    return wav.astype(np.float32)

In [None]:
def ensure_length_seconds(wav: np.ndarray, sr: int, sec: int) -> np.ndarray:
    target_len = sec * sr
    if len(wav) == target_len:
        return wav
    if len(wav) > target_len:
        return wav[:target_len]
    # pad by reflection to avoid hard edges
    pad_len = target_len - len(wav)
    if pad_len <= 0:
        return wav
    left = pad_len // 2
    right = pad_len - left
    wav = np.pad(wav, (left, right), mode="reflect")
    return wav

In [None]:
def wav_to_logmel(
    wav: np.ndarray,
    sr: int,
    n_mels: int,
    n_fft: int,
    hop_length: int,
    win_length: int,
    fmin: int,
    fmax: int,
    top_db: Optional[int]
) -> np.ndarray:
    S = librosa.feature.melspectrogram(
        y=wav,
        sr=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
        power=2.0,
        center=True,
        window="hann"
    )
    S_db = librosa.power_to_db(S, ref=np.max)
    if top_db is not None:
        S_db = np.clip(S_db, a_min=S_db.max() - top_db, a_max=None)
    # normalize to 0..1
    S_min = S_db.min()
    S_max = S_db.max()
    S_norm = (S_db - S_min) / max(S_max - S_min, 1e-6)
    return S_norm.astype(np.float32)


def resize_spec(spec: np.ndarray, size: int) -> np.ndarray:
    # spec shape: [n_mels, T]
    # resize to [size, size] using torch interpolate for consistency
    with torch.no_grad():
        x = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
        x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
        x = x.squeeze(0).squeeze(0)
    return x.numpy().astype(np.float32)


def spec_augment(spec: torch.Tensor, freq_masks: int, time_masks: int, freq_width: int, time_width: int) -> torch.Tensor:
    # spec shape: [B, 1, H, W]
    B, C, H, W = spec.shape
    out = spec.clone()
    for b in range(B):
        for _ in range(freq_masks):
            f = random.randint(0, freq_width)
            f0 = random.randint(0, max(0, H - f))
            out[b, :, f0:f0 + f, :] = 0.0
        for _ in range(time_masks):
            t = random.randint(0, time_width)
            t0 = random.randint(0, max(0, W - t))
            out[b, :, :, t0:t0 + t] = 0.0
    return out

In [None]:
class SleepEpochDataset(Dataset):
    def __init__(
        self,
        csv_path: str,
        cfg: Config,
        augment: bool = False
    ):
        df = pd.read_csv(csv_path)
        if "path" not in df.columns or "label" not in df.columns:
            raise ValueError("CSV must have columns path and label")
        self.paths = df["path"].astype(str).tolist()
        self.labels = df["label"].astype(int).tolist()
        self.cfg = cfg
        self.augment = augment

    def __len__(self) -> int:
        return len(self.paths)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        p = self.paths[idx]
        y = self.labels[idx]
        wav = load_audio(p, target_sr=self.cfg.sample_rate)
        wav = ensure_length_seconds(wav, self.cfg.sample_rate, self.cfg.epoch_sec)

        # simple waveform augmentation for training
        if self.augment:
            if random.random() < 0.25:
                noise = np.random.randn(len(wav)).astype(np.float32) * 0.005
                wav = wav + noise
            if random.random() < 0.25:
                gain = np.exp(np.random.uniform(-0.2, 0.2))
                wav = wav * gain
            if random.random() < 0.25:
                shift = int(np.random.uniform(-0.1, 0.1) * len(wav))
                wav = np.roll(wav, shift)

        spec = wav_to_logmel(
            wav=wav,
            sr=self.cfg.sample_rate,
            n_mels=self.cfg.n_mels,
            n_fft=self.cfg.n_fft,
            hop_length=self.cfg.hop_length,
            win_length=self.cfg.win_length,
            fmin=self.cfg.fmin,
            fmax=self.cfg.fmax,
            top_db=self.cfg.top_db
        )
        spec = resize_spec(spec, self.cfg.spec_size)  # [H, W] = [224, 224]
        img = np.stack([spec, spec, spec], axis=0)  # [3, H, W]
        x = torch.from_numpy(img).float()
        y_t = torch.tensor(y).long()
        return {"x": x, "y": y_t}

In [None]:
class SleepViTClassifier(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.backbone = timm.create_model(
            cfg.vit_name,
            pretrained=True,
            in_chans=3,
            num_classes=cfg.num_classes,
            drop_rate=cfg.vit_drop_rate,
            drop_path_rate=cfg.vit_drop_path_rate
        )
        in_features = self.backbone.get_classifier().in_features
        self.backbone.reset_classifier(num_classes=0, global_pool="avg")
        self.head = nn.Sequential(
            nn.LayerNorm(in_features),
            nn.Dropout(0.1),
            nn.Linear(in_features, cfg.num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, 3, 224, 224]
        feats = self.backbone.forward_features(x)
        if feats.ndim == 3:
            feats = feats.mean(dim=1)
        logits = self.head(feats)
        return logits


class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing: float = 0.0):
        super().__init__()
        assert 0.0 <= smoothing < 1.0
        self.smoothing = smoothing

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        n_classes = logits.size(-1)
        log_probs = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        loss = torch.mean(torch.sum(-true_dist * log_probs, dim=-1))
        return loss

def do_mixup(x: torch.Tensor, y: torch.Tensor, alpha: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if alpha <= 0.0:
        return x, y, None
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)
    mixed_x = lam * x + (1.0 - lam) * x[index, :]
    y_a = y
    y_b = y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1.0 - lam) * criterion(pred, y_b)


In [None]:
def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    device: torch.device,
    criterion,
    cfg: Config
) -> Tuple[float, float]:
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0

    for batch in loader:
        x = batch["x"].to(device, non_blocking=True)
        y = batch["y"].to(device, non_blocking=True)

        if cfg.spec_aug_freq_masks > 0 or cfg.spec_aug_time_masks > 0:
            x = spec_augment(
                x,
                cfg.spec_aug_freq_masks,
                cfg.spec_aug_time_masks,
                cfg.spec_aug_freq_width,
                cfg.spec_aug_time_width
            )

        if cfg.mixup_alpha > 0.0:
            x, y_a, y_b, lam = do_mixup(x, y, cfg.mixup_alpha)
            logits = model(x)
            loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
            preds = torch.argmax(logits, dim=1)
            # accuracy approximate with hard labels y_a
            correct = (preds == y_a).sum().item()
        else:
            logits = model(x)
            loss = criterion(logits, y)
            preds = torch.argmax(logits, dim=1)
            correct = (preds == y).sum().item()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item() * x.size(0)
        total_correct += correct
        total += x.size(0)

    avg_loss = total_loss / max(total, 1)
    acc = total_correct / max(total, 1)
    return avg_loss, acc


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    cfg: Config
) -> Tuple[float, float, np.ndarray, List[int], List[int]]:
    model.eval()
    total_loss = 0.0
    total = 0
    criterion = nn.CrossEntropyLoss()
    all_preds = []
    all_targets = []

    for batch in loader:
        x = batch["x"].to(device, non_blocking=True)
        y = batch["y"].to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        preds = torch.argmax(logits, dim=1)

        total_loss += loss.item() * x.size(0)
        total += x.size(0)
        all_preds.extend(preds.cpu().tolist())
        all_targets.extend(y.cpu().tolist())

    avg_loss = total_loss / max(total, 1)
    acc = np.mean(np.array(all_preds) == np.array(all_targets))
    cm = confusion_matrix(all_targets, all_preds, labels=list(range(cfg.num_classes)))
    return avg_loss, acc, cm, all_preds, all_targets


def compute_class_weights(labels: List[int], num_classes: int) -> torch.Tensor:
    counts = np.bincount(labels, minlength=num_classes).astype(np.float32)
    inv = 1.0 / np.maximum(counts, 1.0)
    weights = inv / inv.sum() * num_classes
    return torch.tensor(weights, dtype=torch.float32)


def make_optimizer(model: nn.Module, cfg: Config) -> torch.optim.Optimizer:
    return torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)


def make_scheduler(optimizer: torch.optim.Optimizer, cfg: Config, steps_per_epoch: int):
    total_steps = max(1, cfg.epochs) * max(1, steps_per_epoch)
    warmup_steps = cfg.warmup_epochs * max(1, steps_per_epoch)
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)


def save_checkpoint(model: nn.Module, path: str) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({"state_dict": model.state_dict()}, path)


def load_checkpoint(model: nn.Module, path: str, device: torch.device) -> None:
    ckpt = torch.load(path, map_location=device)
    model.load_state_dict(ckpt["state_dict"], strict=True)

In [None]:
def train_main(args):
    cfg = Config()
    cfg.epochs = args.epochs
    cfg.train_batch_size = args.train_bs
    cfg.val_batch_size = args.val_bs
    cfg.vit_name = args.vit
    cfg.num_classes = args.num_classes
    cfg.class_names = DEFAULT_CLASS_NAMES_4 if cfg.num_classes == 4 else DEFAULT_CLASS_NAMES_5

    set_seed(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_ds = SleepEpochDataset(args.train_csv, cfg, augment=True)
    val_ds = SleepEpochDataset(args.val_csv, cfg, augment=False)

    # optional class weights
    cw = compute_class_weights(train_ds.labels, num_classes=cfg.num_classes).to(device)
    criterion = LabelSmoothingCE(smoothing=cfg.label_smoothing)

    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.train_batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=False
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.val_batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=False
    )

    model = SleepViTClassifier(cfg).to(device)

    # freeze backbone for warmup if desired
    if args.freeze_backbone_epochs > 0:
        for p in model.backbone.parameters():
            p.requires_grad = False

    optimizer = make_optimizer(model, cfg)
    scheduler = make_scheduler(optimizer, cfg, steps_per_epoch=len(train_loader))

    best_acc = 0.0
    best_path = os.path.join(args.out_dir, "best.pt")
    os.makedirs(args.out_dir, exist_ok=True)

    global_step = 0
    for epoch in range(cfg.epochs):
        if epoch == args.freeze_backbone_epochs:
            for p in model.backbone.parameters():
                p.requires_grad = True
            print("Backbone unfrozen")

        model.train()
        running_loss = 0.0
        running_correct = 0
        running_total = 0

        for batch in train_loader:
            x = batch["x"].to(device, non_blocking=True)
            y = batch["y"].to(device, non_blocking=True)

            if cfg.spec_aug_freq_masks > 0 or cfg.spec_aug_time_masks > 0:
                x = spec_augment(
                    x,
                    cfg.spec_aug_freq_masks,
                    cfg.spec_aug_time_masks,
                    cfg.spec_aug_freq_width,
                    cfg.spec_aug_time_width
                )

            if cfg.mixup_alpha > 0.0:
                x, y_a, y_b, lam = do_mixup(x, y, cfg.mixup_alpha)
                logits = model(x)
                loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
                preds = torch.argmax(logits, dim=1)
                correct = (preds == y_a).sum().item()
            else:
                logits = model(x)
                # weighted CE with smoothing by wrapping
                ce = nn.CrossEntropyLoss(weight=cw)
                smoothed_loss = ce(logits, y) * (1.0 - cfg.label_smoothing) + \
                                LabelSmoothingCE(cfg.label_smoothing)(logits, y) * cfg.label_smoothing
                loss = smoothed_loss
                preds = torch.argmax(logits, dim=1)
                correct = (preds == y).sum().item()

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            scheduler.step()

            running_loss += loss.item() * x.size(0)
            running_correct += correct
            running_total += x.size(0)

            global_step += 1

        train_loss = running_loss / max(running_total, 1)
        train_acc = running_correct / max(running_total, 1)

        val_loss, val_acc, cm, preds, targets = evaluate(model, val_loader, device, cfg)

        print(f"Epoch {epoch + 1}/{cfg.epochs} | "
              f"train_loss {train_loss:.4f} acc {train_acc:.4f} | "
              f"val_loss {val_loss:.4f} acc {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            save_checkpoint(model, best_path)
            print(f"Saved new best to {best_path} with acc {best_acc:.4f}")

        # optional early stop
        if args.early_stop_acc is not None and val_acc >= args.early_stop_acc:
            print("Early stopping because target val acc reached")
            break

    # Final evaluation on the best model
    load_checkpoint(model, best_path, device)
    val_loss, val_acc, cm, preds, targets = evaluate(model, val_loader, device, cfg)
    print("Best model validation results:")
    print(f"val_loss {val_loss:.4f} val_acc {val_acc:.4f}")
    print("Confusion matrix rows true, cols pred:")
    print(cm)
    print("Balanced accuracy:", balanced_accuracy_score(targets, preds))
    print(classification_report(targets, preds, target_names=cfg.class_names))