# 1. Package 설치

In [4]:
! pip install --quiet "pandas" "torch>=1.6, <1.9" "torchvision" "ipython[notebook]" "seaborn" "pytorch-lightning>=1.4" "torchmetrics>=0.6" "lightning-bolts"
! pip install pytorch-lightning

[31mERROR: Could not find a version that satisfies the requirement torch<1.9,>=1.6 (from versions: 1.9.0, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1)[0m[31m
[0m[31mERROR: No matching distribution found for torch<1.9,>=1.6[0m[31m
[0mCollecting pytorch-lightning
  Using cached pytorch_lightning-1.7.7-py3-none-any.whl (708 kB)
Collecting tqdm>=4.57.0
  Using cached tqdm-4.64.1-py2.py3-none-any.whl (78 kB)
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Using cached fsspec-2022.8.2-py3-none-any.whl (140 kB)
Collecting tensorboard>=2.9.1
  Using cached tensorboard-2.10.0-py3-none-any.whl (5.9 MB)
Collecting pyDeprecate>=0.3.1
  Using cached pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting torchmetrics>=0.7.0
  Using cached torchmetrics-0.9.3-py3-none-any.whl (419 kB)
Collecting PyYAML>=5.4
  Using cached PyYAML-6.0.tar.gz (124 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.

In [7]:
#! pip install seaborn
#! pip install torchvision
#! pip install lightning-bolts
#! pip install pytorch_lightning
# install pytorch lighting
! pip install pytorch-lightning --quiet
# install weights and biases
! pip install wandb --quiet


# 2. Imports

In [14]:

import pytorch_lightning as pl
# your favorite machine learning tracking tool
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import loggers as pl_loggers


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import CIFAR10


# 3. 데이터 모듈 정의하기

cifar 10 데이터셋을 사용합니다.

cifar 10 데이터 셋은 LightningDataModule의 서브클래스이므로, 상속받아 메서드를 구현한다.

공식 doc 참고하기 : https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html

### 메서드 간략 설명
    prepare_data
GPU 하나에서 한 번만 호출된다. 일반적으로 아래의 데이터 다운로드 단계와 같다.

    setup
각각의 GPU에서 개별적으로 호출되며 fit 또는 test단계일 경우 정의할 스테이지를 받아온다.

    train_dataloader, val_dataloader, test_dataloader
각각의 데이터 세트를 로드한다.

### Notes
- random_split : training-validation split구분을 용이하게 한다. 전체 데이터셋에 적용된다.



In [15]:

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([
            #텐서화 및 정규화
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        # 클래스 수 정의
        self.num_classes = 10
    
    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            #cifar 데이터 받아와서 경로, transform, train, 등등 설정하기 
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)


# 4. Logger 정의하기

epoch이 끝날 때마다 성능 로그를 남긴다. 

## Notes
    pytorch_lightning.callbacks.Callback

공식 docs 보기 : https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html

프로젝트 전체에서 재사용할 수 있는 독립형 프로그램을 뜻한다.

예를 들어, 학습중

    on_validation_epoch_end

메서드가 수행되면 callback hook 이 수행되어 아래의 코드가 수행된다.

다시말해 validation epoch이 종료될 때, 로그를 남기는 코드가 수행되는 것이다.

In [16]:
class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
    
    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)

# 5. 모델 정의하기

LightningModule은 모델이 아닌 시스템을 정의하기 때문에, 모델을 단일 클래스로 독립화시켜야한다.

기존의 Pytorch 코드를 5개의 섹션으로 구분하여 넣어주면 된다.
- Computations (__init__).

- Train loop (training_step)

- Validation loop (validation_step)

- Test loop (test_step)

- Optimizers (configure_optimizers)

### 메서드 간략 설명
pytorch_lightning.LightningModule은 모델의 아키텍처와 forward 전달 방식을 상속받아 구현할 수 있게 해놨다.

#### 1. init 메서드

init에서 모델에 필요한 하이퍼 파라미터를 전달한다.

    save_parameters

를 call 함으로써 init에 있는 모든 값을 check point에 저장하도록 요청할 수 있다.

. . .


    _get_conv_output과 _forward_features

메서드는 convolutional block의 텐서 사이즈를 자동으로 계산하는데 사용한다. 

. . .

    forward
일단 파이토치 코드와 비슷하나, 라이트닝에서는 오직 inference action을 위해서만 사용된다. training_step은 학습 loop를 정의한다.


#### 2. Training step 메서드




In [17]:
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)
        
        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = Accuracy()

    # returns the size of the output tensor going into Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)
       
       return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [18]:
dm = CIFAR10DataModule(batch_size=32)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

Files already downloaded and verified
Files already downloaded and verified


(torch.Size([32, 3, 32, 32]), torch.Size([32]))

In [21]:
model = LitModel((3, 32, 32), dm.num_classes)

# Initialize wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize logger
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")

# Initialize Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# Initialize a trainer
trainer = pl.Trainer(max_epochs=10,
                     gpus=0, 
                     logger=tb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: