In [16]:
import torch.nn.functional as F
import torchvision

from optimizers.adamm import *
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn

In [23]:
class SmallModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(3, 9, 3)
        self.fc1 = nn.Linear(9 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [40]:
transform = transforms.Compose([transforms.ToTensor()])

mnist_dataset_train = torchvision.datasets.MNIST('data/mnist/', download=False, train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_dataset_train, batch_size=4)

mnist_dataset_test = torchvision.datasets.MNIST('data/mnist/', download=False, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_dataset_train, batch_size=4)

criterion = nn.CrossEntropyLoss()

nb_epochs = 10

In [41]:
def train(model, optimizer, criterion, nb_epochs, train_loader, test_loader):
    # Heavily inspired from PyTorch tutorial
    train_losses = []
    test_accuracies = []

    running_loss = 0

    for e in range(nb_epochs):

        for i, data in enumerate(train_loader):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            running_loss += loss.item()

            if i % 2000 == 1999:
                train_losses.append(running_loss / 1000)
                print(f'epoch : {e + 1}/{nb_epochs} | train loss : {train_losses[-1]:.4f}')
                running_loss = 0.0

        with torch.no_grad():
            correct_preds = 0
            total_preds = 0

            for inputs, labels in test_loader:
                outputs = model(inputs)

                predictions = torch.argmax(outputs, 1)
                total_preds += labels.size(0)
                correct_preds += (predictions == labels).sum().item()

            test_accuracies.append(correct_preds / total_preds)

    return train_losses, test_accuracies

In [42]:
small_model = SmallModel()

optimizer_pt = optim.Adam(small_model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8, amsgrad=True)

train_losses_pt, test_acc_pt = train(small_model, optimizer_pt, criterion, nb_epochs, train_loader, test_loader)

epoch : 1/10 | train loss : 0.9933
epoch : 1/10 | train loss : 0.3720
epoch : 1/10 | train loss : 0.2507


KeyboardInterrupt: 

In [43]:
small_model = SmallModel()

optimizer_custom = AdaMM(small_model.parameters(), lr=1e-3, beta1=0.9, beta2=0.999, epsilon=1e-8)

train_losses_custom, test_acc_custom = train(small_model, optimizer_custom, criterion, nb_epochs, train_loader, test_loader)

epoch : 1/10 | train loss : nan
epoch : 1/10 | train loss : nan
epoch : 1/10 | train loss : nan
epoch : 1/10 | train loss : nan
epoch : 1/10 | train loss : nan


KeyboardInterrupt: 