In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [2]:
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class LU_SGD(optim.Optimizer):
    def __init__(self, params, lr=0.01, momentum=0.9, weight_decay=5e-4, alpha=0.5, k=5):
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, alpha=alpha, k=k)
        super(LU_SGD, self).__init__(params, defaults)

        for group in self.param_groups:
            group['slow_params'] = [p.clone().detach() for p in group['params']]

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            alpha = group['alpha']
            k = group['k']
            slow_params = group['slow_params']

            for p, slow in zip(group['params'], slow_params):
                if p.grad is None:
                    continue

                d_p = p.grad
                p.add_(d_p, alpha=-group['lr'])

                if self.state.get(p, {}).get('step', 0) % k == 0:
                    slow.add_(p - slow, alpha=alpha)
                    p.copy_(slow)

                if 'step' not in self.state[p]:
                    self.state[p]['step'] = 0
                self.state[p]['step'] += 1

        return loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = LU_SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4, alpha=0.5, k=5)

def train(model, trainloader, optimizer, criterion, epochs=20):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader):.4f}")

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

train(model, trainloader, optimizer, criterion, epochs=20)
test(model, testloader)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:10<00:00, 15.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch 1, Loss: 2.2156
Epoch 2, Loss: 1.9005
Epoch 3, Loss: 1.6987
Epoch 4, Loss: 1.5654
Epoch 5, Loss: 1.4803
Epoch 6, Loss: 1.4248
Epoch 7, Loss: 1.3746
Epoch 8, Loss: 1.3338
Epoch 9, Loss: 1.2960
Epoch 10, Loss: 1.2590
Epoch 11, Loss: 1.2251
Epoch 12, Loss: 1.1934
Epoch 13, Loss: 1.1635
Epoch 14, Loss: 1.1317
Epoch 15, Loss: 1.1035
Epoch 16, Loss: 1.0738
Epoch 17, Loss: 1.0473
Epoch 18, Loss: 1.0196
Epoch 19, Loss: 0.9946
Epoch 20, Loss: 0.9708
Test Accuracy: 54.51%
