## wandb.sweep
低代码，可视化，分布式自动调参工具

- 低代码：只需要配置sweep.yaml配置文件，或者定义一个配置dict，几乎不用编写调参相关代码
- 可视化：wandb可以实时监控调参过程中每次尝试，并可视化地分析调参任务的目标值分布，超参重要性等
- 分布式：sweep采用类似master-worker的controller-agent架构，controller在wandb的服务器机器上运行，agents在用户机器上运行，controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松是想拿分布式超参搜索

## sweep使用步骤
1. 配置sweep_cofig
    配置调优算法，调优目标，需要优化的超参数列表等
2. 初始化sweep controller
    sweep_id = wandb.sweep(sweep_config)
3. 启动sweep agents
    wandb.agent(sweep_id, function=train)

In [2]:
import os,PIL 
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch 
from torch import nn 
import torchvision 
from torchvision import transforms
import datetime
import wandb 

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m315680524[0m ([33m550w[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from argparse import Namespace

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#初始化参数配置
config = Namespace(
    project_name = 'wandb_demo',
    
    batch_size = 512,
    
    hidden_layer_width = 64,
    dropout_p = 0.1,
    
    lr = 1e-4,
    optim_type = 'Adam',
    
    epochs = 15,
    ckpt_path = 'checkpoint.pt'
)

# 官方的写法
# config = dict(
#     project_name = 'wandb_demo',
    
#     batch_size = 512,
    
#     hidden_layer_width = 64,
#     dropout_p = 0.1,
    
#     lr = 1e-4,
#     optim_type = 'Adam',
    
#     epochs = 15,
#     ckpt_path = 'checkpoint.pt'
# )

## 一、配置sweep config
详细配置文档可以参考：https://docs.wandb.ai/guides/sweeps/define-sweep-configuration

1，选择一个调优算法
Sweep支持如下3种调优算法:

(1)网格搜索：grid. 遍历所有可能得超参组合，只在超参空间不大的时候使用，否则会非常慢。

(2)随机搜索：random. 每个超参数都选择一个随机值，非常有效，一般情况下建议使用。

(3)贝叶斯搜索：bayes. 创建一个概率模型估计不同超参数组合的效果，采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效，但扩展到非常高维度的超参数时效果不好。

In [4]:
sweep_config = {
    'method': 'random'
    }

2，定义调优目标  
设置优化指标，以及优化方向。

sweep agents 通过 wandb.log 的形式向 sweep controller 传递优化目标的值。

In [5]:
metric = {
    'name': 'val_acc',
    'goal': 'maximize'   
    }
sweep_config['metric'] = metric

3，定义超参空间  
超参空间可以分成 固定型，离散型和连续型。

- 固定型：指定 value
- 离散型：指定 values，列出全部候选取值。
- 连续性：需要指定 分布类型 distribution, 和范围 min, max。用于 random 或者 bayes采样。

In [6]:
sweep_config['parameters'] = {}

# 固定不变的超参
sweep_config['parameters'].update({
    'project_name':{'value':'wandb_demo'},
    'epochs': {'value': 10},
    'ckpt_path': {'value':'checkpoint.pt'}})

# 离散型分布超参
sweep_config['parameters'].update({
    'optim_type': {
        'values': ['Adam', 'SGD','AdamW']
        },
    'hidden_layer_width': {
        'values': [16,32,48,64,80,96,112,128]
        }
    })

# 连续型分布超参
sweep_config['parameters'].update({
    
    'lr': {
        'distribution': 'log_uniform_values',
        'min': 1e-6,
        'max': 0.1
      },
    
    'batch_size': {
        'distribution': 'q_uniform',
        'q': 8,
        'min': 32,
        'max': 256,
      },
    
    'dropout_p': {
        'distribution': 'uniform',
        'min': 0,
        'max': 0.6,
      }
})

4，定义剪枝策略 (可选)  
可以定义剪枝策略，提前终止那些没有希望的任务。

In [7]:
sweep_config['early_terminate'] = {
    'type':'hyperband',
    'min_iter':3,
    'eta':2,
    's':3
} #在step=3, 6, 12 时考虑是否剪枝

### 查看所有的sweep config设置

In [8]:
from pprint import pprint
pprint(sweep_config)

{'early_terminate': {'eta': 2, 'min_iter': 3, 's': 3, 'type': 'hyperband'},
 'method': 'random',
 'metric': {'goal': 'maximize', 'name': 'val_acc'},
 'parameters': {'batch_size': {'distribution': 'q_uniform',
                               'max': 256,
                               'min': 32,
                               'q': 8},
                'ckpt_path': {'value': 'checkpoint.pt'},
                'dropout_p': {'distribution': 'uniform', 'max': 0.6, 'min': 0},
                'epochs': {'value': 10},
                'hidden_layer_width': {'values': [16,
                                                  32,
                                                  48,
                                                  64,
                                                  80,
                                                  96,
                                                  112,
                                                  128]},
                'lr': {'distribution': 'log_uniform_

## 二、 初始化 sweep controller

In [9]:
sweep_id = wandb.sweep(sweep_config, project=config.project_name)

# 使用官方创建config的方法
# config = wandb.config
# sweep_id = wandb.sweep(sweep_config, project=config.project_name)

Create sweep with ID: 27niqr67
Sweep URL: https://wandb.ai/550w/wandb_demo/sweeps/27niqr67


## 三、启动 sweep agents

In [10]:
def create_dataloaders(config):
    transform = transforms.Compose([transforms.ToTensor()])
    ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)

    ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))
    dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, 
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val

In [11]:
def create_net(config):
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) 
    net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width,
                                     out_channels=config.hidden_layer_width,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = config.dropout_p))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(config.hidden_layer_width,10))
    return net 

In [12]:
def train_epoch(model,dl_train,optimizer):
    model.train()
    for step, batch in enumerate(dl_train):
        features,labels = batch
        features,labels = features.to(device),labels.to(device)

        preds = model(features)
        loss = nn.CrossEntropyLoss()(preds,labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    return model

In [13]:
def eval_epoch(model,dl_val):
    model.eval()
    accurate = 0
    num_elems = 0
    for batch in dl_val:
        features,labels = batch
        features,labels = features.to(device),labels.to(device)
        with torch.no_grad():
            preds = model(features)
        predictions = preds.argmax(dim=-1)
        accurate_preds =  (predictions==labels)
        num_elems += accurate_preds.shape[0]
        accurate += accurate_preds.long().sum()

    val_acc = accurate.item() / num_elems
    return val_acc

In [14]:
def train(config = config):
    dl_train, dl_val = create_dataloaders(config)
    model = create_net(config); 
    optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)
    #======================================================================
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True)
    model.run_id = wandb.run.id
    #======================================================================
    model.best_metric = -1.0
    for epoch in range(1,config.epochs+1):
        model = train_epoch(model,dl_train,optimizer)
        val_acc = eval_epoch(model,dl_val)
        if val_acc>model.best_metric:
            model.best_metric = val_acc
            torch.save(model.state_dict(),config.ckpt_path)   
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")
        #======================================================================
        wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})
        #======================================================================        
    #======================================================================
    wandb.finish()
    #======================================================================
    return model

# model = train(config)

In [15]:
# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)

[34m[1mwandb[0m: Agent Starting Run: jt9k32xx with config:
[34m[1mwandb[0m: 	batch_size: 112
[34m[1mwandb[0m: 	ckpt_path: checkpoint.pt
[34m[1mwandb[0m: 	dropout_p: 0.5982875694637337
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_layer_width: 112
[34m[1mwandb[0m: 	lr: 0.02367954862131594
[34m[1mwandb[0m: 	optim_type: Adam
[34m[1mwandb[0m: 	project_name: wandb_demo


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


35.6%
Run jt9k32xx errored: RuntimeError('File not found or corrupted.')
[34m[1mwandb[0m: [32m[41mERROR[0m Run jt9k32xx errored: RuntimeError('File not found or corrupted.')





[34m[1mwandb[0m: Agent Starting Run: egevehgv with config:
[34m[1mwandb[0m: 	batch_size: 160
[34m[1mwandb[0m: 	ckpt_path: checkpoint.pt
[34m[1mwandb[0m: 	dropout_p: 0.2886224267737977
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_layer_width: 64
[34m[1mwandb[0m: 	lr: 0.0006043509385761697
[34m[1mwandb[0m: 	optim_type: Adam
[34m[1mwandb[0m: 	project_name: wandb_demo


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


74.2%
Run egevehgv errored: RuntimeError('File not found or corrupted.')
[34m[1mwandb[0m: [32m[41mERROR[0m Run egevehgv errored: RuntimeError('File not found or corrupted.')





[34m[1mwandb[0m: Agent Starting Run: 7dpkz1tu with config:
[34m[1mwandb[0m: 	batch_size: 248
[34m[1mwandb[0m: 	ckpt_path: checkpoint.pt
[34m[1mwandb[0m: 	dropout_p: 0.2976909215017342
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_layer_width: 64
[34m[1mwandb[0m: 	lr: 0.039895343568556715
[34m[1mwandb[0m: 	optim_type: AdamW
[34m[1mwandb[0m: 	project_name: wandb_demo


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: ./mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Using downloaded and verified file: ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw





epoch【1】@2024-01-16 19:26:00 --> val_acc= 50.57%
epoch【2】@2024-01-16 19:26:10 --> val_acc= 56.22%
epoch【3】@2024-01-16 19:26:20 --> val_acc= 61.74%
epoch【4】@2024-01-16 19:26:31 --> val_acc= 64.98%
epoch【5】@2024-01-16 19:26:41 --> val_acc= 69.85%
epoch【6】@2024-01-16 19:26:52 --> val_acc= 74.69%
epoch【7】@2024-01-16 19:27:03 --> val_acc= 79.58%
epoch【8】@2024-01-16 19:27:14 --> val_acc= 83.91%
epoch【9】@2024-01-16 19:27:25 --> val_acc= 86.28%
epoch【10】@2024-01-16 19:27:36 --> val_acc= 88.00%
epoch【11】@2024-01-16 19:27:47 --> val_acc= 88.81%
epoch【12】@2024-01-16 19:27:58 --> val_acc= 89.40%
epoch【13】@2024-01-16 19:28:08 --> val_acc= 89.99%
epoch【14】@2024-01-16 19:28:19 --> val_acc= 91.00%
epoch【15】@2024-01-16 19:28:29 --> val_acc= 91.54%


0,1
best_val_acc,▁▂▃▃▄▅▆▇▇▇█████
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
val_acc,▁▂▃▃▄▅▆▇▇▇█████

0,1
best_val_acc,0.9154
epoch,15.0
val_acc,0.9154


[34m[1mwandb[0m: Agent Starting Run: qrkm5zfy with config:
[34m[1mwandb[0m: 	batch_size: 136
[34m[1mwandb[0m: 	ckpt_path: checkpoint.pt
[34m[1mwandb[0m: 	dropout_p: 0.2670600216380557
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_layer_width: 64
[34m[1mwandb[0m: 	lr: 0.07449015439320002
[34m[1mwandb[0m: 	optim_type: Adam
[34m[1mwandb[0m: 	project_name: wandb_demo
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


epoch【1】@2024-01-16 19:29:25 --> val_acc= 23.72%
epoch【2】@2024-01-16 19:29:36 --> val_acc= 32.92%
epoch【3】@2024-01-16 19:29:47 --> val_acc= 52.07%
epoch【4】@2024-01-16 19:29:57 --> val_acc= 62.72%
epoch【5】@2024-01-16 19:30:08 --> val_acc= 70.96%
epoch【6】@2024-01-16 19:30:18 --> val_acc= 77.38%
epoch【7】@2024-01-16 19:30:29 --> val_acc= 77.82%
epoch【8】@2024-01-16 19:30:39 --> val_acc= 83.41%
epoch【9】@2024-01-16 19:30:50 --> val_acc= 84.46%
epoch【10】@2024-01-16 19:31:00 --> val_acc= 86.41%
epoch【11】@2024-01-16 19:31:11 --> val_acc= 88.22%
epoch【12】@2024-01-16 19:31:22 --> val_acc= 89.06%
epoch【13】@2024-01-16 19:31:32 --> val_acc= 90.06%
epoch【14】@2024-01-16 19:31:43 --> val_acc= 90.91%
epoch【15】@2024-01-16 19:31:54 --> val_acc= 91.55%


0,1
best_val_acc,▁▂▄▅▆▇▇▇▇▇█████
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
val_acc,▁▂▄▅▆▇▇▇▇▇█████

0,1
best_val_acc,0.9155
epoch,15.0
val_acc,0.9155


[34m[1mwandb[0m: Agent Starting Run: cejsp6h9 with config:
[34m[1mwandb[0m: 	batch_size: 248
[34m[1mwandb[0m: 	ckpt_path: checkpoint.pt
[34m[1mwandb[0m: 	dropout_p: 0.354146391818585
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	hidden_layer_width: 16
[34m[1mwandb[0m: 	lr: 0.0009038198026843288
[34m[1mwandb[0m: 	optim_type: Adam
[34m[1mwandb[0m: 	project_name: wandb_demo
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


epoch【1】@2024-01-16 19:32:47 --> val_acc= 36.19%
epoch【2】@2024-01-16 19:32:57 --> val_acc= 48.47%
epoch【3】@2024-01-16 19:33:07 --> val_acc= 53.59%
epoch【4】@2024-01-16 19:33:18 --> val_acc= 56.99%
epoch【5】@2024-01-16 19:33:28 --> val_acc= 62.71%
epoch【6】@2024-01-16 19:33:39 --> val_acc= 73.01%
epoch【7】@2024-01-16 19:33:49 --> val_acc= 77.13%
epoch【8】@2024-01-16 19:33:59 --> val_acc= 81.52%
epoch【9】@2024-01-16 19:34:09 --> val_acc= 84.73%
epoch【10】@2024-01-16 19:34:19 --> val_acc= 86.52%
epoch【11】@2024-01-16 19:34:30 --> val_acc= 87.71%
epoch【12】@2024-01-16 19:34:40 --> val_acc= 88.85%
epoch【13】@2024-01-16 19:34:51 --> val_acc= 89.36%
epoch【14】@2024-01-16 19:35:01 --> val_acc= 90.07%
epoch【15】@2024-01-16 19:35:12 --> val_acc= 90.60%


0,1
best_val_acc,▁▃▃▄▄▆▆▇▇▇█████
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
val_acc,▁▃▃▄▄▆▆▇▇▇█████

0,1
best_val_acc,0.90604
epoch,15.0
val_acc,0.90604


## 四，调参可视化和跟踪
1. 平行坐标系图
2. 超参重要性图