In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F
import os

# 加载并预处理MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# 设备参数（这些值用于生成混合数据集）
R_form_LH = 0.962
R_form_RH = 0.0383
S_form_LH = 0.0492
S_form_RH = 0.951
intensity_LH = 0.5
intensity_RH = 1 - intensity_LH

# 创建混合数据集的函数，基于设备参数生成带标签的混合图像
def create_device_mixed_data(x_data, y_data, LH_param, RH_param, label_type='label1'):
    n = x_data.shape[0]  # 获取数据集的大小
    mixed_images = []  # 用于存储混合后的图像
    mixed_labels = []  # 用于存储混合后的标签

    # 生成混合数据（按顺序将每两张图像混合）
    for i in range(n - 1):
        img1, label1 = x_data[i], y_data[i]
        img2, label2 = x_data[i + 1], y_data[i + 1]
        mixed_img = img1 * LH_param * intensity_LH + img2 * RH_param * intensity_RH  # 混合两张图像
        mixed_img = mixed_img / np.max(mixed_img)  # 归一化到0到1之间
        mixed_images.append(mixed_img)
        # 根据标签类型决定选择哪个标签
        mixed_labels.append(label1 if label_type == 'label1' else label2)

    # 处理最后一张图像（与第一张图像混合）
    img1, label1 = x_data[-1], y_data[-1]
    img2, label2 = x_data[0], y_data[0]
    mixed_img = img1 * LH_param * intensity_LH + img2 * RH_param * intensity_RH
    mixed_img = mixed_img / np.max(mixed_img)
    mixed_images.append(mixed_img)
    mixed_labels.append(label1 if label_type == 'label1' else label2)

    return np.array(mixed_images), np.array(mixed_labels)

# 使用设备参数创建R-form和S-form数据集
x_test_R_form, y_test_R_form = create_device_mixed_data(
    test_dataset.data.numpy(), test_dataset.targets.numpy(), R_form_LH, R_form_RH, label_type='label1')
x_test_S_form, y_test_S_form = create_device_mixed_data(
    test_dataset.data.numpy(), test_dataset.targets.numpy(), S_form_LH, S_form_RH, label_type='label2')

# 创建不使用设备参数的混合数据集（用于比较）
x_test_mixed, y_test_mixed = create_device_mixed_data(
    test_dataset.data.numpy(), test_dataset.targets.numpy(), 1.0, 1.0, label_type='label1')

# 创建模型（CNN架构）
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(256 * 3 * 3, 512)
        self.fc2 = nn.Linear(512, 10)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bn = nn.BatchNorm2d(256)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.pool(F.relu(self.conv5(x)))
        x = x.view(-1, 256 * 3 * 3)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型
model = CNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.000002)

# 训练和评估模型
epochs = 20
accuracy_data = np.empty((epochs, 5))
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # 在原始测试集上评估
    model.eval()
    correct = 0
    total = 0
    pred_test = []
    true_test = []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pred_test.extend(predicted.cpu().numpy())
            true_test.extend(labels.cpu().numpy())

    test_acc = correct / total

    # 在混合数据集上评估
    pred_mixed = model(torch.tensor(x_test_mixed).float())
    _, pred_mixed = torch.max(pred_mixed, 1)
    test_acc_mixed = np.mean(pred_mixed == torch.tensor(y_test_mixed))

    # 在R-form数据集上评估
    pred_R_form = model(torch.tensor(x_test_R_form).float())
    _, pred_R_form = torch.max(pred_R_form, 1)
    test_acc_R_form = np.mean(pred_R_form == torch.tensor(y_test_R_form))

    # 在S-form数据集上评估
    pred_S_form = model(torch.tensor(x_test_S_form).float())
    _, pred_S_form = torch.max(pred_S_form, 1)
    test_acc_S_form = np.mean(pred_S_form == torch.tensor(y_test_S_form))

    # 存储准确率数据
    accuracy_data[epoch] = [epoch + 1, test_acc, test_acc_mixed, test_acc_R_form, test_acc_S_form]

    # 打印每个数据集的准确率
    print(f"Test Accuracy (Original): {test_acc:.4f}")
    print(f"Test Accuracy (Mixed): {test_acc_mixed:.4f}")
    print(f"Test Accuracy (R-form): {test_acc_R_form:.4f}")
    print(f"Test Accuracy (S-form): {test_acc_S_form:.4f}")

# 将准确率数据保存为CSV格式
np.savetxt('results/MNIST_accuracy_data.csv', accuracy_data, fmt='%1.4f', delimiter=',')

# 保存混淆矩阵为热力图和数值数据的函数
def save_confusion_matrix(y_true, y_pred, filename_image, filename_data):
    cm = confusion_matrix(y_true, y_pred)

    # 保存混淆矩阵为热力图
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, square=True, annot=False, fmt='d', cbar=False, cmap=plt.cm.Blues)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.savefig(filename_image)
    plt.close()  # 显式关闭图形以释放资源

    # 保存混淆矩阵为数值数据（CSV）
    np.savetxt(filename_data, cm, fmt='%d', delimiter=',')

# 使用最后一次epoch的预测保存混淆矩阵
save_confusion_matrix(np.array(true_test), np.array(pred_test), 'results/MNIST_confusion_matrix_original.png',
                      'results/MNIST_confusion_matrix_original.csv')
save_confusion_matrix(y_test_mixed.flatten(), pred_mixed, 'results/MNIST_confusion_matrix_mixed.png',
                      'resultsMNIST_confusion_matrix_mixed.csv')
save_confusion_matrix(y_test_R_form.flatten(), pred_R_form, 'results/MNIST_confusion_matrix_R_form.png',
                      'resultsMNIST_confusion_matrix_R_form.csv')
save_confusion_matrix(y_test_S_form.flatten(), pred_S_form, 'results/MNIST_confusion_matrix_S_form.png',
                      'results/MNIST_confusion_matrix_S_form.csv')
