In [1]:
from time import time

import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import torch

from model_mnist import MLP, BinarizedMLP
from model_torch_mnist import QuantizedMLP
# from model_brevitas_mnist import BinarizedMLP
from mnist_tools import train, test
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

import os
import time
import pickle
from math import floor
from datetime import datetime


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
epochs = 10
lr = 1.0

In [3]:
optimizers = {
    "Adam": optim.Adam,
    "AdaMax": optim.Adamax,
    "AdaDelta": optim.Adadelta
}

models = {
    # "Classic": MLP,
    # "Binary": BinarizedMLP,
    'Quantized': QuantizedMLP
}


In [4]:
train_kwargs = {"batch_size": batch_size}
test_kwargs = {"batch_size": batch_size}

if device == "cuda":
    cuda_kwargs = {
        "num_workers": 1,
        "pin_memory": True,
        "shuffle": True
        }

    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


# Dataset configuration
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./mnist', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./mnist', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, **train_kwargs)
test_loader = DataLoader(test_dataset, **test_kwargs)


In [5]:
from recorder import setup_logging
from time import time


results = []

for model_name, model_class in models.items():
    for optimizer_name, optimizer in optimizers.items():
        
        # Set up logging
        timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        log_filename = f"{model_name}_{optimizer_name}_{timestamp}.txt"
        logger = setup_logging(log_filename, './log2/')
               
        model_instance = model_class().to(device)
        optimizer = optimizer(model_instance.parameters(), lr=lr)
        
        scheduler = StepLR(optimizer, step_size=1, gamma=0.7) # default gamma by authors

        for epoch in range(1, epochs+1):

            start_train_time = time()
            
            train_loss, train_accuracy = train(False, model_instance, device, train_loader, optimizer, epoch, 100, logger)
            train_time = time() - start_train_time

            start_test_time = time()
            test_loss, test_accuracy = test(model_instance, device, test_loader, logger)
            test_time = time() - start_train_time
            
            scheduler.step()

            results.append({
                "model_name": model_name,
                "optimizer_name": optimizer_name,
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_accuracy,
                "test_loss": test_loss,
                "test_acc": test_accuracy,
                "epoch_train_time": train_time,
                "epoch_test_time": test_time
            })

df_results = pd.DataFrame(results)
df_results.to_csv("mnist_experiments_results2.csv", index=False)



Test set: Average loss: 370.8387, Accuracy: 8734/10000 (87%)


Test set: Average loss: 1040.3140, Accuracy: 8981/10000 (90%)


Test set: Average loss: 660.2466, Accuracy: 8597/10000 (86%)


Test set: Average loss: 335.8951, Accuracy: 8820/10000 (88%)


Test set: Average loss: 341.8561, Accuracy: 9078/10000 (91%)


Test set: Average loss: 1216.6968, Accuracy: 9258/10000 (93%)


Test set: Average loss: 534.9763, Accuracy: 9350/10000 (94%)


Test set: Average loss: 0.4980, Accuracy: 9342/10000 (93%)


Test set: Average loss: 12.5827, Accuracy: 9432/10000 (94%)


Test set: Average loss: 220.3550, Accuracy: 9469/10000 (95%)


Test set: Average loss: 1.9097, Accuracy: 8148/10000 (81%)


Test set: Average loss: 0.0111, Accuracy: 8580/10000 (86%)


Test set: Average loss: 0.7364, Accuracy: 9032/10000 (90%)


Test set: Average loss: 0.0362, Accuracy: 8991/10000 (90%)

