In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from agc_optims.optim import SGD_AGC, Adam_AGC, AdamW_AGC, RMSprop_AGC
from torch.optim import SGD, Adam, AdamW, RMSprop
import time


"""
This example was taken from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
You can change the optimizer and test on the CIFAR 10 dataset

"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

batch_size = 256

trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net().to(device=device)
criterion = nn.CrossEntropyLoss()

"""
Change the optimizer to test AGC
"""
optimizer = Adam_AGC(net.parameters(), lr=0.001, clipping=0.16)

for epoch in range(10):  

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        
        inputs, labels = data
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
            
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device=device)
            labels = labels.to(device=device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Epoch {epoch + 1}: Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))
print('Finished Training')


Files already downloaded and verified
Files already downloaded and verified
Epoch 1: Accuracy of the network on the 10000 test images: 45 %
Epoch 2: Accuracy of the network on the 10000 test images: 51 %
Epoch 3: Accuracy of the network on the 10000 test images: 53 %
Epoch 4: Accuracy of the network on the 10000 test images: 55 %
Epoch 5: Accuracy of the network on the 10000 test images: 58 %
Epoch 6: Accuracy of the network on the 10000 test images: 58 %
Epoch 7: Accuracy of the network on the 10000 test images: 60 %
Epoch 8: Accuracy of the network on the 10000 test images: 61 %
Epoch 9: Accuracy of the network on the 10000 test images: 61 %
Epoch 10: Accuracy of the network on the 10000 test images: 62 %
Finished Training
