[Original code](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_%26_Biases.ipynb)

### wandb 이용한 라이트닝

라이트닝
* 파이토치 코드를 위한 경량 wrapper
* 구조 통일에 용이

예시)

```python
from pytorch_lightning.logger import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)
```

W&B 를 사용하면
* 파라미터 기록
* 손실과 metric 기록
* 모델 기록
* 코드 추적
* 시스템 metric 기록(GPU, CPU, memory, temprerature 등...)

In [1]:
# 설치
!pip install -q pytorch-lightning wandb

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m722.4/722.4 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m58.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m729.2/729.2 kB[0m [31m56.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.5/188.5 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.7/214.7 kB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [2]:
# 로그인
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

### DataLoader 세팅

In [3]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = MNIST(
    root="./MNIST",
    download=True,
    transform=transform
)

training_set, validation_set = random_split(dataset, [55000, 5000])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 94815859.11it/s]


Extracting ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 98484303.92it/s]

Extracting ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 28335804.21it/s]


Extracting ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3377154.54it/s]

Extracting ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw






In [4]:
training_loader = DataLoader(training_set, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=64)

### 모델 정의

* `self.save_hyperparameters()` -> W&B에 자동으로 하이퍼파라미터 저장
* `training_step`, `validation_step` 에 `self.log` 호출 시 metric 기록

In [29]:
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule

class MNIST_LitModule(LightningModule):
    def __init__(
        self,
        n_classes=10,
        n_layer_1=128,
        n_layer_2=256,
        lr=1e-3
    ):
        super().__init__()

        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        # loss
        self.loss = CrossEntropyLoss()

        # optimizer param
        self.lr = lr

        # save hyper-param to self.hparams
        self.save_hyperparameters()

    def forward(
        self,
        x
    ):
        """for inference input -> output
        """

        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        return x

    def training_step(
        self,
        batch,
        batch_idx
    ):
        """하나의 batch에서 손실값 리턴
        """
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # 손실값과 metric 기록
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)

        return loss

    def validation_step(self, batch, batch_idx):
        """metrics 기록
        """
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # 손실값과 metric 기록
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        # 커스텀 callback 에서 사용하기 위해 예측값 리턴
        return preds

    def test_step(self, batch, batch_idx):
        """metrics 기록
        """
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # 손실값과 metric 기록
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

    def configure_optimizers(self):
        """모델 최적화 함수 정의
        """
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=10)
        return preds, loss, acc

In [30]:
model = MNIST_LitModule(n_layer_1=128, n_layer_2=128)

### 모델 체크포인트 저장

모델 체크포인트를 W&B에 저장하기 위해 `ModelCheckpoint` 콜백이 필요하다

In [20]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')

### WandbLogger

라이트닝은 `WandbLogger` 이용해서 W&B에 실험을 쉽게 기록이 가능

`Trainer` 에 인자로 전달하면 W&B에 기록

특정한 W&B 팀에 기록하려면 팀명을 `WandbLogger` 의 `entity` 에 전달할 것

In [21]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger(
    project='MNIST', # runs in "MNIST" project
    log_Model='all'  # 학습동안 모든 새로운 체크포인트 기록
)

  rank_zero_warn(


### WandbLogger를 이용해서 이미지, 텍스트 등을 기록

검증 단계 동안 샘플 예측을 자동으로 수행할 수 있도록 custom callback을 만들 수 있다.

`WandbLogger` 는 아래의 logging 함수를 제공한다

* `WandbLogger.log_text` text
* `WandbLogger.log_image` image
* `WandbLogger.log_tabel` W&B table

아래의 예시는 첫 배치의 처음 20개의 이미지의 실제 라벨과 예측 라벨을 기록하는 것이다.

In [33]:
from pytorch_lightning.callbacks import Callback

class LogPredictionsCallback(Callback):
    def on_validation_batch_end(
        self,
        trainer,
        pl_moduel,
        outputs,
        batch,
        batch_idx,
        # dataloader_idx, # version 차이로 인해 주석처리해야 에러 발생안함
    ):
        """검증 batch 끝나고 호출됨

        `outputs`
            `LightningModule.validation_step` 으로부터 전달됨
            모델의 예측에 해당함
        """

        # 첫 배치의 20개 이미지를 기록하기
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [
                f'Ground Truth: {y_i} - Predictions: {y_pred}'
                for y_i, y_pred in zip(y[:n], outputs[:n])
            ]

            # Option 1: `WandbLogger.log_image` 사용
            wandb_logger.log_image(
                key='sample_images',
                images=images,
                caption=captions,
            )

            # Option 2: log predictions as Table
            columns = ['image', 'ground truth', 'prediction']
            data = [
                [wandb.Image(x_i), y_i, y_pred]
                for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))
            ]
            wandb_logger.log_table(
                key='sample_table',
                columns=columns,
                data=data,
            )

log_predictions_callback = LogPredictionsCallback()

### 모델 학습

In [35]:
trainer = Trainer(
    logger=wandb_logger,
    callbacks=[
        log_predictions_callback,
        checkpoint_callback
    ],
    accelerator="gpu",
    max_epochs=5
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [36]:
trainer.fit(model, training_loader, validation_loader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params
---------------------------------------------
0 | layer_1 | Linear           | 100 K 
1 | layer_2 | Linear           | 16.5 K
2 | layer_3 | Linear           | 1.3 K 
3 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
118 K     Trainable params
0         Non-trainable params
118 K     Total params
0.473     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


`wandb.finish()` 를 호출해서 W&B를 닫자

-> 스크립트에서는 자동으로 호출된다

In [37]:
wandb.finish()

VBox(children=(Label(value='0.118 MB of 0.131 MB uploaded (0.029 MB deduped)\r'), FloatProgress(value=0.898580…

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
train_accuracy,▁▄▅▅█▇▆▆▅▇▇▇▇▇▆▇█▇█▇▇▇▇▆▇▇▆▇█▆▇▆▇██▆▇█▇▇
train_loss,█▆▅▄▂▂▂▃▄▂▂▂▂▂▃▂▁▂▁▂▂▂▂▂▂▁▂▂▁▃▂▂▁▁▁▃▂▁▃▃
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
val_accuracy,▃▁▇█▆
val_loss,▇█▁▁▄

0,1
epoch,4.0
train_accuracy,0.95833
train_loss,0.1743
trainer/global_step,4299.0
val_accuracy,0.969
val_loss,0.11312
