In [None]:
from time import time

import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import torch

from model_cifar10 import ConvNetBinary, ConvNetClassic
from cifar10_tools import train, test
from datasets import CIFAR10


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 256
epochs = 10
lr = 0.0001


In [None]:
optimizers = {
    "Adam": optim.Adam,
    "AdaMax": optim.Adamax,
    "AdaDelta": optim.Adadelta
}

models = {
    "Classic": ConvNetClassic,
    "Binary": ConvNetBinary
}


In [None]:
train_kwargs = {"batch_size": batch_size}
test_kwargs = {"batch_size": batch_size}

if device == "cuda":
    cuda_kwargs = {
        "num_workers": 1,
        "pin_memory": True,
        "shuffle": True
        }

    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

train_dataset, test_dataset = CIFAR10.get_train_and_test(
    "./cifar10",
    download=True
    )

train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)


In [None]:
results = []

for model_name, model_class in models.items():
    for optimizer_name, optimizer in optimizers.items():
        criterion = nn.CrossEntropyLoss()
        model_inctance = model_class().to(device)
        optimizer = optimizer(model_inctance.parameters(), lr=lr)

        for epoch in range(1, epochs+1):

            start_train_time = time()
            train_loss, train_acc = train(
                model_inctance, optimizer, criterion, train_loader, device, epoch
                )
            train_time = time() - start_train_time

            start_test_time = time()
            test_loss, test_acc = test(
                model_inctance, criterion, test_loader, device
                )
            test_time = time() - start_train_time

            results.append({
                "model_name": model_name,
                "optimizer_name": optimizer_name,
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "test_loss": test_loss,
                "test_acc": test_acc,
                "epoch_train_time": train_time,
                "epoch_test_time": test_time
            })

df_results = pd.DataFrame(results)
df_results.to_csv("cifar10_results.csv", index=False)
