In [8]:
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch import nn, optim
import time

# ========== 路径 ==========
data_dir = "data_processed"
model_dir = "models"
os.makedirs(model_dir, exist_ok=True)

# ========== 模型结构 ==========
class EEGNetHybridNorm(nn.Module):
    """
    HybridNorm EEGNet:
    - 在每个卷积块中保留 BatchNorm，同时叠加 GroupNorm(num_groups=1) ≈ LayerNorm（跨通道归一）
    - 比纯 BatchNorm 更稳地应对跨被试分布差异
    """
    def __init__(self, num_classes=4, num_channels=22, sample_length=1000, dropout_rate=0.35):
        super().__init__()
        # block1: time conv -> BN -> spatial conv -> BN -> GN -> act -> pool -> dropout
        self.temporal = nn.Conv2d(1, 16, (1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(16)

        self.spatial = nn.Conv2d(16, 32, (num_channels, 1), groups=16, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        self.gn2 = nn.GroupNorm(1, 32)  # = LayerNorm over channels
        self.act1 = nn.LeakyReLU(0.1)
        self.pool1 = nn.AvgPool2d((1, 4))
        self.drop1 = nn.Dropout(dropout_rate)

        # block2: depthwise time conv -> pointwise -> BN -> GN -> act -> pool -> dropout
        self.dw_time = nn.Conv2d(32, 32, (1, 16), groups=32, padding=(0, 8), bias=False)
        self.pw = nn.Conv2d(32, 32, (1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(32)
        self.gn3 = nn.GroupNorm(1, 32)
        self.act2 = nn.LeakyReLU(0.1)
        self.pool2 = nn.AvgPool2d((1, 8))
        self.drop2 = nn.Dropout(dropout_rate)

        with torch.no_grad():
            dummy = torch.randn(1, 1, num_channels, sample_length)
            out = self._forward_features(dummy)
            fc_in = out.view(1, -1).size(1)

        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(fc_in, 64),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout_rate),
            nn.Linear(64, num_classes)
        )

    def _forward_features(self, x):
        x = self.temporal(x)
        x = self.bn1(x)

        x = self.spatial(x)
        x = self.bn2(x)
        x = self.gn2(x)
        x = self.act1(x)
        x = self.pool1(x)
        x = self.drop1(x)

        x = self.dw_time(x)
        x = self.pw(x)
        x = self.bn3(x)
        x = self.gn3(x)
        x = self.act2(x)
        x = self.pool2(x)
        x = self.drop2(x)
        return x

    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

    def extract_features(self, x):
        self.eval()
        with torch.no_grad():
            x = self._forward_features(x)
            return x.view(x.size(0), -1)

def augment_batch(x, noise_std=0.03, p_shift=0.3, max_shift=10, p_scale=0.3, scale_lo=0.9, scale_hi=1.1):
    # 高斯噪声
    x = x + noise_std * torch.randn_like(x)
    # 随机时移
    if torch.rand(1).item() < p_shift:
        shift = int(torch.empty(1).uniform_(-max_shift, max_shift).round().item())
        if shift != 0:
            x = torch.roll(x, shifts=shift, dims=-1)
    # 轻缩放
    if torch.rand(1).item() < p_scale:
        factor = float(torch.empty(1).uniform_(scale_lo, scale_hi).item())
        x = x * factor
    return x


# ========== 数据加载 ==========
def load_dataset(name):
    data = np.load(os.path.join(data_dir, name))
    X, y = data["X"], data["y"]
    print(f"Loaded {name}: {X.shape}")
    return torch.FloatTensor(X).unsqueeze(1), torch.LongTensor(y - 1)

X_train, y_train = load_dataset("train_5subj.npz")
X_val, y_val = load_dataset("val_id_5subj.npz")

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val, y_val), batch_size=64, shuffle=False)

# ========== 训练函数 ==========
def train_eegnet(model, train_loader, val_loader, epochs=300, lr=3e-4, patience=40):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=12, factor=0.7, min_lr=1e-5
    )

    best_val_acc, best_val_loss = 0.0, float("inf")
    patience_cnt = 0
    os.makedirs(model_dir, exist_ok=True)
    best_path = os.path.join(model_dir, "pooled_eegnet_best.pth")

    print("\n Train HybridNorm EEGNet")
    print("Epoch | TrainLoss | ValLoss | TrainAcc | ValAcc | LR")

    for epoch in range(epochs):
        model.train()
        tr_loss, tr_correct, tr_total = 0.0, 0, 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            data = augment_batch(data)  #  轻量增强

            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            tr_loss += loss.item()
            tr_total += target.size(0)
            tr_correct += (out.argmax(1) == target).sum().item()

        tr_loss /= len(train_loader)
        tr_acc = 100. * tr_correct / tr_total

        model.eval()
        va_loss, va_correct, va_total = 0.0, 0, 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                out = model(data)
                va_loss += criterion(out, target).item()
                va_total += target.size(0)
                va_correct += (out.argmax(1) == target).sum().item()
        va_loss /= len(val_loader)
        va_acc = 100. * va_correct / va_total

        scheduler.step(va_loss)
        lr_now = optimizer.param_groups[0]['lr']
        print(f"{epoch+1:03d} | {tr_loss:.4f} | {va_loss:.4f} | {tr_acc:.2f}% | {va_acc:.2f}% | {lr_now:.1e}")

        # 保存最佳
        if va_acc > best_val_acc:
            best_val_acc, best_val_loss = va_acc, va_loss
            torch.save(model.state_dict(), best_path)
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= patience:
                print(f" Early stopping at epoch {epoch+1}")
                break

    print(f"\n Done. Best Val Acc = {best_val_acc:.2f}%  (saved: {best_path})")
    return best_val_acc

# ========== 主流程 ==========
if __name__ == "__main__":

    model = EEGNetHybridNorm(dropout_rate=0.35)     
    best_val = train_eegnet(model, train_loader, val_loader, epochs=300, lr=3e-4, patience=40)

    # 载入最佳权重，后续照旧做 features / probs 提取
    model.load_state_dict(torch.load(os.path.join(model_dir, "pooled_eegnet_best.pth"), map_location="cpu"))



Loaded train_5subj.npz: (1005, 22, 1000)
Loaded val_id_5subj.npz: (215, 22, 1000)

 Train HybridNorm EEGNet
Epoch | TrainLoss | ValLoss | TrainAcc | ValAcc | LR
001 | 1.4040 | 1.3886 | 24.38% | 23.72% | 3.0e-04
002 | 1.3896 | 1.3849 | 25.27% | 21.40% | 3.0e-04
003 | 1.3868 | 1.3814 | 26.87% | 27.91% | 3.0e-04
004 | 1.3884 | 1.3806 | 27.06% | 26.05% | 3.0e-04
005 | 1.3799 | 1.3770 | 27.96% | 33.02% | 3.0e-04
006 | 1.3787 | 1.3738 | 27.76% | 33.49% | 3.0e-04
007 | 1.3810 | 1.3705 | 28.76% | 28.37% | 3.0e-04
008 | 1.3721 | 1.3641 | 28.96% | 32.09% | 3.0e-04
009 | 1.3663 | 1.3568 | 30.55% | 34.42% | 3.0e-04
010 | 1.3536 | 1.3426 | 31.74% | 32.56% | 3.0e-04
011 | 1.3376 | 1.3284 | 36.02% | 36.28% | 3.0e-04
012 | 1.3226 | 1.3175 | 37.31% | 38.14% | 3.0e-04
013 | 1.2971 | 1.3096 | 38.31% | 38.60% | 3.0e-04
014 | 1.2982 | 1.3028 | 37.31% | 39.53% | 3.0e-04
015 | 1.2795 | 1.2991 | 39.10% | 40.47% | 3.0e-04
016 | 1.2703 | 1.2994 | 40.10% | 39.07% | 3.0e-04
017 | 1.2482 | 1.2883 | 42.09% | 39.53%

In [9]:
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import numpy as np
import torch

def eval_split(model, loader, name="split"):
    device = next(model.parameters()).device
    model.eval()
    all_y, all_pred = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)
            pred = logits.argmax(1).cpu().numpy()
            all_pred.append(pred); all_y.append(y.numpy())
    y = np.concatenate(all_y); y_pred = np.concatenate(all_pred)
    acc = accuracy_score(y, y_pred)
    print(f"\n=== {name} ===")
    print(f"Accuracy: {acc*100:.2f}%")
    print("Confusion matrix:\n", confusion_matrix(y, y_pred))
    print(classification_report(y, y_pred, digits=3))
    return acc

# 载入最优权重后评估：
model.load_state_dict(torch.load(os.path.join(model_dir, "pooled_eegnet_best.pth"), map_location="cpu"))
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
acc_id  = eval_split(model, test_id_loader,  name="TEST_ID")
acc_ood = eval_split(model, test_ood_loader, name="TEST_OOD")
print(f"\nSummary -> TEST_ID: {acc_id*100:.2f}%, TEST_OOD: {acc_ood*100:.2f}%")



=== TEST_ID ===
Accuracy: 55.00%
Confusion matrix:
 [[30 15  3  7]
 [ 9 37  5  4]
 [ 9  6 18 22]
 [ 4  6  9 36]]
              precision    recall  f1-score   support

           0      0.577     0.545     0.561        55
           1      0.578     0.673     0.622        55
           2      0.514     0.327     0.400        55
           3      0.522     0.655     0.581        55

    accuracy                          0.550       220
   macro avg      0.548     0.550     0.541       220
weighted avg      0.548     0.550     0.541       220


=== TEST_OOD ===
Accuracy: 57.64%
Confusion matrix:
 [[ 61  26  46  11]
 [  1  97  31  15]
 [ 10  22  69  43]
 [  2   5  32 105]]
              precision    recall  f1-score   support

           0      0.824     0.424     0.560       144
           1      0.647     0.674     0.660       144
           2      0.388     0.479     0.429       144
           3      0.603     0.729     0.660       144

    accuracy                          0.576     