## __W&BでPytorch__

W&BとPyTorchを用いてMNISTデータの多値分類を行う

参考 https://github.com/MLHPC/wandb_tutorial

### __準備__

In [1]:
%%capture
!pip install wandb

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import wandb
from tqdm.notebook import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# API keyを求められた場合は「Setting」→「Danger Zone」から取得

wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

### __設定__

__設定方法__

1. `wantdb.init()`の`config`引数に辞書で指定

2. `wandb.config`で設定
    ```python
    wandb.config.epochs = 50
    ```

3. `wandb.config.update()`メソッドで設定項目の追加
    ```python
    wandb.config.update({"epochs": 8, "batch_size": 64})
    ```

4. yamlファイルで設定

    - `desc`〜説明(description)

    - `value`〜設定値

    ```yaml
    epochs:
        desc: Number of epochs to train over
        value: 100
    batch_size:
        desc: Size of each mini-batch
        value: 32
    ```

In [4]:
wandb.init(
    project="tutorial",
    name="pytorch_0",
    entity="satodaichi",
    config={
        "lr": 0.01,
        "batch_size": 512,
        "epochs": 20,
        "n_input": 28*28,
        "n_hidden": 128,
        "n_output": 10, # 10クラス分類
        "ndigits": 5, # 表示する小数点以下の桁数
        }
)

config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33msatodaichi[0m. Use [1m`wandb login --relogin`[0m to force relogin


### __データの準備__

In [5]:
# transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

# Dataset
train_dataset = torchvision.datasets.MNIST(
    root="./data",
    transform=transform,
    download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="./data",
    transform=transform,
    download=False
)

# DataLoader
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True
)

test_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=False
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



### __モデル__

In [6]:
# 2層の線形変換
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.l1 = nn.Linear(n_input, n_hidden)
        self.l2 = nn.Linear(n_hidden, n_output)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        y = self.l2(x)
        return y

In [7]:
model = Net(config.n_input, config.n_hidden, config.n_output).to(device)
criterion = nn.CrossEntropyLoss() # 損失関数
optimizer = optim.Adam(model.parameters(), lr=config.lr) # 最適化関数

### __学習__

__wandb.log()の使い方__

- `wandb.log()`に指定した辞書がappendされて、historyが完成する

- 一度`wandb.log()`呼ぶとそれが1ステップとして記録される

- 1ステップの中で数ヶ所に分けて`wandb.log()`を呼びたい場合は、以下の2通りの方法がある

    - `step`を指定する

        ```python
        wandb.log({ 'accuracy': 0.9 }, step=10)
        wandb.log({ 'loss': 0.1 }, step=10)
        ```
    - `commit=False`にする
        ```python
        wandb.log({ 'accuracy': 0.9 }, commit=False) # まだ記録されない
        wandb.log({ 'loss': 0.1 }) # ここで { 'accuracy': 0.9, 'loss': 0.2 } が記録される
        ```

In [8]:
for epoch in range(config.epochs):

    # 学習
    model.train()
    train_acc, train_loss = 0, 0
    for inputs, labels in tqdm(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad() # 勾配をリセット
        outputs = model(inputs) # 予測
        loss = criterion(outputs, labels) # 損失関数
        loss.backward() # 逆伝播
        optimizer.step() # パラメータの更新   
        pred = outputs.argmax(1)
        train_acc += (pred==labels).sum().item() / len(labels)
        train_loss += loss.item()
    train_acc = round(train_acc / len(train_dataloader), config.ndigits)
    train_loss = round(train_loss / len(train_dataloader), config.ndigits)

    # 推論
    model.eval()
    test_acc, test_loss = 0, 0
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs) # 予測
            loss = criterion(outputs, labels) # 損失関数
            pred = outputs.argmax(1)
            test_acc += (pred==labels).sum().item() / len(labels)
            test_loss += loss.item()
        test_acc = round(test_acc / len(test_dataloader), config.ndigits)
        test_loss = round(test_loss / len(test_dataloader), config.ndigits)
    
    # wandbに記録
    wandb.log({"train_acc": train_acc, "train_loss": train_loss, "test_acc": test_acc, "test_loss": test_loss})
    print(f"{epoch+1}/{config.epochs}", "train_acc:", train_acc, "train_loss:", train_loss, "test_acc:", test_acc, "test_loss:", test_loss)
wandb.finish()

  0%|          | 0/118 [00:00<?, ?it/s]

1/20 train_acc: 0.90943 train_loss: 0.3129 test_acc: 0.95856 test_loss: 0.14041


  0%|          | 0/118 [00:00<?, ?it/s]

2/20 train_acc: 0.96356 train_loss: 0.12083 test_acc: 0.97657 test_loss: 0.08089


  0%|          | 0/118 [00:00<?, ?it/s]

3/20 train_acc: 0.975 train_loss: 0.0829 test_acc: 0.98005 test_loss: 0.06308


  0%|          | 0/118 [00:00<?, ?it/s]

4/20 train_acc: 0.97981 train_loss: 0.06526 test_acc: 0.98485 test_loss: 0.04907


  0%|          | 0/118 [00:00<?, ?it/s]

5/20 train_acc: 0.98385 train_loss: 0.0494 test_acc: 0.98455 test_loss: 0.04695


  0%|          | 0/118 [00:00<?, ?it/s]

6/20 train_acc: 0.98651 train_loss: 0.04204 test_acc: 0.99101 test_loss: 0.02799


  0%|          | 0/118 [00:00<?, ?it/s]

7/20 train_acc: 0.98835 train_loss: 0.03534 test_acc: 0.99284 test_loss: 0.02283


  0%|          | 0/118 [00:00<?, ?it/s]

8/20 train_acc: 0.99172 train_loss: 0.02552 test_acc: 0.99473 test_loss: 0.01634


  0%|          | 0/118 [00:00<?, ?it/s]

9/20 train_acc: 0.9933 train_loss: 0.01978 test_acc: 0.99375 test_loss: 0.01952


  0%|          | 0/118 [00:00<?, ?it/s]

10/20 train_acc: 0.9922 train_loss: 0.02227 test_acc: 0.99071 test_loss: 0.0267


  0%|          | 0/118 [00:00<?, ?it/s]

11/20 train_acc: 0.99171 train_loss: 0.02338 test_acc: 0.99465 test_loss: 0.01519


  0%|          | 0/118 [00:00<?, ?it/s]

12/20 train_acc: 0.99121 train_loss: 0.02491 test_acc: 0.99455 test_loss: 0.01582


  0%|          | 0/118 [00:00<?, ?it/s]

13/20 train_acc: 0.99432 train_loss: 0.01644 test_acc: 0.99636 test_loss: 0.01138


  0%|          | 0/118 [00:00<?, ?it/s]

14/20 train_acc: 0.99498 train_loss: 0.01587 test_acc: 0.99556 test_loss: 0.01312


  0%|          | 0/118 [00:00<?, ?it/s]

15/20 train_acc: 0.99481 train_loss: 0.01455 test_acc: 0.99316 test_loss: 0.02104


  0%|          | 0/118 [00:00<?, ?it/s]

16/20 train_acc: 0.99377 train_loss: 0.01821 test_acc: 0.991 test_loss: 0.02568


  0%|          | 0/118 [00:00<?, ?it/s]

17/20 train_acc: 0.99371 train_loss: 0.01903 test_acc: 0.99467 test_loss: 0.01682


  0%|          | 0/118 [00:00<?, ?it/s]

18/20 train_acc: 0.99406 train_loss: 0.01879 test_acc: 0.99411 test_loss: 0.01715


  0%|          | 0/118 [00:00<?, ?it/s]

19/20 train_acc: 0.99399 train_loss: 0.01839 test_acc: 0.99586 test_loss: 0.01151


  0%|          | 0/118 [00:00<?, ?it/s]

20/20 train_acc: 0.99524 train_loss: 0.01419 test_acc: 0.99825 test_loss: 0.005


0,1
test_acc,▁▄▅▆▆▇▇▇▇▇▇▇██▇▇▇▇██
test_loss,█▅▄▃▃▂▂▂▂▂▂▂▁▁▂▂▂▂▁▁
train_acc,▁▅▆▇▇▇▇█████████████
train_loss,█▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test_acc,0.99825
test_loss,0.005
train_acc,0.99524
train_loss,0.01419
