<a href="https://colab.research.google.com/github/IMOKURI/wandb-demo/blob/main/wandb_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WandB 使ってみた

## 🚀 準備

In [1]:
import os

if os.path.exists('init.txt'):
    print("Already initialized.")

else:
    !pip install -q wandb
    !touch init.txt

Already initialized.


In [2]:
import math
import random

import pandas as pd
import wandb

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
netrc = "/content/drive/MyDrive/.netrc"
!cp -f {netrc} ~/
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mimokuri[0m (use `wandb login --relogin` to force relogin)


## 🚀 パラメータ

In [5]:
class Config:
    wandb_entity = "imokuri"
    wandb_project = "demo"

    train = False
    inference = False
    tuning = True

    debug = True

In [6]:
if Config.train:
    wandb_job_type = "training"

elif Config.inference:
    wandb_job_type = "inference"

elif Config.tuning:
    wandb_job_type = "tuning"

In [7]:
config_defaults = {
    "general": {
        "seed": 440,
        "n_class": 1,
    },
    "training": {
        "n_fold": 5,
        "epochs": 100,
        "gradient_accumulation_steps": 1,
        "max_grad_norm": 1000,
        "data_loader": {
            "batch_size": 16,
            "num_workers": 4,
        },
        "model": {
            "name": "",
            "dropout": 0.1,
        },
        "optimizer": {
            "name": "Adam",
            "scheduler": "CosineAnnealingWarmRestarts",
            "lr": 2e-5,
            "min_lr": 1e-5,
            "weight_decay": 1e-6,
        },
        "criterion": {
            "name": "BCEWithLogitsLoss",
        },
        "best_model_choice": "loss",  # "score",
    },
    "inference": {
        "ids": [
            "2ty1mwxc",
        ],
        "artifacts": [
            "sin-artifact:v0",
        ],
    },
    "tuning": {
        "count": 5,
    }
}

In [8]:
# ハイパーパラメータチューニング用

config_sweep = {
    "name": "sin-sweep",
    "method": "random",
    "metric": {
        "name": "sum_sin",
        "goal": "maximize",
    },
    "parameters": {
        "epochs": {
            "values": [10, 50, 100],
        },
        "dropout": {
            "min": 0.1,
            "max": 0.5,
        },
    },
}

## 🚀 初期化

In [9]:
if Config.train or Config.inference:
    if Config.debug:
        run = wandb.init(entity=Config.wandb_entity, project=Config.wandb_project, config=config_defaults, mode="disabled")
    else:
        run = wandb.init(entity=Config.wandb_entity, project=Config.wandb_project, config=config_defaults, job_type=wandb_job_type, save_code=True)

In [10]:
if Config.train or Config.inference:
    config = wandb.config

In [11]:
# 設定値へのアクセス

if Config.train or Config.inference:
    print(config["general"]["seed"])

## 🚀 データ保存 (学習)

In [12]:
# 数値の記録
# サマリ値の記録

if Config.train:
    data = []
    sum_sin = 0

    for i in range(config["training"]["epochs"]):
        sin = math.sin(math.radians(i * 15))
        data.append([i * 15, sin])
        sum_sin += sin
        wandb.log({
            "epoch": i + 1,
            "sin": sin,
            "sum_sin": sum_sin,
        })

    wandb.run.summary["final_sum_sin"] = sum_sin

In [13]:
# グラフの記録

if Config.train:
    table = wandb.Table(data=data, columns = ["rad", "sin"])
    wandb.log({"sin_graph" : wandb.plot.line(table, "rad", "sin", title="Sin Graph")})

In [14]:
# ファイルの保存

if Config.train:
    df = pd.DataFrame(data, columns=["rad", "sin"])
    df.to_csv("sin.csv", index=False)

    wandb.save('sin.csv')

In [15]:
# データセットのバージョン管理

if Config.train:
    artifact = wandb.Artifact('sin-artifact', type='dataset')
    artifact.add_file('sin.csv')
    run.log_artifact(artifact)

## 🚀 データ利用 (推論)

In [16]:
# ファイルのダウンロード

if Config.inference:
    api = wandb.Api()

    for m, run_id in enumerate(config["inference"]["ids"]):
        if not os.path.exists(run_id):
            os.makedirs(run_id)

        run_path = f"{Config.wandb_entity}/{Config.wandb_project}/{run_id}"
        run = api.run(run_path)

        try:
            run.file("sin.csv").download(run_id)
        except wandb.CommError:
            print(f"Already downloaded. run_id: {run_id}")

        df = pd.read_csv(f"{run_id}/sin.csv")

In [17]:
# データセットの利用

if Config.inference:
    api = wandb.Api()

    for m, name_version in enumerate(config["inference"]["artifacts"]):
        dir_name = name_version.replace(":", "-")
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)

        artifact_path = f"{Config.wandb_entity}/{Config.wandb_project}/{name_version}"
        artifact = api.artifact(artifact_path)

        artifact.download(dir_name)

        df = pd.read_csv(f"{dir_name}/sin.csv")

## 🚀 パラメータチューニング

In [18]:
if Config.tuning:
    sweep_id = wandb.sweep(entity=Config.wandb_entity, project=Config.wandb_project, sweep=config_sweep)

Create sweep with ID: r8u8hiwy
Sweep URL: https://wandb.ai/imokuri/demo/sweeps/r8u8hiwy


In [19]:
if Config.tuning:
    def train():
        with wandb.init() as run:
            config = wandb.config
            sum_sin = 0
            for i in range(config["epochs"]):
                if random.random() > config["dropout"]:
                    sin = math.sin(math.radians(i * 15))
                    sum_sin += sin
                wandb.log({"sum_sin": sum_sin})

In [20]:
if Config.tuning:
    wandb.agent(sweep_id, function=train, count=config_defaults["tuning"]["count"])

[34m[1mwandb[0m: Agent Starting Run: i1dse6ms with config:
[34m[1mwandb[0m: 	dropout: 0.2754490497560308
[34m[1mwandb[0m: 	epochs: 50
[34m[1mwandb[0m: Currently logged in as: [33mimokuri[0m (use `wandb login --relogin` to force relogin)


VBox(children=(Label(value=' 0.03MB of 0.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
sum_sin,2.4568
_runtime,3.0
_timestamp,1631240080.0
_step,49.0


0,1
sum_sin,▁▁▂▂▃▃▄▄▅▆▆▅▅▄▄▄▄▄▃▃▃▃▃▄▆▆▇█████▇▆▆▅▄▃▃▃
_runtime,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_timestamp,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: k5h4a9oq with config:
[34m[1mwandb[0m: 	dropout: 0.18485914466822417
[34m[1mwandb[0m: 	epochs: 100


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
sum_sin,-1.00731
_runtime,2.0
_timestamp,1631240088.0
_step,99.0


0,1
sum_sin,▄▄▅▆███▆▃▂▂▃▅▇█▇▆▃▂▁▂▄▅▇▇▆▄▂▁▁▂▃▆▇▆▅▃▂▁▃
_runtime,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_timestamp,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: ds019pgm with config:
[34m[1mwandb[0m: 	dropout: 0.45927096477979135
[34m[1mwandb[0m: 	epochs: 50


VBox(children=(Label(value=' 0.05MB of 0.05MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
sum_sin,-0.38927
_runtime,3.0
_timestamp,1631240097.0
_step,49.0


0,1
sum_sin,▂▂▃▄▅▅▅▆▆▇▇▆▅▅▄▄▄▃▂▂▂▃▃▄▅▆▇▇███▇▆▄▄▃▁▁▁▁
_runtime,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_timestamp,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: 1kct2ymt with config:
[34m[1mwandb[0m: 	dropout: 0.34213844417708505
[34m[1mwandb[0m: 	epochs: 50


VBox(children=(Label(value=' 0.06MB of 0.06MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
sum_sin,-1.92152
_runtime,2.0
_timestamp,1631240105.0
_step,49.0


0,1
sum_sin,▃▃▄▄▄▅▆▇█████▇▆▅▄▃▃▃▃▃▃▃▅▆▇▇▇▇▇▇▅▅▄▃▂▁▁▁
_runtime,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_timestamp,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: Agent Starting Run: 3id7nexo with config:
[34m[1mwandb[0m: 	dropout: 0.21813633479254452
[34m[1mwandb[0m: 	epochs: 50


VBox(children=(Label(value=' 0.07MB of 0.07MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
sum_sin,-1.51411
_runtime,3.0
_timestamp,1631240115.0
_step,49.0


0,1
sum_sin,▃▃▃▄▅▆▆▇█████▇▆▅▃▂▂▂▂▂▂▃▅▅▆▇▇▇▇▇▅▄▃▂▂▁▁▁
_runtime,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_timestamp,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


## 🚀

In [21]:
wandb.finish()