In [1]:
import os
os.chdir('..')

In [2]:
from torchvision.datasets import MNIST
import torch
import torch.nn as nn
import torchvision.transforms as TF
from tqdm.auto import tqdm
from NetBayesianization import wrap, api

In [3]:
train_ds = MNIST(root='./.cache', train=True, download=True, 
                 transform=TF.ToTensor()) 
test_ds = MNIST(root='./.cache', train=False, download=False, 
                transform=TF.ToTensor())

In [4]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=36, shuffle=False)

In [5]:
num_classes = 10

In [6]:
model = nn.Sequential(nn.Flatten() ,nn.Linear(28*28, 128), 
                       nn.ReLU(), nn.Linear(128, 64), 
                       nn.ReLU(), nn.Linear(64, num_classes), nn.Softmax(dim=1))
optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
criterion = nn.CrossEntropyLoss()
images, labels = next(iter(train_dl))
images = images.view(images.shape[0], -1)
logps = model(images)
loss = criterion(logps, labels)

In [7]:
# train
n_epochs = 4
for e in range(n_epochs):
    running_loss = 0
    for images, labels in train_dl:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    else:
        print("Epoch {} - Training loss: {}".format(e, running_loss/len(train_dl)))

Epoch 0 - Training loss: 2.25678127800267
Epoch 1 - Training loss: 1.8390895219784549
Epoch 2 - Training loss: 1.7249091420977432
Epoch 3 - Training loss: 1.7108306553668438


In [8]:
correct_count, all_count = 0, 0
for images,labels in test_dl:
  for i in range(len(labels)):
    img = images[i].view(1, 784)
    with torch.no_grad():
        logps = model(img)

    ps = torch.exp(logps)
    probab = list(ps.numpy()[0])
    pred_label = probab.index(max(probab))
    true_label = labels.numpy()[i]
    if(true_label == pred_label):
      correct_count += 1
    all_count += 1

print("Number Of Images Tested =", all_count)
print("\nModel Accuracy =", (correct_count/all_count))

Number Of Images Tested = 10000

Model Accuracy = 0.7578


In [9]:
# build classical bayesian model
bayes_model = api.BasicBayesianWrapper(model, 'basic', 0.1, None, None)

In [10]:
# predict
n_iter = 5
bayes_model.predict(images, n_iter)

{'mean': tensor([[1.7637e-07, 4.9688e-12, 4.8315e-09, 6.0096e-10, 9.9799e-01, 1.6134e-08,
          4.0904e-06, 5.4219e-06, 4.9832e-06, 1.9917e-03],
         [3.0347e-06, 4.6740e-13, 1.4453e-09, 1.8297e-09, 7.2091e-03, 4.8047e-09,
          4.4629e-08, 1.4678e-03, 1.2209e-04, 9.9120e-01],
         [1.6841e-07, 5.7389e-11, 1.5229e-08, 3.6501e-10, 9.9975e-01, 4.3970e-08,
          3.9729e-05, 7.6581e-07, 3.5806e-06, 2.0085e-04],
         [1.4218e-02, 2.1124e-04, 5.2089e-05, 4.1242e-01, 6.0283e-06, 7.3929e-05,
          2.1471e-04, 4.0200e-04, 5.6721e-01, 5.1968e-03],
         [3.1363e-04, 6.5162e-07, 3.2607e-06, 5.5411e-07, 2.8429e-02, 1.3241e-05,
          9.6778e-01, 7.8308e-08, 3.2903e-03, 1.7358e-04],
         [1.2231e-07, 9.6539e-09, 5.5084e-08, 2.0875e-09, 9.9879e-01, 9.3387e-08,
          5.3918e-05, 1.9783e-04, 4.7187e-05, 9.0917e-04],
         [4.6287e-09, 9.9891e-01, 3.6492e-06, 1.1816e-04, 6.8836e-07, 2.0026e-06,
          9.4979e-06, 3.3455e-05, 9.1654e-04, 3.3610e-06],
     

In [11]:
# build bayesian model with beta distibution
bayes_model = api.BasicBayesianWrapper(model, 'beta', None, 0.6, 0.3)

In [12]:
# predict
n_iter = 5
bayes_model.predict(images, n_iter)

{'mean': tensor([[1.6017e-07, 4.0813e-08, 9.0452e-05, 1.9088e-11, 1.0572e-01, 1.5165e-04,
          3.0091e-03, 3.8216e-01, 1.6024e-02, 4.9285e-01],
         [1.1333e-01, 2.5806e-16, 3.6569e-08, 1.5372e-17, 3.2736e-06, 1.6070e-08,
          2.1157e-07, 4.7035e-01, 4.5246e-05, 4.1627e-01],
         [4.9152e-07, 3.5110e-07, 1.0474e-04, 1.4276e-10, 4.3853e-01, 3.4521e-04,
          1.2708e-02, 4.1738e-01, 4.2717e-03, 1.2667e-01],
         [3.3336e-01, 2.1298e-07, 1.1850e-04, 7.4759e-06, 7.3887e-09, 2.9445e-05,
          3.5276e-06, 3.3381e-01, 3.1011e-01, 2.2570e-02],
         [9.7341e-04, 1.9708e-08, 7.7908e-07, 8.2092e-13, 2.6463e-06, 2.8729e-06,
          6.5263e-01, 3.3384e-01, 1.3013e-04, 1.2413e-02],
         [2.9739e-03, 2.8272e-08, 7.7688e-06, 8.1368e-13, 2.2168e-03, 1.3912e-05,
          2.3154e-04, 6.2577e-01, 1.3144e-03, 3.6747e-01],
         [1.8112e-04, 3.2075e-01, 6.6775e-06, 1.3760e-06, 3.3145e-07, 5.5507e-06,
          7.9438e-04, 6.6506e-01, 1.2691e-02, 5.0835e-04],
     