In [45]:
import os
import random
import numpy as np
import pandas as pd
from scipy.signal import welch
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F

# 1) 固定 Python、NumPy、Torch 的随机种子
SEED = 3407
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# 2) 禁用 cudnn 的非确定性优化
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# # 3) （PyTorch ≥1.8）强制所有算子使用确定性算法
# torch.use_deterministic_algorithms(True)

g = torch.Generator()
g.manual_seed(SEED)


# ——————————————
# Focal Loss 实现
# ——————————————
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        """
        gamma: 调整难易样本的聚焦参数
        alpha: 各类别的权重，tensor of shape (num_classes,)
        reduction: 'none' | 'mean' | 'sum'
        """
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        inputs: logits, shape (B, C)
        targets: ground-truth indices, shape (B,)
        """
        # 1) 计算 per-sample 交叉熵（未做 reduction）
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')  # (B,)

        # 2) 计算 pt = exp(-ce_loss)
        pt = torch.exp(-ce_loss)  # (B,)

        # 3) 计算权重 alpha_t
        if self.alpha is not None:
            # 假设 alpha 是一个 length-C 的 tensor
            at = self.alpha.gather(0, targets)  # (B,)
            ce_loss = ce_loss * at

        # 4) 计算 focal loss
        loss = (1 - pt) ** self.gamma * ce_loss  # (B,)

        # 5) reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss  # 'none'

# ——————————————
# 1. 特征提取 & 数据集定义
# ——————————————
def extract_features(window, fs=1259):
    """window: (window_size, 4) -> returns (24,) flattened features"""
    N, C = window.shape
    feats = np.zeros((C, 6), dtype=np.float32)
    for ch in range(C):
        sig = window[:, ch]
        rms   = np.sqrt(np.mean(sig**2))
        iemg  = np.sum(np.abs(sig))
        f, Pxx = welch(sig, fs=fs, nperseg=N)
        total = np.sum(Pxx)
        mnf   = np.sum(f * Pxx) / total
        cumsum = np.cumsum(Pxx)
        mdf   = f[np.searchsorted(cumsum, total/2)]
        sef90 = f[np.searchsorted(cumsum, total*0.9)]
        last10 = np.mean(sig[-10:])
        feats[ch] = [mnf, mdf, sef90, rms, iemg, last10]
    return feats.reshape(-1)  # (24,)

class EMGFeatureDataset(Dataset):
    def __init__(self, data_root, subjects, fs=1259, window_sec=2):
        self.fs = fs
        ws = fs * window_sec
        stride = ws
        all_feats = []
        all_labels = []

        for subj in subjects:
            subj_dir = os.path.join(data_root, subj)
            for fn in sorted(os.listdir(subj_dir)):
                if not fn.endswith('.csv'): continue
                df = pd.read_csv(os.path.join(subj_dir, fn))
                arr = df[[c for c in df.columns if c!='label']].values  # (T,4)
                labs = df['label'].values
                n_win = (len(df) - ws) // stride + 1
                for i in range(n_win):
                    s, e = i*stride, i*stride + ws
                    win = arr[s:e]
                    lbl = int(labs[e-1])
                    feats = extract_features(win, fs=self.fs)
                    all_feats.append(feats)
                    all_labels.append(lbl)

        feats_tensor = torch.tensor(np.stack(all_feats), dtype=torch.float32)  # (N,24)
        labels_tensor = torch.tensor(all_labels, dtype=torch.long)            # (N,)

        # Z-score 标准化 24 个特征
        mean = feats_tensor.mean(dim=0, keepdim=True)
        std  = feats_tensor.std(dim=0, keepdim=True) + 1e-6
        norm_feats = (feats_tensor - mean) / std

        # reshape to (N,1,4,6) for Conv2d
        self.data = norm_feats.view(-1, 1, 4, 6)
        self.labels = labels_tensor

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


# ——————————————
# 2. 模型定义
# ——————————————
class ResBlock(nn.Module):
    def __init__(self, channels=1, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(channels, channels, kernel_size, padding=padding, bias=False)
        self.mlp  = nn.Sequential(
            nn.Linear(4*6, 4*6),
            nn.ReLU(inplace=True),
            nn.Linear(4*6, 4*6),
        )
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.conv(x)                        # (B,1,4,6)
        y_flat = y.view(y.size(0), -1)          # (B,24)
        y2 = self.mlp(y_flat).view_as(y)        # (B,1,4,6)
        return self.act(x + y2)                 # 残差 + ReLU

class EMGResNet(nn.Module):
    def __init__(self, num_classes=3, in_ch=1, n_blocks=3):
        super().__init__()
        blocks = []
        for _ in range(n_blocks):
            blocks.append(ResBlock(in_ch))
        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.Flatten(),              # (B,1,4,6)->(B,24)
            nn.Linear(24, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)


# ——————————————
# 3. 数据加载与划分
# ——————————————
data_root = 'filtered'   # 滤波对齐后数据根目录
subjects  = ['subject_1','subject_2','subject_3']
dataset   = EMGFeatureDataset(data_root, subjects)
os.makedirs('checkpoints', exist_ok=True)
save_start = 200

# 按 70/15/15 划分
N = len(dataset)
n_train = int(N * 0.7)
n_val   = int(N * 0.15)
n_test  = N - n_train - n_val
train_ds, val_ds, test_ds = random_split(dataset, [n_train, n_val, n_test],generator=torch.Generator().manual_seed(3407))

train_loader = DataLoader(train_ds, batch_size=6000, shuffle=True,generator=g)
val_loader   = DataLoader(val_ds,   batch_size=6000)
test_loader  = DataLoader(test_ds,  batch_size=6000)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ——————————————
# 4. 模型、损失与优化器
# ——————————————
model = EMGResNet(num_classes=3).to(device)

# 如果想给不同类别加权，比如平衡类别不平衡：
alpha = torch.tensor([1.1, 2, 1.5], device=device)  
criterion = FocalLoss(gamma=3.0, alpha=alpha)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


# ——————————————
# 5. 训练 & 验证
# ——————————————
num_epochs = 800
for epoch in range(1, num_epochs+1):
    model.train()
    running_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        logits = model(X_batch)
        loss = criterion(logits, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * X_batch.size(0)
    epoch_loss = running_loss / n_train

    # 验证
    model.eval()
    correct = 0
    with torch.no_grad():
        for X_val, y_val in val_loader:
            X_val, y_val = X_val.to(device), y_val.to(device)
            preds = model(X_val).argmax(dim=1)
            correct += (preds == y_val).sum().item()
    val_acc = correct / n_val

    if epoch >= save_start:
        ckpt_path = f'checkpoints/epoch_{epoch:03d}.pth'
        torch.save(model.state_dict(), ckpt_path)
        print(f"  → Saved checkpoint: {ckpt_path}")

    print(f"Epoch {epoch:02d}/{num_epochs}  Train Loss: {epoch_loss:.4f}  Val Acc: {val_acc:.4f}")


# ——————————————
# 6. 测试集评估
# ——————————————
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for X_test, y_test in test_loader:
        X_test, y_test = X_test.to(device), y_test.to(device)
        preds = model(X_test).argmax(dim=1)
        correct += (preds == y_test).sum().item()
        total += y_test.size(0)
test_acc = correct / total
print(f"\nTest Accuracy: {test_acc:.4f}")


Epoch 01/800  Train Loss: 0.5726  Val Acc: 0.2508
Epoch 02/800  Train Loss: 0.5540  Val Acc: 0.2564
Epoch 03/800  Train Loss: 0.5379  Val Acc: 0.2708
Epoch 04/800  Train Loss: 0.5239  Val Acc: 0.3008
Epoch 05/800  Train Loss: 0.5119  Val Acc: 0.3163
Epoch 06/800  Train Loss: 0.5018  Val Acc: 0.3396
Epoch 07/800  Train Loss: 0.4932  Val Acc: 0.3796
Epoch 08/800  Train Loss: 0.4861  Val Acc: 0.4018
Epoch 09/800  Train Loss: 0.4803  Val Acc: 0.4040
Epoch 10/800  Train Loss: 0.4755  Val Acc: 0.4151
Epoch 11/800  Train Loss: 0.4716  Val Acc: 0.4240
Epoch 12/800  Train Loss: 0.4684  Val Acc: 0.4218
Epoch 13/800  Train Loss: 0.4657  Val Acc: 0.4151
Epoch 14/800  Train Loss: 0.4635  Val Acc: 0.4218
Epoch 15/800  Train Loss: 0.4616  Val Acc: 0.4273
Epoch 16/800  Train Loss: 0.4598  Val Acc: 0.4329
Epoch 17/800  Train Loss: 0.4582  Val Acc: 0.4329
Epoch 18/800  Train Loss: 0.4567  Val Acc: 0.4340
Epoch 19/800  Train Loss: 0.4552  Val Acc: 0.4373
Epoch 20/800  Train Loss: 0.4538  Val Acc: 0.4428


In [46]:
import os
import torch
from collections import defaultdict

# 配置
checkpoints_dir = 'checkpoints'  # 保存权重的文件夹
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 假设 model 已定义并使用相同架构
best_acc = 0.0
best_ckpt = None
results = []

# 遍历所有 checkpoint 文件
for ckpt in sorted(os.listdir(checkpoints_dir)):
    if not ckpt.endswith('.pth'):
        continue
    path = os.path.join(checkpoints_dir, ckpt)
    
    # 加载模型权重
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    model.eval()
    
    # 初始化统计
    correct = 0
    total = 0
    correct_per_class = defaultdict(int)
    total_per_class = defaultdict(int)
    
    # 测试集评估
    with torch.no_grad():
        for X_test, y_test in test_loader:
            X_test, y_test = X_test.to(device), y_test.to(device)
            preds = model(X_test).argmax(dim=1)
            for p, t in zip(preds.cpu().numpy(), y_test.cpu().numpy()):
                total_per_class[t] += 1
                if p == t:
                    correct_per_class[t] += 1
                    correct += 1
                total += 1
    
    # 计算整体和每类准确率
    acc = correct / total
    results.append((ckpt, acc))
    print(f"\nCheckpoint: {ckpt}")
    print(f"  Overall Test Accuracy: {acc:.4f}")
    for cls in sorted(total_per_class.keys()):
        cls_acc = correct_per_class[cls] / total_per_class[cls]
        print(f"  Class {cls}: {correct_per_class[cls]}/{total_per_class[cls]} correct, Accuracy: {cls_acc:.4f}")
    
    # 更新最佳模型
    if acc > best_acc:
        best_acc = acc
        best_ckpt = ckpt

print(f"\nBest checkpoint: {best_ckpt} with Test Accuracy = {best_acc:.4f}")



Checkpoint: epoch_200.pth
  Overall Test Accuracy: 0.6563
  Class 0: 282/391 correct, Accuracy: 0.7212
  Class 1: 178/281 correct, Accuracy: 0.6335
  Class 2: 132/230 correct, Accuracy: 0.5739

Checkpoint: epoch_201.pth
  Overall Test Accuracy: 0.6530
  Class 0: 275/391 correct, Accuracy: 0.7033
  Class 1: 176/281 correct, Accuracy: 0.6263
  Class 2: 138/230 correct, Accuracy: 0.6000

Checkpoint: epoch_202.pth
  Overall Test Accuracy: 0.6563
  Class 0: 283/391 correct, Accuracy: 0.7238
  Class 1: 176/281 correct, Accuracy: 0.6263
  Class 2: 133/230 correct, Accuracy: 0.5783

Checkpoint: epoch_203.pth
  Overall Test Accuracy: 0.6541
  Class 0: 276/391 correct, Accuracy: 0.7059
  Class 1: 175/281 correct, Accuracy: 0.6228
  Class 2: 139/230 correct, Accuracy: 0.6043

Checkpoint: epoch_204.pth
  Overall Test Accuracy: 0.6596
  Class 0: 286/391 correct, Accuracy: 0.7315
  Class 1: 174/281 correct, Accuracy: 0.6192
  Class 2: 135/230 correct, Accuracy: 0.5870

Checkpoint: epoch_205.pth
  O