In [6]:
from time import time

import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import torch

from model_mnist import MLP, BinarizedMLP
from mnist_tools import train, test
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

import os
import time
import pickle
from math import floor
from datetime import datetime


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
epochs = 10
lr = 1.0

In [8]:
optimizers = {
    "Adam": optim.Adam,
    "AdaMax": optim.Adamax,
    "AdaDelta": optim.Adadelta
}

models = {
    "Classic": MLP,
    "Binary": BinarizedMLP
}


In [9]:
train_kwargs = {"batch_size": batch_size}
test_kwargs = {"batch_size": batch_size}

if device == "cuda":
    cuda_kwargs = {
        "num_workers": 1,
        "pin_memory": True,
        "shuffle": True
        }

    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


# Dataset configuration
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./mnist', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./mnist', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, **train_kwargs)
test_loader = DataLoader(test_dataset, **test_kwargs)


In [10]:
from recorder import setup_logging
from time import time


results = []

for model_name, model_class in models.items():
    for optimizer_name, optimizer in optimizers.items():
        
        # Set up logging
        timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        log_filename = f"{model_name}_{optimizer_name}_{timestamp}.txt"
        logger = setup_logging(log_filename)
               
        model_instance = model_class().to(device)
        optimizer = optimizer(model_instance.parameters(), lr=lr)
        
        scheduler = StepLR(optimizer, step_size=1, gamma=0.7) # default gamma by authors

        for epoch in range(1, epochs+1):

            start_train_time = time()
            
            train_loss, train_accuracy = train(False, model_instance, device, train_loader, optimizer, epoch, 100, logger)
            train_time = time() - start_train_time

            start_test_time = time()
            test_loss, test_accuracy = test(model_instance, device, test_loader, logger)
            test_time = time() - start_train_time
            
            scheduler.step()

            results.append({
                "model_name": model_name,
                "optimizer_name": optimizer_name,
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_accuracy,
                "test_loss": test_loss,
                "test_acc": test_accuracy,
                "epoch_train_time": train_time,
                "epoch_test_time": test_time
            })

df_results = pd.DataFrame(results)
df_results.to_csv("mnist_experiment_results.csv", index=False)



Test set: Average loss: 10302.5102, Accuracy: 7587/10000 (76%)


Test set: Average loss: 16688.1131, Accuracy: 7517/10000 (75%)


Test set: Average loss: 10570.4313, Accuracy: 8574/10000 (86%)


Test set: Average loss: 3673.4517, Accuracy: 9262/10000 (93%)


Test set: Average loss: 117.8629, Accuracy: 9219/10000 (92%)


Test set: Average loss: 1059.3584, Accuracy: 9367/10000 (94%)


Test set: Average loss: 527.3480, Accuracy: 9396/10000 (94%)


Test set: Average loss: 464.3730, Accuracy: 9445/10000 (94%)


Test set: Average loss: 777.9949, Accuracy: 9421/10000 (94%)


Test set: Average loss: 250.4390, Accuracy: 9476/10000 (95%)


Test set: Average loss: 0.1544, Accuracy: 8201/10000 (82%)


Test set: Average loss: 1.9532, Accuracy: 8915/10000 (89%)


Test set: Average loss: 0.1277, Accuracy: 9143/10000 (91%)


Test set: Average loss: 0.1458, Accuracy: 9291/10000 (93%)


Test set: Average loss: 0.1681, Accuracy: 9326/10000 (93%)


Test set: Average loss: 0.1480, Accuracy: 9410/10000 (94