In [1]:
import torch
import torch.nn as nn

from torchvision.models import resnet18
model = resnet18(num_classes=10)

from torchvision import datasets, transforms
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为 Tensor
    transforms.Normalize((0.5,), (0.5,))  # 对图片进行归一化
])

In [2]:
from torch.utils.data import DataLoader
from lightning.fabric import Fabric

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

fabric = Fabric(accelerator='cuda')  
all_batches = []
for batch in test_loader:
    batch_on_device = fabric.to_device(batch)
    all_batches.append(batch_on_device)

In [3]:
data_iter = iter(train_loader)
images, labels = next(data_iter)

print(images.shape)  # 打印一个批次 的图像尺寸 (64, 1, 28, 28)
print(labels.shape)  # 打印标签的尺寸 (64,)

torch.Size([64, 1, 28, 28])
torch.Size([64])


In [4]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

In [22]:
class ResNetMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    # define model and loss
    self.model = resnet18(num_classes=10)
    self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.loss = nn.CrossEntropyLoss()

  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_no):
    # implement single training step
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    accuracy = (logits.argmax(dim=1) == y).float().mean()
    self.log("train_loss", loss)
    self.log("train_accuracy", accuracy)
    return loss
  
  def configure_optimizers(self):
    # choose your optimizer
    return torch.optim.RMSprop(self.parameters(), lr=0.01)

In [33]:
model = ResNetMNIST()
logger = TensorBoardLogger("mnist_logs", name="resnet")
trainer = pl.Trainer(
    devices=1,           
    accelerator='gpu', 
    max_epochs=2, # set number of epochs
   enable_progress_bar=True,
   logger=logger 
)
trainer.fit(model, train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | ResNet           | 11.2 M | train
1 | loss  | CrossEntropyLoss | 0      | train
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.701    Total estimated model params size (MB)
69        Modules in train mode
0         Modules in eval mode
d:\Anaconda\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [34]:
trainer.save_checkpoint("resnet18_mnist.pt")

In [35]:
from tqdm.autonotebook import tqdm

def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location="cuda")

In [36]:
true_y, pred_y = [], []
for batch in tqdm(iter(all_batches), total=len(all_batches)):
  x, y = batch
  true_y.extend(y.cpu())
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

  0%|          | 0/157 [00:00<?, ?it/s]

In [37]:
from sklearn.metrics import classification_report
print(classification_report(true_y, pred_y, digits=3))

              precision    recall  f1-score   support

           0      0.999     0.969     0.984       980
           1      0.995     0.989     0.992      1135
           2      0.992     0.987     0.990      1032
           3      0.976     0.997     0.986      1010
           4      0.991     0.990     0.990       982
           5      0.992     0.981     0.986       892
           6      0.987     0.984     0.986       958
           7      0.988     0.987     0.988      1028
           8      0.956     0.994     0.974       974
           9      0.983     0.978     0.981      1009

    accuracy                          0.986     10000
   macro avg      0.986     0.986     0.986     10000
weighted avg      0.986     0.986     0.986     10000

