# Wandb Tutorial

## WandBのインストールとセットアップ

wandbのインストール

In [None]:
!pip install wandb

wandbへのログイン

In [None]:
!wandb login

## 実際に使ってみる

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

### モデルの作成とデータセットのロード

In [2]:
# MNIST用のモデルを作成
class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x
    
# Net1よりも複雑なモデルを作成
class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [3]:
import torchvision.datasets as datasets

# MNISTのデータセットを取得
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [4]:
# dataloaderの準備
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

## 学習

### wandbを使わずに学習する場合

In [5]:
# 学習関数の定義
def train(net, trainloader, optimizer, criterion, epochs):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")

In [None]:
# モデル、最適化手法、損失関数の定義
net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

config = {
    'net':net1,
    'trainloader':trainloader,
    'optimizer':optimizer,
    'criterion':criterion,
    'epochs': 10,
}

# モデルを学習
train(**config)

### WandBを使って学習する場合

In [6]:
import wandb
from wandb import AlertLevel

In [13]:
# 学習関数の定義
def train_with_wandb(net, trainloader, optimizer, criterion, epochs):

    wandb.watch(net.fc1, criterion, log="all", log_freq=1)

    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")
        
        wandb.log( {"loss": running_loss/len(trainloader), "epoch": epoch} )

In [None]:
# モデル、最適化手法、損失関数の定義
net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

config_dict = {
    'net':net1,
    'trainloader':trainloader,
    'optimizer':optimizer,
    'criterion':criterion,
    'epochs': 5
}

with wandb.init(project='wandb_tutorial', group='tutorial',name='testrun',config=config_dict):
    
    train_with_wandb(**config_dict)
    
    # 終わったらアラートする
    wandb.alert(
        title='wandb_tutorial',
        text='<@ slack_id > net1の学習が終わりました！',
        level=AlertLevel.INFO
    )

In [None]:
# モデル、最適化手法、損失関数の定義
net2 = Net2()
optimizer = optim.Adam(net2.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

config = {
    'net':net2,
    'trainloader':trainloader,
    'optimizer':optimizer,
    'criterion':criterion,
    'epochs': 10,
}

with wandb.init(project='wandb_tutorial', group='tutorial',config=config,name='testrun'):
    train_with_wandb(**config)
    # 終わったらアラートする
    wandb.alert(
        title='wandb_tutorial',
        text='<@ slack_id > net2の学習が終わりました！',
        level=AlertLevel.INFO
    )