<a href="https://colab.research.google.com/github/OnlyBelter/machine-learning-note/blob/master/pytorch/pytorch_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Reference:

* https://pytorch.org/tutorials/beginner/basics/intro.html
* https://lightning.ai/docs/pytorch/stable/starter/introduction.html
* https://lightning.ai/docs/torchmetrics/stable/pages/quickstart.html
* https://youtu.be/DbESHcCoWbM?si=_Wpu5iP14dvi5seK

In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

In [2]:
%%capture
!pip install torchmetrics
!pip install pytorch-lightning

In [3]:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [4]:
import pytorch_lightning as pl
from torchmetrics.functional import accuracy

---
# PyTorch Lightning
1. model
2. optimizer
3. data
4. training loop "the magic"
5. validation loop "the validation magic"

In [5]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

class ResNet(pl.LightningModule):
    def __init__(self, batch_size, learning_rate):
        super().__init__()
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.l1 = nn.Linear(28*28, 512)
        self.l2 = nn.Linear(512, 512)
        self.l3 = nn.Linear(512, 10)
        self.do = nn.Dropout(0.1)
        self.loss_fn = nn.CrossEntropyLoss()

        generator = torch.Generator().manual_seed(42)
        self.train_dataset = datasets.FashionMNIST(
            root="data",
            train=True,
            download=True,
            transform=ToTensor()
        )
        self.val_dataset = datasets.FashionMNIST(
            root="data",
            train=False,
            download=True,
            transform=ToTensor()
        )
        # self.train_dataset, self.val_dataset = \
        #   random_split(training_data, [0.9, 0.1], generator=generator)

    def forward(self, x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do = self.do(h2 + h1)  # residual connection
        logits = self.l3(do)
        return logits

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        b = x.size(0)
        x = x.view(b, -1)

        # 1 forward
        logits = self(x)

        # 2 compute the objective function
        loss = self.loss_fn(logits, y)
        acc = accuracy(logits, y, task='multiclass', num_classes=10)
        pbar = {'train_acc': acc}

        return {'loss': loss, 'progress_bar': pbar}

    def validation_step(self, batch, batch_idx):
        results = self.training_step(batch, batch_idx)
        results['progress_bar']['val_acc'] = results['progress_bar']['train_acc']
        del results['progress_bar']['train_acc']

        # 使用 self.log 记录指标
        self.log('val_loss', results['loss'], on_step=False, on_epoch=True)
        self.log('val_acc', results['progress_bar']['val_acc'], on_step=False, on_epoch=True, prog_bar=True)

        return results

    # def on_validation_epoch_end(self, val_step_outputs):
    #     avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
    #     avg_val_acc = torch.tensor([x['progress_bar']['val_acc'] for x in val_step_outputs]).mean()

    #     pbar = {'avg_val_acc': avg_val_acc}
    #     return {'val_loss': avg_val_loss, 'progress_bar': pbar}

    def train_dataloader(self):
        train_dataloader = DataLoader(self.train_dataset, batch_size=self.batch_size)
        return train_dataloader

    def val_dataloader(self):
        val_dataloader = DataLoader(self.val_dataset, batch_size=self.batch_size)
        return val_dataloader


In [6]:
# model = NeuralNetwork().to(device)
learning_rate = 0.001
batch_size = 64
model = ResNet(batch_size=batch_size, learning_rate=learning_rate)
print(model)

ResNet(
  (l1): Linear(in_features=784, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=512, bias=True)
  (l3): Linear(in_features=512, out_features=10, bias=True)
  (do): Dropout(p=0.1, inplace=False)
  (loss_fn): CrossEntropyLoss()
)


In [7]:
trainer = pl.Trainer(enable_progress_bar=True, max_epochs=10)
trainer.fit(model)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | l1      | Linear           | 401 K  | train
1 | l2      | Linear           | 262 K  | train
2 | l3      | Linear           | 5.1 K  | train
3 | do      | Dropout          | 0      | train
4 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
669 K     Trainable params
0         Non-trainable params
669 K     Total params
2.679     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
