pytorch模型训练的[demo](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Simple_PyTorch_Integration.ipynb)

In [2]:
import wandb
import hydra
import os
from omegaconf import DictConfig, OmegaConf
from hydra import initialize, compose

# 这种方式不能解析命令行参数, 并且也不能使用 multi-run功能
# 使用 @hydra.main 来启用hydra
def load_config(config_path:str, config_name:str):
    with initialize(config_path=config_path, version_base="1.1"):
        return compose(config_name=config_name)

config = load_config(config_path="conf", config_name="default")
print(OmegaConf.to_yaml(config, resolve=True))

project_name: wandb_demo
dataset:
  data_root: data
  ori_path: data/cifar10
  train_path: data/cifar10/train
  test_path: data/cifar10/test
  val_path: data/cifar10/val
  name: cifar10
  epochs: 10
  batch_size: 32
  num_classes: 10
model:
  name: resnet18
optimizer:
  name: adam
  lr: 0.005
  beta1: 0.9
  beta2: 0.999
  weight_decay: 0.0



In [3]:
import torchvision
import torch.nn as nn

def build_model(config:DictConfig):
    if config.model.name == "resnet18":
        model = torchvision.models.resnet18()
    elif config.model.name == "resnet34":
        model = torchvision.models.resnet34()
    else:
        raise NotImplementedError
    
    model.fc = nn.Linear(model.fc.in_features, config.dataset.num_classes)
    
    return model

model = build_model(config)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def build_dataloader(config: DictConfig):
    if config.dataset.name in ["cifar10", "cifar100"]:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408] if config.dataset.name == "cifar100"
                                 else [0.4914, 0.4822, 0.4465],
                                 std=[0.2675, 0.2565, 0.2761] if config.dataset.name == "cifar100"
                                 else [0.2023, 0.1994, 0.2010]),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408] if config.dataset.name == "cifar100"
                                 else [0.4914, 0.4822, 0.4465],
                                 std=[0.2675, 0.2565, 0.2761] if config.dataset.name == "cifar100"
                                 else [0.2023, 0.1994, 0.2010]),
        ])

        DatasetClass = datasets.CIFAR100 if config.dataset.name == "cifar100" else datasets.CIFAR10

        train_dataset = DatasetClass(
            root=config.dataset.ori_path,
            train=True,
            download=True,
            transform=transform_train
        )

        val_dataset = DatasetClass(
            root=config.dataset.ori_path,
            train=False,
            download=True,
            transform=transform_test
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=config.dataset.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=config.dataset.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        return train_loader, val_loader

    else:
        raise ValueError(f"Unsupported dataset: {config.dataset.name}")

train_loader, val_loader = build_dataloader(config)

In [5]:
import torch

def build_optimizer(config:DictConfig, params):
    if config.optimizer.name == "adam":
        return torch.optim.Adam(params=params, lr=config.optimizer.lr, betas=(config.optimizer.beta1, config.optimizer.beta2), weight_decay=config.optimizer.weight_decay)
    else:
        raise NotImplementedError

optimizer = build_optimizer(config, model.parameters())
print(optimizer)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.005
    maximize: False
    weight_decay: 0.0
)


In [6]:
import wandb
run = wandb.init(
    project="torch",
    name=f"{config.model.name}-{config.dataset.name}-{config.optimizer.name}",
    config=OmegaConf.to_container(config, resolve=True),
    tags=["baseline"],
    group="test", 
)

[34m[1mwandb[0m: Currently logged in as: [33mrem1[0m ([33mrem1-opera[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from torch.nn import functional as F

def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            total_loss += loss.item() * inputs.size(0)

            probs = F.softmax(outputs, dim=1)
            all_preds.append(probs.detach().cpu())
            all_labels.append(targets.detach().cpu())

    # 拼接
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # 多分类 AUC
    try:
        auc = roc_auc_score(
            F.one_hot(all_labels, num_classes=all_preds.size(1)).numpy(),
            all_preds.numpy(),
            average="macro",
            multi_class="ovr"
        )
    except ValueError:
        auc = float('nan')  # 某些 batch 可能 label 不全，跳过 AUC

    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, auc

def train(config, model, train_loader, val_loader, optimizer):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    wandb.watch(model, log="all", log_freq=10)

    epochs = config.dataset.epochs

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        all_preds = []
        all_labels = []
        batch_ct = 0
        example_ct = 0
        

        for _, (inputs, targets) in enumerate(train_loader):
            batch_loss = 0.0
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)
            batch_loss += loss.item() * inputs.size(0)
            probs = F.softmax(outputs, dim=1)
            all_preds.append(probs.detach().cpu())
            all_labels.append(targets.detach().cpu())
            example_ct += len(targets)
            batch_ct += 1
            
            if ((batch_ct + 1) % 25) == 0:
                wandb.log({"batch_loss": batch_loss/len(targets)})

        # 拼接预测
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)

        # 计算训练 AUC
        try:
            train_auc = roc_auc_score(
                F.one_hot(all_labels, num_classes=all_preds.size(1)).numpy(),
                all_preds.numpy(),
                average="macro",
                multi_class="ovr"
            )
        except ValueError:
            train_auc = float('nan')

        train_loss = total_loss / len(train_loader.dataset)

        # 验证集
        val_loss, val_auc = evaluate(model, val_loader, device)

        # 打印结果
        print(f"[Epoch {epoch}] "
              f"Train Loss: {train_loss:.4f} | Train AUC: {train_auc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val AUC: {val_auc:.4f}")
        wandb.log({"train_loss":train_loss, "val_loss": val_loss, "train_auc": train_auc, "val_auc": val_auc})
        
train(config, model, train_loader, val_loader, optimizer)



[Epoch 1] Train Loss: 1.8016 | Train AUC: 0.8032 | Val Loss: 1.4777 | Val AUC: 0.8764




[Epoch 2] Train Loss: 1.3797 | Train AUC: 0.8864 | Val Loss: 1.1897 | Val AUC: 0.9209




[Epoch 3] Train Loss: 1.1661 | Train AUC: 0.9184 | Val Loss: 1.0231 | Val AUC: 0.9392




[Epoch 4] Train Loss: 1.0290 | Train AUC: 0.9361 | Val Loss: 0.9579 | Val AUC: 0.9463




[Epoch 5] Train Loss: 0.9348 | Train AUC: 0.9466 | Val Loss: 0.8737 | Val AUC: 0.9573




[Epoch 6] Train Loss: 0.8668 | Train AUC: 0.9538 | Val Loss: 0.7569 | Val AUC: 0.9650




[Epoch 7] Train Loss: 0.8044 | Train AUC: 0.9600 | Val Loss: 0.8389 | Val AUC: 0.9609




[Epoch 8] Train Loss: 0.7647 | Train AUC: 0.9636 | Val Loss: 0.7101 | Val AUC: 0.9698




[Epoch 9] Train Loss: 0.7227 | Train AUC: 0.9672 | Val Loss: 0.6503 | Val AUC: 0.9736




[Epoch 10] Train Loss: 0.6861 | Train AUC: 0.9703 | Val Loss: 0.6602 | Val AUC: 0.9735


In [8]:
run.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
batch_loss,█▇▆▇▆▅▆█▆▄▄▃▅▄▃▄▂▄▃▄▂▃▄▃▁▂▃▃▃▂▂▂▃▁▁▁▃▃▁▄
train_auc,▁▄▆▇▇▇████
train_loss,█▅▄▃▃▂▂▁▁▁
val_auc,▁▄▆▆▇▇▇███
val_loss,█▆▄▄▃▂▃▂▁▁

0,1
batch_loss,0.65278
train_auc,0.97034
train_loss,0.6861
val_auc,0.97352
val_loss,0.66017
