In [1]:
import os
os.chdir('..')

In [2]:
from torchvision.datasets import MNIST
import torch
import torch.nn as nn
import torchvision.transforms as TF
from tqdm.auto import tqdm
from eXNN.NetBayesianization import api, wrap

In [3]:
train_ds = MNIST(root='./.cache', train=True, download=True, 
                 transform=TF.ToTensor()) 
test_ds = MNIST(root='./.cache', train=False, download=False, 
                transform=TF.ToTensor())

In [4]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=36, shuffle=False)

In [5]:
num_classes = 10

In [6]:
model = nn.Sequential(nn.Flatten(), nn.Linear(28*28, 256), 
                       nn.ReLU(), nn.Linear(256, 64), 
                       nn.ReLU(), nn.Linear(64, num_classes))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [7]:
# train
n_epochs = 20
loss_fn = nn.CrossEntropyLoss()
for epoch in tqdm(list(range(n_epochs))):
    for imgs, lbls in train_dl:
        optimizer.zero_grad()
        out = model(imgs)
        loss = loss_fn(out, lbls)
        loss.backward()
        optimizer.step()

  0%|          | 0/20 [00:00<?, ?it/s]

In [8]:
# predict
all_preds = []
all_lbls = []
for imgs, lbls in test_dl:
    with torch.no_grad():
        pred = model(imgs).argmax(dim=1)
    all_preds.append(pred)
    all_lbls.append(lbls)
all_preds = torch.cat(all_preds, dim=0)
all_lbls = torch.cat(all_lbls, dim=0)

In [9]:
# evaluate
print('Accuracy is: ', (all_preds == all_lbls).float().mean().item())

Accuracy is:  0.9819999933242798


In [11]:
# build bayesian model
bayes_model = api.BasicBayesianWrapper(model, 'basic', 0.1, None, None)

In [16]:
# predict
n_iter = 3

all_preds = []
all_lbls = []
for imgs, lbls in test_dl:
    all_lbls.append(lbls)
    all_preds.append(bayes_model.predict(imgs, n_iter))
all_preds = torch.cat(all_preds, dim=0)
all_lbls = torch.cat(all_lbls, dim=0)

In [18]:
# evaluate
print('Accuracy of mean predictions is: ', (all_preds == all_lbls).float().mean().item())

Accuracy of mean predictions is:  0.9627000093460083
