In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report
import os
import random

# 设置随机种子确保结果可复现
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [2]:
# 1. 数据处理模块
def prepare_data(batch_size=128, val_split=0.2, seed=42):
    """准备MNIST数据集并创建数据加载器"""
    set_seed(seed)
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # 加载数据集
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    
    # 创建训练集和验证集
    train_size = int((1 - val_split) * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

In [3]:
# 2. 模型构建模块
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=784):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:x.size(1), :]

def create_transformer_model(input_dim=1, d_model=32, nhead=4, num_layers=6, num_classes=10, dropout=0.1):
    """创建Transformer模型"""
    class TransformerModel(nn.Module):
        def __init__(self):
            super(TransformerModel, self).__init__()
            
            # 嵌入层
            self.embedding = nn.Linear(input_dim, d_model)
            
            # 位置编码
            self.pos_encoder = PositionalEncoding(d_model)
            
            # 变压器编码器层
            encoder_layer = nn.TransformerEncoderLayer(
                d_model = d_model,
                nhead = nhead,
                dim_feedforward = 4*d_model,
                dropout = dropout,
                activation = 'gelu',
                batch_first = True
            )
            
            # 变压器编码器
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            
            # 分类头
            self.classifier = nn.Sequential(
                nn.Linear(d_model, 2*d_model),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(2*d_model, num_classes)
            )
            
            self.d_model = d_model
            
        def forward(self, x):
            batch_size = x.size(0)
            x = x.view(batch_size, -1, 1)  # [batch_size, 784, 1]
            
            # 嵌入和位置编码
            x = self.embedding(x) * np.sqrt(self.d_model)
            x = self.pos_encoder(x)
            
            # 通过变压器编码器
            x = self.transformer_encoder(x)
            
            # 全局平均池化
            x = x.mean(dim=1)
            
            # 分类
            logits = self.classifier(x)
            
            return logits
    
    return TransformerModel()

In [4]:
# 3. 训练模块
def train_epoch(model, train_loader, criterion, optimizer, device):
    """训练一个epoch"""
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    avg_loss = train_loss / len(train_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

def validate(model, val_loader, criterion, device):
    """验证模型"""
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            val_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    avg_loss = val_loss / len(val_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs, model_path='best_model.pth'):
    """训练完整模型"""
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    for epoch in range(1, epochs + 1):
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        # 调整学习率
        scheduler.step(val_loss)
        
        # 获取当前学习率
        current_lr = scheduler.get_last_lr()[0]
        print(f'Epoch {epoch}: Current learning rate is {current_lr}')
        
        print(f'Epoch: {epoch}/{epochs}')
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
        
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), model_path)
            print('Model saved!\n')
    
    return train_losses, val_losses, train_accuracies, val_accuracies

In [5]:
# 4. 评估模块
def evaluate_model(model, test_loader, device):
    """评估模型"""
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, preds = output.max(1)
            
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    accuracy = accuracy_score(all_targets, all_preds)
    report = classification_report(all_targets, all_preds)
    
    return accuracy, report

In [6]:
# 5. 可视化模块
def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies):
    """绘制训练曲线"""
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title('Training and Validation Accuracy')
    plt.tight_layout()
    
    plt.savefig('training_curves.png')
    plt.close()

In [7]:
# 6. 样本生成模块
def generate_samples(model, test_loader, num_samples=5000, device='cpu', save_path='generated_samples'):
    """生成样本并保存"""
    model.eval()
    generated_samples = []
    generated_labels = []
    
    # 从测试集获取样本
    all_data = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            all_data.append(data)
            all_targets.append(target)
    
    all_data = torch.cat(all_data, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    
    # 随机选择样本
    indices = torch.randperm(len(all_data))[:num_samples]
    samples = all_data[indices].to(device)
    true_labels = all_targets[indices].to(device)
    
    with torch.no_grad():
        outputs = model(samples)
        _, predicted = outputs.max(1)
        
        generated_samples = samples.cpu().numpy()
        generated_labels = predicted.cpu().numpy()
        true_labels = true_labels.cpu().numpy()
    
    # 保存生成的样本
    np.save(f'{save_path}_images.npy', generated_samples)
    np.save(f'{save_path}_predicted_labels.npy', generated_labels)
    np.save(f'{save_path}_true_labels.npy', true_labels)
    
    return generated_samples, generated_labels, true_labels

In [8]:
# 主函数
def main(epochs):
    
    # 设置随机种子
    set_seed(42)
    
    # 设备配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}\n")
    
    # 准备数据
    train_loader, val_loader, test_loader = prepare_data(batch_size=32)
    
    # 创建模型
    model = create_transformer_model().to(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    # 定义学习率调度器
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    # 训练模型
    train_losses, val_losses, train_accuracies, val_accuracies = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs, model_path='best_transformer_mnist_model.pth'
    )
    
    # 绘制训练曲线
    plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies)
    
    # 加载最佳模型
    model.load_state_dict(torch.load('best_transformer_mnist_model.pth'))
    
    # 评估模型
    test_accuracy, test_report = evaluate_model(model, test_loader, device)
    print(f'Test Accuracy: {test_accuracy:.4f}')
    print(f'Classification Report:')
    print(test_report)

    
    # 生成样本
    generate_samples(model, test_loader, num_samples=5000, device=device)
    print("Training and evaluation completed!")

In [9]:
if __name__ == "__main__":
    main(50)

Using device: cuda





Epoch: 1/50
Train Loss: 1.9939 | Train Acc: 25.04%
Val Loss: 1.8013 | Val Acc: 30.63%
Model saved!

Epoch: 2/50
Train Loss: 1.7121 | Train Acc: 33.49%
Val Loss: 1.5409 | Val Acc: 41.43%
Model saved!

Epoch: 3/50
Train Loss: 1.4744 | Train Acc: 43.06%
Val Loss: 1.3633 | Val Acc: 48.00%
Model saved!

Epoch: 4/50
Train Loss: 1.3295 | Train Acc: 50.04%
Val Loss: 1.3123 | Val Acc: 52.24%
Model saved!

Epoch: 5/50
Train Loss: 1.2132 | Train Acc: 54.92%
Val Loss: 1.2615 | Val Acc: 54.26%
Model saved!

Epoch: 6/50
Train Loss: 1.0876 | Train Acc: 60.84%
Val Loss: 1.5350 | Val Acc: 47.88%
Epoch: 7/50
Train Loss: 0.9947 | Train Acc: 64.24%
Val Loss: 1.4455 | Val Acc: 51.42%
Epoch: 8/50
Train Loss: 0.9236 | Train Acc: 66.69%
Val Loss: 1.3244 | Val Acc: 54.26%
Epoch: 9/50
Train Loss: 0.8808 | Train Acc: 68.13%
Val Loss: 0.9467 | Val Acc: 65.76%
Model saved!

Epoch: 10/50
Train Loss: 0.8286 | Train Acc: 70.22%
Val Loss: 0.8375 | Val Acc: 70.04%
Model saved!

Epoch: 11/50
Train Loss: 0.7841 | Train A

  model.load_state_dict(torch.load('best_transformer_mnist_model.pth'))


Test Accuracy: 0.9265
Classification Report:
              precision    recall  f1-score   support

           0       0.95      0.96      0.96       980
           1       0.99      0.97      0.98      1135
           2       0.92      0.93      0.93      1032
           3       0.91      0.87      0.89      1010
           4       0.92      0.93      0.93       982
           5       0.87      0.87      0.87       892
           6       0.93      0.94      0.94       958
           7       0.96      0.92      0.94      1028
           8       0.91      0.94      0.93       974
           9       0.88      0.91      0.89      1009

    accuracy                           0.93     10000
   macro avg       0.93      0.93      0.93     10000
weighted avg       0.93      0.93      0.93     10000

Training and evaluation completed!
