# ResNet-18 on Fashion-MNIST (Subset Experiment)

本实验使用 **ResNet-18 预训练模型** 对 **Fashion-MNIST** 数据集进行分类，
仅使用 **1/n 子集数据**，并绘制 Loss / Accuracy 曲线与混淆矩阵。

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Subset
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from torchsummary import summary
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np

## 1. 基本设置

In [None]:
batch_size = 16
epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('Using device:', device)

## 2. 数据集（Fashion-MNIST，1/n 子集）

In [None]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor()
])

train_full = datasets.FashionMNIST(
    root='data', train=True, download=True, transform=transform)
test_full = datasets.FashionMNIST(
    root='data', train=False, download=True, transform=transform)

n = 5  # 使用 1/n 数据
rng = np.random.default_rng(42)

train_idx = rng.choice(len(train_full), len(train_full)//n, replace=False)
test_idx = rng.choice(len(test_full), len(test_full)//n, replace=False)

train_dataset = Subset(train_full, train_idx)
test_dataset = Subset(test_full, test_idx)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False)

print(f'Train samples: {len(train_dataset)}')
print(f'Test samples : {len(test_dataset)}')

## 3. ResNet-18 模型

In [None]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-5)

# 模型结构
summary(model, input_size=(3, 224, 224))

## 4. 训练与测试

In [None]:
accs, losses = [], []

for epoch in range(epochs):
    # ---------- Train ----------
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()

    # ---------- Test ----------
    model.eval()
    correct = 0
    total_loss = 0.0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)

            total_loss += F.cross_entropy(out, y).item()
            preds = out.argmax(1)

            correct += (preds == y).sum().item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    acc = correct / len(test_dataset)
    avg_loss = total_loss / len(test_loader)

    accs.append(acc)
    losses.append(avg_loss)

    print(f"Epoch [{epoch+1}/{epochs}]  Loss: {avg_loss:.4f}  Acc: {acc:.4f}")

## 5. Loss & Accuracy 曲线

In [None]:
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(losses, marker='o')
plt.title("Test Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.subplot(1, 2, 2)
plt.plot(accs, marker='o')
plt.title("Test Accuracy Curve")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")

plt.tight_layout()
plt.savefig("res/resnet18_loss_acc.png")
plt.show()

## 6. 混淆矩阵

In [None]:
class_names = [str(i) for i in range(10)]
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names
)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("res/resnet18_confusion_matrix.png")
plt.show()