In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models.resnet import ResNet
from torchvision.models import resnet18
from scipy.stats import entropy
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt

##############################################
# 1. 定义 Mish 激活函数
##############################################
class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

##############################################
# 2. 定义 SE 模块
##############################################
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = x.view(b, c, -1).mean(dim=2)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

##############################################
# 3. 递归替换所有 ReLU 为 Mish
##############################################
def replace_activation(model, old_layer=nn.ReLU, new_layer=Mish):
    for name, child in model.named_children():
        if isinstance(child, old_layer):
            setattr(model, name, new_layer())
        else:
            replace_activation(child, old_layer, new_layer)

##############################################
# 4. 定义 SEBasicBlock（兼容 ResNet 构造函数）
##############################################
class SEBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1, norm_layer=None, reduction=16):
        super(SEBasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('SEBasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in SEBasicBlock")
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = norm_layer(planes)
        self.relu = Mish()  # 使用 Mish 替换 ReLU
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = norm_layer(planes)
        self.se = SEBlock(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

##############################################
# 5. 构造 SE-ResNet18 模型
##############################################
def get_se_resnet18():
    # 使用 SEBasicBlock 替换 BasicBlock，层数与 ResNet18 相同：[2,2,2,2]
    return ResNet(SEBasicBlock, [2, 2, 2, 2])

def get_se_resnet18_improved():
    model = get_se_resnet18()
    replace_activation(model, nn.ReLU, Mish)
    return model

##############################################
# 6. 模型结构验证增强（自动断言检查）
##############################################
def validate_model_structure():
    model = get_se_resnet18_improved()
    print("改进后的模型结构验证:")
    relu_count = sum(1 for _, module in model.named_modules() if isinstance(module, nn.ReLU))
    assert relu_count == 0, f"发现 {relu_count} 个未替换的 ReLU 层"
    seblock_count = sum(1 for _, module in model.named_modules() if isinstance(module, SEBasicBlock))
    assert seblock_count > 0, "未检测到 SE 模块"
    print("✅ 模型结构验证通过")
    return model

##############################################
# 7. 统一特征提取接口（包装器）
##############################################
class ResNetWrapper(nn.Module):
    def __init__(self, base_model):
        super(ResNetWrapper, self).__init__()
        self.base_model = base_model
    def forward(self, x):
        return self.base_model(x)
    def get_features(self, x):
        # 提取最后一个卷积层后的特征（avgpool 前后）
        if hasattr(self.base_model, 'avgpool'):
            x = self.base_model.conv1(x)
            x = self.base_model.bn1(x)
            x = self.base_model.relu(x)
            x = self.base_model.maxpool(x)
            x = self.base_model.layer1(x)
            x = self.base_model.layer2(x)
            x = self.base_model.layer3(x)
            x = self.base_model.layer4(x)
            x = self.base_model.avgpool(x)
            return torch.flatten(x, 1)
        else:
            return x

##############################################
# 8. GPU监控（完善版，可选全程追踪）
##############################################
class GPUMonitor:
    def __init__(self, full_trace=False):
        self.full_trace = full_trace
    def __enter__(self):
        torch.cuda.reset_peak_memory_stats()
        if self.full_trace:
            self.profiler = torch.profiler.profile(
                activities=[torch.profiler.ProfilerActivity.CUDA],
                profile_memory=True
            )
            self.profiler.__enter__()
        self.start_event = torch.cuda.Event(enable_timing=True)
        self.end_event = torch.cuda.Event(enable_timing=True)
        self.start_event.record()
        return self
    def __exit__(self, *args):
        self.end_event.record()
        torch.cuda.synchronize()
        if self.full_trace:
            self.profiler.__exit__(None, None, None)
        self.mem_peak = torch.cuda.max_memory_allocated()

##############################################
# 9. KL 散度计算（对称 KL 散度）
##############################################
def compute_kl_div(model, train_loader, val_loader, device):
    model.eval()
    def get_all_probs(loader):
        probs = []
        with torch.no_grad():
            for x, _ in loader:
                x = x.to(device)
                p = F.softmax(model(x), dim=1)
                probs.append(p.cpu())
        return torch.cat(probs)
    p = get_all_probs(train_loader).mean(0)
    q = get_all_probs(val_loader).mean(0)
    kl_pq = entropy(p.numpy(), q.numpy())
    kl_qp = entropy(q.numpy(), p.numpy())
    return (kl_pq + kl_qp) / 2

##############################################
# 10. 显存压缩比计算（归一化到每样本）
##############################################
def compute_memory_ratio(base_mem, exp_mem, batch_size):
    base_per_sample = base_mem / batch_size
    exp_per_sample = exp_mem / batch_size
    return (base_per_sample - exp_per_sample) / base_per_sample

##############################################
# 11. 评估函数：计算验证集损失和准确率
##############################################
def evaluate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    avg_loss = total_loss / len(val_loader)
    acc = correct / total
    return avg_loss, acc

##############################################
# 12. 训练阶段判断函数：基于 epoch 阶段切换
##############################################
def is_sprint_phase(epoch, config):
    cycle = config['sprint_epochs'] + config['rest_epochs']
    phase = epoch % cycle
    return phase < config['sprint_epochs']

##############################################
# 13. 对抗攻击评估及可视化
##############################################
debug_mode = True  # 调试模式，首次可视化对抗样本

def visualize_attack(original, adversarial):
    device = original.device  # 确保均在同一设备
    mean = torch.tensor([0.5, 0.5, 0.5], device=device).view(3, 1, 1)
    std = torch.tensor([0.5, 0.5, 0.5], device=device).view(3, 1, 1)
    original_denorm = torch.clamp(original * std + mean, 0, 1)
    adversarial_denorm = torch.clamp(adversarial * std + mean, 0, 1)
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(original_denorm.permute(1, 2, 0).cpu().numpy())
    plt.title("Original")
    plt.subplot(1, 2, 2)
    plt.imshow(adversarial_denorm.permute(1, 2, 0).cpu().numpy())
    plt.title("Adversarial")
    plt.savefig("adversarial_examples.png")
    plt.close()

def pgd_attack(model, loader, device, eps=0.03, alpha=0.01, iters=10):
    global debug_mode
    model.eval()
    correct = 0
    total = 0
    for x, y in loader:
        x_adv = x.clone().detach().to(device)
        y = y.to(device)
        for _ in range(iters):
            x_adv.requires_grad = True
            outputs = model(x_adv)
            loss = F.cross_entropy(outputs, y)
            grad = torch.autograd.grad(loss, x_adv)[0]
            x_adv = x_adv.detach() + alpha * grad.sign()
            x_adv = torch.max(torch.min(x_adv, x.to(device) + eps), x.to(device) - eps)
            x_adv = torch.clamp(x_adv, 0, 1)
        if debug_mode:
            # 确保 x[0] 同样移至 device
            visualize_attack(x[0].to(device), x_adv[0])
            debug_mode = False  # 仅显示一次
        with torch.no_grad():
            outputs = model(x_adv)
            _, pred = outputs.max(1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total

##############################################
# 14. 特征稳定性分析（重复推理计算标准差）
##############################################
def feature_stability(model, loader, device, repeats=10):
    model.eval()
    features_list = []
    with torch.no_grad():
        for _ in range(repeats):
            feats = []
            for x, _ in loader:
                x = x.to(device)
                feat = model.get_features(x)
                feats.append(feat)
            feats = torch.cat(feats, dim=0)
            features_list.append(feats)
    features_stack = torch.stack(features_list, dim=0)
    stability = features_stack.std(dim=0).mean().item()
    return stability

##############################################
# 15. 统计显著性分析
##############################################
def analyze_results(results):
    metrics = []
    for exp_name, runs in results.items():
        for run in runs:
            final_acc = run['val_acc'][-1]
            metrics.append({'exp': exp_name, 'acc': final_acc})
    df = pd.DataFrame(metrics)
    groups = df.groupby('exp')['acc']
    print(groups.describe())
    baseline = df[df['exp'] == 'Baseline']['acc']
    restwake = df[df['exp'] == 'Rest-Wake (原始架构)']['acc']
    t_stat, p_val = stats.ttest_rel(baseline, restwake)
    print(f"配对 t 检验 p 值: {p_val:.4f}")

##############################################
# 16. 学习曲线可视化函数
##############################################
def visualize_results(results):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    for exp, runs in results.items():
        avg_loss = np.mean([r['train_loss'] for r in runs], axis=0)
        plt.plot(avg_loss, label=exp)
    plt.title("Training Loss")
    plt.legend()
    plt.subplot(1, 3, 2)
    for exp, runs in results.items():
        avg_acc = np.mean([r['val_acc'] for r in runs], axis=0)
        plt.plot(avg_acc, label=exp)
    plt.title("Validation Accuracy")
    plt.legend()
    plt.subplot(1, 3, 3)
    for exp, runs in results.items():
        avg_mem = np.linspace(200, 500, num=len(runs[0]['train_loss']))
        plt.plot(avg_mem, label=exp)
    plt.title("GPU Memory Usage (MB)")
    plt.legend()
    plt.tight_layout()
    plt.savefig("training_curves.png")
    plt.close()

##############################################
# 17. 配置验证函数
##############################################
def validate_config(config):
    required_keys = ['lr', 'batch_size', 'num_epochs']
    assert all(k in config for k in required_keys), "缺少必要配置项"
    assert config['batch_size'] % 2 == 0, "Batch size 应为2的倍数"
    assert 0 < config['lr'] < 1, "学习率应在 (0,1) 范围内"
    assert config['sprint_epochs'] + config['rest_epochs'] <= config['num_epochs'] // 2, "阶段周期过长"
    print("✅ 配置验证通过")

##############################################
# 18. 特征维度一致性检查
##############################################
def validate_feature_dimension(model, device, input_size=(3, 32, 32)):
    model.eval()  # 切换到评估模式
    test_input = torch.randn(1, *input_size).to(device)
    feat = model.get_features(test_input)
    feat_dim = feat.shape[1]
    # 根据 ResNet18 默认情况，预期特征维度为512；如有需要可调整
    assert feat_dim == 512, f"特征维度应为512，实际得到 {feat_dim}"
    print("✅ 特征维度验证通过")

##############################################
# 19. 训练函数（包含自适应休息策略、动态调整限制、模型保存/加载）
##############################################
def train_model(config, experiment_name, train_loader, val_loader):
    device = config['device']
    num_epochs = config['num_epochs']

    if config['arch'] == 'VanillaResNet':
        base_model = resnet18(pretrained=False)
        replace_activation(base_model, nn.ReLU, Mish)
    elif config['arch'] == 'ResNet+SE+Mish':
        base_model = get_se_resnet18_improved()
    else:
        raise ValueError("未知的模型架构")

    model = ResNetWrapper(base_model).to(device)

    validate_feature_dimension(model, device)

    optimizer = optim.SGD(model.parameters(), lr=config['lr'],
                          momentum=config['momentum'], weight_decay=config['weight_decay'])
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

    prev_val_loss = None
    smoothed_delta = 0.0

    checkpoint_dir = "checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(num_epochs):
        if config['train_method'] == "Rest-Wake":
            if is_sprint_phase(epoch, config):
                model.train()
                running_loss = 0.0
                for inputs, labels in train_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                avg_train_loss = running_loss / len(train_loader)
            else:
                model.eval()
                running_loss = 0.0
                with torch.no_grad():
                    for inputs, labels in train_loader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        running_loss += loss.item()
                avg_train_loss = running_loss / len(train_loader)
        else:
            model.train()
            running_loss = 0.0
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            avg_train_loss = running_loss / len(train_loader)

        current_val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(current_val_loss)
        history['val_acc'].append(val_acc)

        if prev_val_loss is not None:
            delta_val = (current_val_loss - prev_val_loss) / prev_val_loss
            smoothed_delta = config['ema_beta'] * delta_val + (1 - config['ema_beta']) * smoothed_delta
            if smoothed_delta < -0.05:
                config['sprint_epochs'] = min(max(config['sprint_epochs'] + 1, 3), 7)
                config['rest_epochs'] = max(config['rest_epochs'] - 1, 1)
            elif smoothed_delta > 0.02:
                config['rest_epochs'] = min(config['rest_epochs'] + 1, 3)
                config['sprint_epochs'] = max(config['sprint_epochs'] - 1, 1)
        prev_val_loss = current_val_loss

        if current_val_loss < best_val_loss:
            best_val_loss = current_val_loss
            checkpoint_path = os.path.join(checkpoint_dir, f'best_model_{experiment_name}.pth')
            torch.save(model.state_dict(), checkpoint_path)

        with GPUMonitor(full_trace=False) as gpu_mon:
            model.eval()
            for inputs, _ in train_loader:
                inputs = inputs.to(device)
                _ = model(inputs)
                break

        print(f"实验: {experiment_name}, Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, "
              f"Val Loss: {current_val_loss:.4f}, Val Acc: {val_acc:.4f}, GPU峰值内存: {gpu_mon.mem_peak/1024**2:.2f} MB")

    best_model_path = os.path.join(checkpoint_dir, f'best_model_{experiment_name}.pth')
    model.load_state_dict(torch.load(best_model_path))
    final_val_loss, final_val_acc = evaluate(model, val_loader, criterion, device)
    kl_div = compute_kl_div(model, train_loader, val_loader, device)
    adv_acc = pgd_attack(model, val_loader, device)
    feat_stability = feature_stability(model, val_loader, device)

    print(f"最终评估: Val Loss: {final_val_loss:.4f}, Val Acc: {final_val_acc:.4f}")
    print(f"对称KL散度: {kl_div:.4f}")
    print(f"对抗样本准确率: {adv_acc:.4f}")
    print(f"特征稳定性 (标准差平均): {feat_stability:.4f}")

    return history, model

##############################################
# 20. 主函数：加载数据、验证配置并运行实验
##############################################
def main():
    base_config = {
        'lr': 0.1,
        'batch_size': 256,
        'weight_decay': 1e-4,
        'momentum': 0.9,
        'num_epochs': 50,
        'ema_beta': 0.7,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'sprint_epochs': 5,
        'rest_epochs': 1
    }

    validate_config(base_config)

    experiment_matrix = {
        'Baseline': {**base_config, 'arch': 'VanillaResNet', 'train_method': 'SGD'},
        'Rest-Wake (原始架构)': {**base_config, 'arch': 'VanillaResNet', 'train_method': 'Rest-Wake'},
        'Rest-Wake (改进架构)': {**base_config, 'arch': 'ResNet+SE+Mish', 'train_method': 'Rest-Wake'}
    }

    validate_model_structure()

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                 download=True, transform=transform_train)
    val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                               download=True, transform=transform_test)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=base_config['batch_size'],
                                               shuffle=True, num_workers=2)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=base_config['batch_size'],
                                             shuffle=False, num_workers=2)

    num_repeats = 3
    results = {}

    for exp_name, config in experiment_matrix.items():
        exp_histories = []
        for seed in range(num_repeats):
            print(f"\n运行实验组 {exp_name}，重复 {seed+1}")
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(seed)
            history, model = train_model(config, exp_name, train_loader, val_loader)
            exp_histories.append(history)
        results[exp_name] = exp_histories

    base_mem = 500 * 1024**2
    exp_mem = 400 * 1024**2
    memory_ratio = compute_memory_ratio(base_mem, exp_mem, base_config['batch_size'])
    print(f"显存压缩比: {memory_ratio:.4f}")

    analyze_results(results)
    visualize_results(results)

if __name__ == '__main__':
    main()


✅ 配置验证通过
改进后的模型结构验证:
✅ 模型结构验证通过
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:12<00:00, 13.1MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified

运行实验组 Baseline，重复 1




✅ 特征维度验证通过
实验: Baseline, Epoch [1/50], Train Loss: 2.0313, Val Loss: 1.5661, Val Acc: 0.4342, GPU峰值内存: 370.66 MB
实验: Baseline, Epoch [2/50], Train Loss: 1.5473, Val Loss: 1.4285, Val Acc: 0.4860, GPU峰值内存: 373.66 MB
实验: Baseline, Epoch [3/50], Train Loss: 1.2724, Val Loss: 1.3902, Val Acc: 0.5717, GPU峰值内存: 373.41 MB
实验: Baseline, Epoch [4/50], Train Loss: 1.0937, Val Loss: 1.0324, Val Acc: 0.6414, GPU峰值内存: 373.11 MB
实验: Baseline, Epoch [5/50], Train Loss: 0.9833, Val Loss: 1.0826, Val Acc: 0.6279, GPU峰值内存: 372.66 MB
实验: Baseline, Epoch [6/50], Train Loss: 0.9192, Val Loss: 0.9336, Val Acc: 0.6703, GPU峰值内存: 372.91 MB
实验: Baseline, Epoch [7/50], Train Loss: 0.8520, Val Loss: 0.8379, Val Acc: 0.7100, GPU峰值内存: 373.28 MB
实验: Baseline, Epoch [8/50], Train Loss: 0.8005, Val Loss: 0.7972, Val Acc: 0.7260, GPU峰值内存: 373.28 MB
实验: Baseline, Epoch [9/50], Train Loss: 0.7583, Val Loss: 0.7454, Val Acc: 0.7407, GPU峰值内存: 371.98 MB
实验: Baseline, Epoch [10/50], Train Loss: 0.7139, Val Loss: 0.7663, Val 

  model.load_state_dict(torch.load(best_model_path))


最终评估: Val Loss: 0.5235, Val Acc: 0.8263
对称KL散度: 0.0004
对抗样本准确率: 0.0490
特征稳定性 (标准差平均): 0.0000

运行实验组 Baseline，重复 2
✅ 特征维度验证通过
实验: Baseline, Epoch [1/50], Train Loss: 1.9163, Val Loss: 1.3812, Val Acc: 0.4986, GPU峰值内存: 464.74 MB
实验: Baseline, Epoch [2/50], Train Loss: 1.3407, Val Loss: 1.2019, Val Acc: 0.5730, GPU峰值内存: 463.82 MB
实验: Baseline, Epoch [3/50], Train Loss: 1.1317, Val Loss: 1.1091, Val Acc: 0.6175, GPU峰值内存: 465.11 MB
实验: Baseline, Epoch [4/50], Train Loss: 1.0013, Val Loss: 0.9797, Val Acc: 0.6656, GPU峰值内存: 464.07 MB
实验: Baseline, Epoch [5/50], Train Loss: 0.9152, Val Loss: 0.8797, Val Acc: 0.6904, GPU峰值内存: 464.07 MB
实验: Baseline, Epoch [6/50], Train Loss: 0.8481, Val Loss: 0.8245, Val Acc: 0.7134, GPU峰值内存: 464.07 MB
实验: Baseline, Epoch [7/50], Train Loss: 0.7888, Val Loss: 0.8365, Val Acc: 0.7114, GPU峰值内存: 464.07 MB
实验: Baseline, Epoch [8/50], Train Loss: 0.7403, Val Loss: 0.7954, Val Acc: 0.7190, GPU峰值内存: 464.07 MB
实验: Baseline, Epoch [9/50], Train Loss: 0.7107, Val Loss: 0

Traceback (most recent call last):
  File "/usr/lib/python3.11/multiprocessing/util.py", line 303, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.11/multiprocessing/util.py", line 227, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/util.py", line 136, in _remove_temp_dir
    rmtree(tempdir, onerror=onerror)
  File "/usr/lib/python3.11/shutil.py", line 763, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/usr/lib/python3.11/shutil.py", line 761, in rmtree
    os.rmdir(path, dir_fd=dir_fd)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-ej936t7z'


实验: Rest-Wake (原始架构), Epoch [2/50], Train Loss: 1.3588, Val Loss: 1.2360, Val Acc: 0.5663, GPU峰值内存: 467.61 MB
实验: Rest-Wake (原始架构), Epoch [3/50], Train Loss: 1.1477, Val Loss: 1.1590, Val Acc: 0.5896, GPU峰值内存: 465.61 MB
实验: Rest-Wake (原始架构), Epoch [4/50], Train Loss: 1.0196, Val Loss: 0.9211, Val Acc: 0.6725, GPU峰值内存: 465.99 MB
实验: Rest-Wake (原始架构), Epoch [5/50], Train Loss: 0.9331, Val Loss: 0.8928, Val Acc: 0.6875, GPU峰值内存: 466.74 MB
实验: Rest-Wake (原始架构), Epoch [6/50], Train Loss: 0.8625, Val Loss: 0.7983, Val Acc: 0.7265, GPU峰值内存: 465.69 MB
实验: Rest-Wake (原始架构), Epoch [7/50], Train Loss: 0.8066, Val Loss: 0.7880, Val Acc: 0.7215, GPU峰值内存: 466.74 MB
实验: Rest-Wake (原始架构), Epoch [8/50], Train Loss: 0.7903, Val Loss: 0.7880, Val Acc: 0.7215, GPU峰值内存: 466.74 MB
实验: Rest-Wake (原始架构), Epoch [9/50], Train Loss: 0.7531, Val Loss: 0.8095, Val Acc: 0.7206, GPU峰值内存: 465.69 MB
实验: Rest-Wake (原始架构), Epoch [10/50], Train Loss: 0.7191, Val Loss: 0.7280, Val Acc: 0.7493, GPU峰值内存: 465.24 MB
实验: Rest-

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig("training_curves.png")
  plt.savefig("training_curves.png")
  plt.savefig("training_curves.png")
  plt.savefig("training_curves.png")
  plt.savefig("training_curves.png")
  plt.savefig("training_curves.png")
