In [1]:
import torch
import torch.nn as nn
from torch.optim import SGD

from torchvision import datasets, transforms
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为 Tensor
    transforms.Normalize((0.5,), (0.5,))  # 对图片进行归一化
])

In [2]:
from torch.utils.data import DataLoader
from lightning.fabric import Fabric

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

fabric = Fabric(accelerator='cuda')  
all_batches = []
for batch in test_loader:
    batch_on_device = fabric.to_device(batch)
    all_batches.append(batch_on_device)

In [3]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        # 第一层卷积：输入通道为 1，输出通道为 6
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)  # 输出尺寸: (6, 24, 24)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 输出尺寸: (6, 12, 12)
        
        # 第二层卷积：输入通道为 6，输出通道为 16
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 输出尺寸: (16, 8, 8)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 输出尺寸: (16, 4, 4)
        
        # 全连接层：根据卷积输出调整输入尺寸
        self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 展平后尺寸为 16 * 4 * 4 = 256
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))

        x = x.view(x.size(0), -1)  
        
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x) 
        return x

In [4]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

In [12]:
class LeNetLightning(pl.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.model = LeNet(num_classes=num_classes)  # use LeNet model that is self-defined
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

    def configure_optimizers(self):
        return SGD(self.parameters(), lr=0.02, momentum=0.9)


In [17]:
model = LeNetLightning()
logger = TensorBoardLogger("mnist_logs", name="lenet")
trainer = pl.Trainer(
    devices=1,           
    accelerator='gpu', 
    max_epochs=2, # set number of epochs
   enable_progress_bar=True,
   logger=logger 
)
trainer.fit(model, train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | LeNet            | 44.4 K | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
44.4 K    Trainable params
0         Non-trainable params
44.4 K    Total params
0.178     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [18]:
trainer.save_checkpoint("lenet_mnist.pt")

In [19]:
from tqdm.autonotebook import tqdm

def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

inference_model = LeNetLightning.load_from_checkpoint("lenet_mnist.pt", map_location="cuda")

In [20]:
true_y, pred_y = [], []
for batch in tqdm(iter(all_batches), total=len(all_batches)):
  x, y = batch
  true_y.extend(y.cpu())
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

  0%|          | 0/157 [00:00<?, ?it/s]

In [21]:
from sklearn.metrics import classification_report
print(classification_report(true_y, pred_y, digits=3))

              precision    recall  f1-score   support

           0      0.995     0.986     0.990       980
           1      0.995     0.996     0.996      1135
           2      0.995     0.973     0.984      1032
           3      0.994     0.969     0.981      1010
           4      0.989     0.993     0.991       982
           5      0.952     0.997     0.974       892
           6      0.986     0.990     0.988       958
           7      0.980     0.992     0.986      1028
           8      0.971     0.980     0.975       974
           9      0.986     0.969     0.978      1009

    accuracy                          0.985     10000
   macro avg      0.984     0.985     0.984     10000
weighted avg      0.985     0.985     0.985     10000

