In [None]:
from __future__ import annotations

import os
os.environ.setdefault('KMP_DUPLICATE_LIB_OK', 'TRUE')

import csv
import json
import math
import random
import subprocess
import sys
from collections import Counter
from pathlib import Path

import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

PROJECT_ROOT = Path.cwd().resolve().parent if Path.cwd().name.lower() == 'notebook' else Path.cwd().resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from Main.config import ensure_output_dirs, get_default_config
from Main.dataset import build_dataloaders
from Main.evaluate import evaluate_model, save_confusion_matrix_plot
from Main.interpret import run_band_importance
from Main.model import build_model
from Main.train import build_class_weights, train_one_epoch

print('Project root:', PROJECT_ROOT)


In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

raw_dir = PROJECT_ROOT / 'Data' / 'raw'
processed_dir = PROJECT_ROOT / 'Data' / 'processed'
metadata_path = PROJECT_ROOT / 'Data' / 'metadata.csv'
processed_dir.mkdir(parents=True, exist_ok=True)

def _make_epoch_signal(label: int, samples: int = 3000, fs: int = 100) -> np.ndarray:
    t = np.arange(samples) / fs
    # 
    if label == 0:      # N1 theta
        sig = 0.8 * np.sin(2 * np.pi * 6 * t)
    elif label == 1:    # N2 sigma + theta
        sig = 0.7 * np.sin(2 * np.pi * 13 * t) + 0.4 * np.sin(2 * np.pi * 7 * t)
    elif label == 2:    # N3 delta
        sig = 1.2 * np.sin(2 * np.pi * 2 * t)
    else:               # REM alpha + beta
        sig = 0.6 * np.sin(2 * np.pi * 10 * t) + 0.3 * np.sin(2 * np.pi * 20 * t)
    noise = 0.25 * np.random.randn(samples)
    return (sig + noise).astype(np.float32)

def generate_synthetic_dataset():
    subjects = [f'S{i:02d}' for i in range(1, 13)]
    split_map = {}
    for i, s in enumerate(subjects):
        if i < 8:
            split_map[s] = 'train'
        elif i < 10:
            split_map[s] = 'val'
        else:
            split_map[s] = 'test'

    rows = []
    class_counter = Counter()
    for subject_id in subjects:
        for rec_idx in range(2):
            record_id = f'{subject_id}_R{rec_idx+1}'
            n_epochs = 80
            probs = np.array([0.2, 0.35, 0.25, 0.2], dtype=np.float64)
            labels = np.random.choice(np.arange(4), size=n_epochs, p=probs).astype(np.int64)

            eeg = np.stack([_make_epoch_signal(int(y)) for y in labels], axis=0)[:, None, :]
            eeg = (eeg - eeg.mean(axis=-1, keepdims=True)) / (eeg.std(axis=-1, keepdims=True) + 1e-6)

            eeg_path = processed_dir / f'{record_id}_eeg.npy'
            label_path = processed_dir / f'{record_id}_label.npy'
            np.save(eeg_path, eeg.astype(np.float32))
            np.save(label_path, labels.astype(np.int64))

            class_counter.update(labels.tolist())
            rows.append({
                'subject_id': subject_id,
                'record_id': record_id,
                'eeg_path': str(eeg_path),
                'label_path': str(label_path),
                'split': split_map[subject_id],
            })

    with metadata_path.open('w', encoding='utf-8', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=['subject_id', 'record_id', 'eeg_path', 'label_path', 'split'])
        writer.writeheader()
        writer.writerows(rows)

    return rows, class_counter

psg_files = list(raw_dir.rglob('*PSG.edf')) if raw_dir.exists() else []
hyp_files = list(raw_dir.rglob('*Hypnogram.edf')) if raw_dir.exists() else []

if psg_files and hyp_files:
    print(f'Found raw Sleep-EDF files in {raw_dir}, running preprocess script...')
    cmd = [
        sys.executable,
        str(PROJECT_ROOT / 'Main' / 'preprocess.py'),
        '--raw-dir', str(raw_dir),
        '--processed-dir', str(processed_dir),
        '--metadata-path', str(metadata_path),
        '--channel', 'Fpz-Cz',
    ]
    subprocess.run(cmd, check=True)
    print('Preprocess finished:', metadata_path)
else:
    rows, class_counter = generate_synthetic_dataset()
    print('No raw EDF found. Generated synthetic dataset for full-pipeline smoke test.')
    print('Records:', len(rows), 'Class distribution:', dict(class_counter))


In [None]:
cfg = get_default_config()
cfg.data.metadata_path = metadata_path
cfg.data.processed_root = processed_dir
cfg.data.context_window = 5

cfg.model.max_seq_len = cfg.data.context_window
cfg.model.in_channels = 1

cfg.train.epochs = 4
cfg.train.batch_size = 64
cfg.train.num_workers = 0
cfg.train.lr = 1e-3
cfg.train.early_stop_patience = 4

ensure_output_dirs(cfg)
train_ds, val_ds, test_ds, train_loader, val_loader, test_loader = build_dataloaders(cfg)

print('Train/Val/Test samples:', len(train_ds), len(val_ds), len(test_ds))
print('Train class counts:', dict(train_ds.class_counts))


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_model(cfg).to(device)

class_weights = None
if cfg.train.use_class_weights:
    class_weights = build_class_weights(train_ds.class_counts, cfg.model.num_classes, device)

criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
optimizer = AdamW(model.parameters(), lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=max(cfg.train.epochs, 1))

best_val_f1 = -1.0
history = []

for epoch in range(1, cfg.train.epochs + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_metrics = evaluate_model(model, val_loader, device, cfg.model.num_classes, criterion=criterion)
    scheduler.step()

    history.append({
        'epoch': epoch,
        'train_loss': train_loss,
        'val_loss': val_metrics['loss'],
        'val_accuracy': val_metrics['accuracy'],
        'val_macro_f1': val_metrics['macro_f1'],
    })

    if float(val_metrics['macro_f1']) > best_val_f1:
        best_val_f1 = float(val_metrics['macro_f1'])
        torch.save({'model_state_dict': model.state_dict()}, cfg.result.best_ckpt_path)

    print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_macro_f1={val_metrics['macro_f1']:.4f}")

print('Best val macro-F1:', best_val_f1)


In [None]:
ckpt = torch.load(cfg.result.best_ckpt_path, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])

test_metrics = evaluate_model(model, test_loader, device, cfg.model.num_classes, criterion=criterion)
class_names = [k for k, _ in sorted(cfg.data.label_map.items(), key=lambda kv: kv[1])]
save_confusion_matrix_plot(test_metrics['confusion_matrix'], class_names, cfg.result.confusion_matrix_path)

summary = {
    'history': history,
    'test': {
        'loss': test_metrics['loss'],
        'accuracy': test_metrics['accuracy'],
        'macro_f1': test_metrics['macro_f1'],
        'cohen_kappa': test_metrics['cohen_kappa'],
        'per_class': test_metrics['per_class'],
        'confusion_matrix': test_metrics['confusion_matrix'].tolist(),
    }
}

cfg.result.metrics_path.parent.mkdir(parents=True, exist_ok=True)
with cfg.result.metrics_path.open('w', encoding='utf-8') as f:
    json.dump(summary, f, indent=2)

print(json.dumps(summary['test'], indent=2))
print('Saved metrics to', cfg.result.metrics_path)
print('Saved confusion matrix to', cfg.result.confusion_matrix_path)


In [None]:
rows = run_band_importance(
    model=model,
    data_loader=test_loader,
    device=device,
    num_classes=cfg.model.num_classes,
    sampling_rate=cfg.data.sampling_rate,
    output_csv=cfg.result.band_importance_path,
)

print('Band importance saved to', cfg.result.band_importance_path)
for row in rows:
    print(row)
