In [1]:
import torch
import torchvision
import numpy as np
import sys

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cuda


In [2]:
model_hyperparams = {
    'lenet5_mnist':{'optimizer':torch.optim.Adam, 
                        'optimizer_params':{'lr':0.001, 'weight_decay':0}, 
                    'batch_size':256,
                    'epoch':25,
                    'ensemble_size':10}
}

# MNIST - Fashion MNIST
### Choose dataset and model, and run the cell.

In [8]:
from models import LeNet5, DeterministicWrapper, EnsembleWrapper
from datasets import MNIST, FashionMNIST
from utils import ECELoss

model = LeNet5().to(device)
hyperparams = model_hyperparams['lenet5_mnist']

######################################################
#model = DeterministicWrapper(model, hyperparams)
model = EnsembleWrapper(model, hyperparams)
######################################################

######################################################
# dataset = MNIST
dataset = FashionMNIST
######################################################

loss_func = torch.nn.CrossEntropyLoss().to(device)

transform = torchvision.transforms.Compose([       
        torchvision.transforms.ToTensor(),         
        torchvision.transforms.Lambda(lambda x: x.to(device))
        ])

train_data = dataset(transform=transform)
test_data =  dataset(train=False, transform=transform)
train_loader = train_data.get_loader(batch_size=hyperparams['batch_size'])
test_loader = test_data.get_loader(batch_size=hyperparams['batch_size'])

for epoch in range(hyperparams['epoch']):
  model.train_epoch(train_loader, loss_func)
  print("Epoch", epoch, "done.")

print("-Training-")
model.predict_epoch(train_loader)
print("-Validation-")
logits, labels = model.predict_epoch(test_loader, return_logits=True)

ECE_loss_func = ECELoss()
ECE_loss = ECE_loss_func(logits, labels).item()
print("Mean Expected Calibration Error: ", ECE_loss / len(test_loader.dataset))

Epoch 0 done.
Epoch 1 done.
Epoch 2 done.
Epoch 3 done.
Epoch 4 done.
Epoch 5 done.
Epoch 6 done.
Epoch 7 done.
Epoch 8 done.
Epoch 9 done.
Epoch 10 done.
Epoch 11 done.
Epoch 12 done.
Epoch 13 done.
Epoch 14 done.
Epoch 15 done.
Epoch 16 done.
Epoch 17 done.
Epoch 18 done.
Epoch 19 done.
Epoch 20 done.
Epoch 21 done.
Epoch 22 done.
Epoch 23 done.
Epoch 24 done.
-Training-
Accuracy:  0.9275666666666667
-Validation-
Accuracy:  0.8969
Mean Expected Calibration Error:  2.7768097817897795e-06
