# zero_grad(set_to_none=True) の効果

In [1]:
import random
import numpy as np
import time
from contextlib import contextmanager
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast

In [2]:
DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda:2


In [3]:
# ユーティリティ関数を定義
def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)




TIMINGS = defaultdict(list)

@contextmanager
def timed(label: str, sync_cuda: bool = True, record: bool = True, echo: bool = True):
    """
    時間を計測するためのユーティリティ関数
    使い方としては、以下のように使う。
    ```
    with timed("表示したい文字列"):
        func()
    ```
    record=True のとき、経過時間（ms）を TIMINGS[label] に保存します。
    echo=True のとき、逐次 print も行います。
    """
    # ======
    # with句に入った瞬間の処理
    if sync_cuda and torch.cuda.is_available():
        torch.cuda.synchronize()
    t0 = time.perf_counter()

    # ======
    # withブロック内の処理を実行する
    yield

    # ======
    # with句から抜ける直前の処理
    if sync_cuda and torch.cuda.is_available():
        torch.cuda.synchronize()
    t1 = time.perf_counter()
    elapsed_ms = (t1 - t0) * 1000
    if echo:
        print(f"{label}: {elapsed_ms:.2f} ms")
    if record:
        TIMINGS[label].append(elapsed_ms)


def reset_timings():
    """保存した全ての計測値をリセットする。"""
    TIMINGS.clear()


def timing_summary(labels: list[str] | None = None) -> dict:
    """
    保存された計測の要約統計量を返す。
    - labels が指定された場合、そのラベルのみ集計。
    - 戻り値は {label: {count, mean_ms, std_ms, min_ms, max_ms}}。
    """
    items = TIMINGS.items()
    if labels is not None:
        items = [(k, TIMINGS[k]) for k in labels if k in TIMINGS]

    summary = {}
    for k, vals in items:
        arr = np.asarray(vals, dtype=np.float64)
        n = int(arr.size)
        mean = float(arr.mean()) if n > 0 else 0.0
        std = float(arr.std(ddof=1)) if n > 1 else 0.0
        vmin = float(arr.min()) if n > 0 else 0.0
        vmax = float(arr.max()) if n > 0 else 0.0
        summary[k] = {
            "count": n,
            "mean_ms": mean,
            "std_ms": std,
            "min_ms": vmin,
            "max_ms": vmax,
        }
    return summary


def print_timing_summary(labels: list[str] | None = None) -> None:
    """保存された計測の要約統計量を読みやすく表示する。"""
    summary = timing_summary(labels)
    if not summary:
        print("No timings recorded.")
        return
    for k, v in summary.items():
        print(
            f"{k}: n={v['count']}, mean={v['mean_ms']:.2f} ms, "
            f"std={v['std_ms']:.2f} ms, min={v['min_ms']:.2f} ms, max={v['max_ms']:.2f} ms"
        )


set_seed()

In [4]:
# ハイパラの設定
NUM_EPOCHS = 5
BATCH_SIZE = 128
LEARNING_RATE = 0.01

In [5]:
# CIFAR-10のクラスラベル
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


def setup_transform() -> tuple[transforms.Compose, transforms.Compose]:
    """
    CIFAR-10 のデータセットの前処理を定義する。
    """
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    return train_transform, test_transform


def setup_dataset(dataset_path:str="./data")->tuple[Dataset, Dataset]:
    """
    CIFAR-10 のデータセットをダウンロードし、DataLoader を返す。
    """
    train_transform, test_transform = setup_transform()

    train_dataset = torchvision.datasets.CIFAR10(
        root=dataset_path, train=True, download=True, transform=train_transform
    )

    test_dataset = torchvision.datasets.CIFAR10(
        root=dataset_path, train=False, download=True, transform=test_transform
    )

    return train_dataset, test_dataset


def setup_dataloader(is_pin: bool = False, dataset_path:str="./data") -> tuple[DataLoader, DataLoader]:
    """
    CIFAR-10 のデータセットを DataLoader として返す。
    """
    train_dataset, test_dataset = setup_dataset(dataset_path)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=is_pin
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=is_pin
    )

    return train_loader, test_loader


def setup_elements():
    model = torchvision.models.resnet50(weights=None)

    # CIFAR-10は画像サイズが32x32と小さいため、最初の畳み込み層とプーリング層を調整
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity() # MaxPool層を無効化

    # ResNet-50の最終層（全結合層）をCIFAR-10の10クラス分類用に変更
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)

    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)

    # 学習率をスケジューリング（例：10エポック毎に学習率を0.1倍にする）
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    return model, criterion, optimizer, scheduler

In [None]:
class AverageMetric:
    def __init__(self):
        self.sum = 0.0
        self.count = 0

    def update(self, batch_avg: float, n: int = 1):
        self.sum += batch_avg * n
        self.count += n

    @property
    def avg(self) -> float:
        return self.sum / self.count if self.count > 0 else 0.0


def train_1epoch(
    model,
    criterion,
    optimizer,
    scheduler,
    train_loader,
):
    metrics = {
        "loss": AverageMetric(),
        "acc": AverageMetric(),
    }
    model.train()
    for idx, (Xs, ys) in enumerate(train_loader):
        Xs, ys = Xs.to(DEVICE), ys.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(Xs)
        loss = criterion(outputs, ys)
        loss.backward()

        if hasattr(optimizer, "step"):
            optimizer.step()
        if hasattr(scheduler, "step"):
            scheduler.step()

        # Update metrics
        num_samples = ys.size(0)
        metrics["loss"].update(loss.item(), n=num_samples)
        preds = outputs.argmax(dim=1)
        metrics["acc"].update((preds == ys).float().mean().item(), n=num_samples)

    return metrics


def train_1epoch_with_nograd(
    model,
    criterion,
    optimizer,
    scheduler,
    train_loader,
):
    """zero_grad()で、オプション`set_to_none=True`を使用して勾配をNoneにする"""
    metrics = {
        "loss": AverageMetric(),
        "acc": AverageMetric(),
    }
    model.train()
    for idx, (Xs, ys) in enumerate(train_loader):
        Xs, ys = Xs.to(DEVICE), ys.to(DEVICE)
        optimizer.zero_grad(set_to_none=True) # ここで勾配をゼロにする.
        outputs = model(Xs)
        loss = criterion(outputs, ys)
        loss.backward()

        if hasattr(optimizer, "step"):
            optimizer.step()
        if hasattr(scheduler, "step"):
            scheduler.step()

        # Update metrics
        num_samples = ys.size(0)
        metrics["loss"].update(loss.item(), n=num_samples)
        preds = outputs.argmax(dim=1)
        metrics["acc"].update((preds == ys).float().mean().item(), n=num_samples)

    return metrics


def eval(
    model,
    criterion,
    test_loader,
):
    model.eval()
    metrics = {
        "loss": AverageMetric(),
        "acc": AverageMetric(),
    }
    with torch.no_grad():
        for idx, (Xs, ys) in enumerate(test_loader):
            Xs, ys = Xs.to(DEVICE), ys.to(DEVICE)
            outputs = model(Xs)
            loss = criterion(outputs, ys)

            # Update metrics
            num_samples = ys.size(0)
            metrics["loss"].update(loss.item(), n=num_samples)
            preds = outputs.argmax(dim=1)
            metrics["acc"].update((preds == ys).float().mean().item(), n=num_samples)

    return metrics

In [7]:
# 通常の学習条件で実行速度を計測する
reset_timings()

train_loader, test_loader = setup_dataloader(is_pin=False)
model, criterion, optimizer, scheduler = setup_elements()

for epoch in range(NUM_EPOCHS):
    with timed("noraml_training", echo=True):
        train_1epoch(
            model,
            criterion,
            optimizer,
            scheduler,
            train_loader,
        )
    print()



metrics = eval(
    model,
    criterion,
    test_loader,
)
print(f"loss: {metrics['loss'].avg}, acc: {metrics['acc'].avg}")

# 要約を表示
print_timing_summary(["noraml_training"])

noraml_training: 78488.78 ms

noraml_training: 76510.33 ms

noraml_training: 77398.08 ms

noraml_training: 76064.51 ms

noraml_training: 75450.02 ms

loss: 2.965869669342041, acc: 0.1271
noraml_training: n=5, mean=76782.34 ms, std=1189.06 ms, min=75450.02 ms, max=78488.78 ms


In [8]:
# オプション`set_to_none=True`条件で実行速度を計測する
reset_timings()

train_loader, test_loader = setup_dataloader(is_pin=False)
model, criterion, optimizer, scheduler = setup_elements()

for epoch in range(NUM_EPOCHS):
    with timed("set_to_none=True", echo=True):
        train_1epoch_with_nograd(
            model,
            criterion,
            optimizer,
            scheduler,
            train_loader,
        )
    print()


metrics = eval(
    model,
    criterion,
    test_loader,
)
print(f"loss: {metrics['loss'].avg}, acc: {metrics['acc'].avg}")

# 要約を表示
print_timing_summary(["set_to_none=True"])

set_to_none=True: 75386.75 ms

set_to_none=True: 76258.83 ms

set_to_none=True: 75747.22 ms

set_to_none=True: 75894.83 ms

set_to_none=True: 75559.26 ms

loss: 3.168150677108765, acc: 0.1341
set_to_none=True: n=5, mean=75769.38 ms, std=334.05 ms, min=75386.75 ms, max=76258.83 ms
