In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device(
    "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
    else "cpu"
)
print("device:", device)


device: mps


In [2]:
data_dir = "../data"

transform = transforms.Compose([
    transforms.Pad(2),                 # 28x28 -> 32x32（LeNet-5 必要）
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_ds = datasets.MNIST(
    root=data_dir,
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_ds,
    batch_size=128,
    shuffle=True,
    num_workers=0
)


In [3]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.pool1 = nn.AvgPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool2 = nn.AvgPool2d(2, 2)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = torch.tanh(self.conv1(x))
        x = self.pool1(x)
        x = torch.tanh(self.conv2(x))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.fc3(x)
        return x


In [4]:
model = LeNet5().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.9
)


In [5]:
xb, yb = next(iter(train_loader))
xb = xb.to(device)
yb = yb.to(device)

print("xb:", xb.shape, xb.dtype, xb.device)
print("yb:", yb.shape, yb.dtype, yb.device)

logits = model(xb)
print("logits:", logits.shape, logits.dtype, logits.device)


xb: torch.Size([128, 1, 32, 32]) torch.float32 mps:0
yb: torch.Size([128]) torch.int64 mps:0
logits: torch.Size([128, 10]) torch.float32 mps:0


In [6]:
loss = criterion(logits, yb)
print("loss:", float(loss))


loss: 2.294621467590332


In [7]:
optimizer.zero_grad(set_to_none=True)
loss.backward()

grad = model.conv1.weight.grad
print(
    "conv1 grad:",
    grad.shape,
    "is_nan:", torch.isnan(grad).any().item(),
    "mean_abs:", grad.abs().mean().item()
)


conv1 grad: torch.Size([6, 1, 5, 5]) is_nan: False mean_abs: 0.0040332479402422905


In [8]:
with torch.no_grad():
    w_before = model.fc3.weight.clone()

optimizer.step()

with torch.no_grad():
    w_after = model.fc3.weight
    delta = (w_after - w_before).abs().mean().item()

print("fc3 weight mean abs delta after step:", delta)


fc3 weight mean abs delta after step: 2.167457751056645e-05
