In [46]:
from pkg_resources import parse_version
import sys
import pytorch_lightning as pl
import torch 
import torch.nn as nn 
from torchmetrics import __version__ as torchmetrics_version
from torchmetrics import Accuracy
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.callbacks import ModelCheckpoint

In [47]:
#构建lightning模型
class MultiLayerPerceptron(pl.LightningModule):
    #构造函数
    def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
        super().__init__()
    
        # new PL attributes:
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.valid_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)
        #添加了准确率属性，如self.train_acc = Accuracy()
        # Model similar to previous section:
        input_size = image_shape[0] * image_shape[1] * image_shape[2] 
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units: 
            layer = nn.Linear(input_size, hidden_unit) 
            all_layers.append(layer) 
            all_layers.append(nn.ReLU()) 
            input_size = hidden_unit 
 
        all_layers.append(nn.Linear(hidden_units[-1], 10)) 
        self.model = nn.Sequential(*all_layers)

    #简单的前向传播，并返回logits值用于训练、验证和测试步骤
    def forward(self, x):
        x = self.model(x)
        return x
    
    #training_step,training_epoch_end等方法都是lightning独有的方法
    #此方法定义了训练期间的单个前向传播，同时跟踪准确率和训练损失函数值，以供后续分析使用。
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)#返回logits
        loss = nn.functional.cross_entropy(logits, y)#交叉熵
        preds = torch.argmax(logits, dim=1)#argmax用于返回指定维度上最大值所在的索引。在这个上下文中，它被用来从模型的输出（logits）中选择最可能的类别。
        self.train_acc.update(preds, y)     #此处计算了准确率，但没有保存
        self.log("train_loss", loss, prog_bar=True)
        return loss

    #该方法在每一轮训练结束时根据训练过程积累的准确率计算训练数据集的准确率
    def training_epoch_end(self, outs):
        self.log("train_acc", self.train_acc.compute())
        self.train_acc.reset()
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss", loss, prog_bar=True)
        return loss
    
    def validation_epoch_end(self, outs):
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
        self.valid_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)
        return loss

    #在此处指定训练时使用的优化器
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer 

In [48]:
#13.8.2为Lightning设置数据加载器
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path='./'):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def prepare_data(self):
        MNIST(root=self.data_path, download=True) 

    def setup(self, stage=None):
        # stage is either 'fit', 'validate', 'test', or 'predict'
        # here note relevant
        mnist_all = MNIST( 
            root=self.data_path,
            train=True,
            transform=self.transform,  
            download=False
        ) 

        self.train, self.val = random_split(
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
        )

        self.test = MNIST( 
            root=self.data_path,
            train=False,
            transform=self.transform,  
            download=False
        ) 

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64, num_workers=4)

In [49]:
torch.manual_seed(1) 
mnist_dm = MnistDataModule()

In [51]:
#13.8.3使用pytorch Lightning Trainer类训练模型
mnistclassifier = MultiLayerPerceptron()
callbacks = [ModelCheckpoint(save_top_k=1, mode='max', monitor="valid_acc")] # save top 1 model

trainer = pl.Trainer(max_epochs=10, callbacks=callbacks)

trainer.fit(model=mnistclassifier, datamodule=mnist_dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


NotImplementedError: Support for `training_epoch_end` has been removed in v2.0.0. `MultiLayerPerceptron` implements this method. You can use the `on_train_epoch_end` hook instead. To access outputs, save them in-memory as instance attributes. You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520.