In [6]:
from data.pyg_dataToGraph import DataToGraph
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
from src.gaussian_diffusion import GaussianDiffusion
from src.classifier import UNetClassifier

In [7]:
# TODO 加载数据集
dataset = DataToGraph(
    raw_data_path='../data/raw',
    dataset_name='TFF' + '.mat')  # 格式: [(graph,label),...,(graph,label)]

input_dim = dataset[0].x.size(1)
num_classes = dataset.num_classes

# 提取所有的x和y
x0 = []
labels = []

for data in dataset:
    # 提取x (形状为 [num_nodes, input_dim])
    # 但是你提到dataset.x的形状是 [24,50]，这可能是一个图的x特征矩阵
    x0.append(data.x)
    # 提取y（标量标签）
    labels.append(data.y)

# 将列表转换为张量
x0 = torch.stack(x0).unsqueeze(1)  # 形状 [num_samples, 1, 24, 50]

labels = torch.stack(labels)  # 形状 [num_samples]

print(num_classes)
print("X0 shape:", x0.shape)
print("Labels shape:", labels.shape)

Processing...


7
X0 shape: torch.Size([2368, 1, 24, 50])
Labels shape: torch.Size([2368, 1])


Done!


In [8]:
print("Labels shape:", labels.shape)
# 将数据传输到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型
classifier = UNetClassifier(num_classes)
# 准备数据集
dataset = TensorDataset(x0, labels)  # x0: [N,24,50], labels: [N]
dataloader = DataLoader(dataset,
                        batch_size=64,
                        shuffle=True,
                        pin_memory=True)

Labels shape: torch.Size([2368, 1])


In [9]:
num_epochs = 1000
lr = 3e-4
grad_clip = 1.0
save_interval = 50  # 每50个epoch保存一次模型

# 初始化扩散模型
diffusion = GaussianDiffusion(num_steps=64)

# 初始化优化器和学习率调度器
optimizer = AdamW(classifier.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

# 将数据转移到GPU
x0 = x0.to(device)
labels = labels.to(device)

best_loss = float('inf')
train_losses = []


In [10]:
acc_history = []
best_acc = 0.0

for epoch in range(num_epochs):
    classifier.train()
    epoch_loss = 0.0
    correct = 0
    total = 0

    with tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
        for batch_idx, (x_batch, label_batch) in enumerate(pbar):
            # 数据转移到设备
            x_batch = x_batch.to(device)  # [B, 1, 24, 50]
            label_batch = label_batch.to(device)  # [B]
            label_batch = label_batch.squeeze(1)  # 新增代码
            # 随机采样时间步 (关键步骤!)
            B = x_batch.size(0)
            t = torch.randint(0, diffusion.num_steps, (B,), device=device).long()

            # 使用扩散模型加噪 (核心修改点)
            noisy_batch = diffusion.q_sample(x_start=x_batch, t=t)

            # 分类器前向传播
            logits = classifier(noisy_batch, t)  # 假设分类器接受时间步输入

            # 计算损失
            loss = F.cross_entropy(logits, label_batch)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), grad_clip)
            optimizer.step()

            # 统计指标
            epoch_loss += loss.item() * x_batch.size(0)
            _, predicted = torch.max(logits, 1)
            correct += (predicted == label_batch).sum().item()
            total += label_batch.size(0)

            # 更新进度条
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{correct / total:.2%}"
            })

    # 计算epoch指标
    avg_loss = epoch_loss / len(dataset)
    epoch_acc = correct / total
    train_losses.append(avg_loss)
    acc_history.append(epoch_acc)

    # 更新学习率
    scheduler.step()

    # 保存最佳模型
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(classifier.state_dict(), 'best_noisy_classifier.pth')
        print(f"✅ 保存最佳模型 | 准确率: {best_acc:.2%}")

    # 定期保存检查点
    if (epoch + 1) % save_interval == 0:
        torch.save({
            'epoch': epoch,
            'model_state': classifier.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'loss': avg_loss,
            'accuracy': epoch_acc
        }, f'classifier_checkpoint_epoch_{epoch + 1}.pth')

    # 打印epoch结果
    print(f"Epoch {epoch + 1:03d}/{num_epochs} | "
          f"Loss: {avg_loss:.4f} | "
          f"Acc: {epoch_acc:.2%} | "
          f"LR: {scheduler.get_last_lr()[0]:.2e}")

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title("Loss Curve")
plt.xlabel("Epoch")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(acc_history, label='Accuracy', color='orange')
plt.title("Accuracy Curve")
plt.xlabel("Epoch")
plt.legend()
plt.tight_layout()
plt.savefig('classifier_training_curves.png')
plt.show()

print("🎉 训练完成!")


Epoch 1/1000: 100%|██████████| 37/37 [04:32<00:00,  7.38s/batch, loss=1.8201, acc=18.71%]


✅ 保存最佳模型 | 准确率: 18.71%
Epoch 001/1000 | Loss: 2.0185 | Acc: 18.71% | LR: 3.00e-04


Epoch 2/1000:  24%|██▍       | 9/37 [01:16<03:57,  8.48s/batch, loss=1.9588, acc=21.70%]


KeyboardInterrupt: 