In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch as th
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
from torch.nn import functional as F
from tqdm import tqdm
import os

from utils import data, metrics
import Frequentist_main as FCNN
import Bayesian_main as BCNN
from Bayesian.BayesianCNN import BBBAlexNet
from Frequentist.FrequentistCNN import AlexNet

In [2]:
# Set the device
device = th.device("cuda" if th.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
# Set the parameters
priors={
    'prior_mu': 0,
    'prior_sigma': 0.1,
    'posterior_mu_initial': (0, 0.1),  # (mean, std) normal_
    'posterior_rho_initial': (-5, 0.1),  # (mean, std) normal_
}

n_epochs = 100
lr_start = 0.001
num_workers = 4
valid_size = 0.2
batch_size = 256
beta_type = "Blundell" 

In [4]:
# Dataset and Dataloader
c10_trainset, c10_testset, c10_inputs, c10_outputs = data.getDataset('CIFAR10')
c10_train_loader, c10_valid_loader, c10_test_loader = data.getDataloader(
    c10_trainset, c10_testset, valid_size, batch_size, num_workers)

c100_trainset, c100_testset, c100_inputs, c100_outputs = data.getDataset('CIFAR100')
c100_train_loader, c100_valid_loader, c100_test_loader = data.getDataloader(
    c100_trainset, c100_testset, valid_size, batch_size, num_workers)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [None]:
# BayesianCNN with softplus on CIFAR10
bc10_sp_net = BBBAlexNet(c10_outputs, c10_inputs, priors, activation_type='softplus').to(device)
bc10_sp_criterion = metrics.ELBO(len(c10_trainset)).to(device)
bc10_sp_optimizer = Adam(bc10_sp_net.parameters(), lr=lr_start)
bc10_sp_lr_sched = lr_scheduler.ReduceLROnPlateau(bc10_sp_optimizer, patience=6, verbose=True)
bc10_sp_valid_loss_max = np.Inf

ckpt_name = 'Bayesian/Models/bc10_sp.pth'
if os.path.isfile(ckpt_name):
    checkpoint = th.load(ckpt_name)
    bc10_sp_net.load_state_dict(checkpoint['model_state_dict'])
    bc10_sp_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    bc10_sp_lr_sched.load_state_dict(checkpoint['scheduler_state_dict'])
    bc10_sp_valid_loss_max = checkpoint['valid_loss_max']
    start_epoch = checkpoint['epoch'] + 1  
    print('Model loaded from {}'.format(ckpt_name))
else:
    start_epoch = 0 

for epoch in tqdm(range(start_epoch, n_epochs)):  # loop over the dataset multiple times

    bc10_sp_train_loss, bc10_sp_train_acc, bc10_sp_train_kl = BCNN.train_model(bc10_sp_net, bc10_sp_optimizer, bc10_sp_criterion, c10_train_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc10_sp_valid_loss, bc10_sp_valid_acc = BCNN.validate_model(bc10_sp_net, bc10_sp_criterion, c10_valid_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc10_sp_lr_sched.step(bc10_sp_valid_loss)

    # save model if validation accuracy has increased
    if bc10_sp_valid_loss <= bc10_sp_valid_loss_max:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            bc10_sp_valid_loss_max, bc10_sp_valid_loss))
        th.save({
            'model_state_dict': bc10_sp_net.state_dict(),
            'optimizer_state_dict': bc10_sp_optimizer.state_dict(),
            'scheduler_state_dict': bc10_sp_lr_sched.state_dict(),
            'valid_loss_max': bc10_sp_valid_loss,
            'epoch': epoch
        }, ckpt_name)
        bc10_sp_valid_loss_max = bc10_sp_valid_loss

    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
        epoch, bc10_sp_train_loss, bc10_sp_train_acc, bc10_sp_valid_loss, bc10_sp_valid_acc, bc10_sp_train_kl))

# After all epochs are complete, evaluate the model on the test set
bc10_sp_test_loss, bc10_sp_test_acc = BCNN.validate_model(bc10_sp_net, bc10_sp_criterion, c10_test_loader, num_ens=1, beta_type=beta_type)
print('Test Loss: {:.4f} \tTest Accuracy: {:.4f}'.format(bc10_sp_test_loss, bc10_sp_test_acc))

  1%|          | 1/100 [03:03<5:03:32, 183.97s/it]

Validation loss decreased (inf --> 23235517.065625).  Saving model ...
Epoch: 0 	Training Loss: 3641796.2032 	Training Accuracy: 0.1523 	Validation Loss: 23235517.0656 	Validation Accuracy: 0.2140 	train_kl_div: 463689881.6815


  2%|▏         | 2/100 [06:07<4:59:40, 183.47s/it]

Validation loss decreased (23235517.065625 --> 21709061.231250).  Saving model ...
Epoch: 1 	Training Loss: 3038038.8322 	Training Accuracy: 0.2473 	Validation Loss: 21709061.2312 	Validation Accuracy: 0.2790 	train_kl_div: 434250893.8599


  3%|▎         | 3/100 [09:10<4:56:36, 183.47s/it]

Validation loss decreased (21709061.231250 --> 20215212.049023).  Saving model ...
Epoch: 2 	Training Loss: 2835248.8125 	Training Accuracy: 0.3124 	Validation Loss: 20215212.0490 	Validation Accuracy: 0.3461 	train_kl_div: 404486737.5287


  4%|▍         | 4/100 [12:13<4:52:55, 183.08s/it]

Validation loss decreased (20215212.049023 --> 18863587.385352).  Saving model ...
Epoch: 3 	Training Loss: 2642932.6058 	Training Accuracy: 0.3423 	Validation Loss: 18863587.3854 	Validation Accuracy: 0.3604 	train_kl_div: 377296797.5541


  5%|▌         | 5/100 [15:16<4:49:59, 183.16s/it]

Validation loss decreased (18863587.385352 --> 17659680.585742).  Saving model ...
Epoch: 4 	Training Loss: 2468921.6793 	Training Accuracy: 0.3665 	Validation Loss: 17659680.5857 	Validation Accuracy: 0.3467 	train_kl_div: 352984673.4268


  6%|▌         | 6/100 [18:21<4:47:46, 183.69s/it]

Validation loss decreased (17659680.585742 --> 16575367.776563).  Saving model ...
Epoch: 5 	Training Loss: 2314096.9820 	Training Accuracy: 0.3808 	Validation Loss: 16575367.7766 	Validation Accuracy: 0.4031 	train_kl_div: 331269974.4204


  7%|▋         | 7/100 [21:24<4:44:50, 183.77s/it]

Validation loss decreased (16575367.776563 --> 15610062.360742).  Saving model ...
Epoch: 6 	Training Loss: 2176517.9839 	Training Accuracy: 0.3920 	Validation Loss: 15610062.3607 	Validation Accuracy: 0.3992 	train_kl_div: 311807151.2866


  8%|▊         | 8/100 [24:28<4:41:42, 183.72s/it]

Validation loss decreased (15610062.360742 --> 14736935.998047).  Saving model ...
Epoch: 7 	Training Loss: 2052249.3615 	Training Accuracy: 0.4061 	Validation Loss: 14736935.9980 	Validation Accuracy: 0.4119 	train_kl_div: 294252389.9108


  9%|▉         | 9/100 [27:32<4:38:49, 183.84s/it]

Validation loss decreased (14736935.998047 --> 13942609.544336).  Saving model ...
Epoch: 8 	Training Loss: 1939570.2161 	Training Accuracy: 0.4245 	Validation Loss: 13942609.5443 	Validation Accuracy: 0.4342 	train_kl_div: 278317209.4777


 10%|█         | 10/100 [30:35<4:35:26, 183.63s/it]

Validation loss decreased (13942609.544336 --> 13221478.122070).  Saving model ...
Epoch: 9 	Training Loss: 1837883.1780 	Training Accuracy: 0.4411 	Validation Loss: 13221478.1221 	Validation Accuracy: 0.4209 	train_kl_div: 263772668.1274


 11%|█         | 11/100 [33:40<4:32:44, 183.88s/it]

Validation loss decreased (13221478.122070 --> 12555008.199414).  Saving model ...
Epoch: 10 	Training Loss: 1745164.3599 	Training Accuracy: 0.4428 	Validation Loss: 12555008.1994 	Validation Accuracy: 0.4481 	train_kl_div: 250432947.5669


 12%|█▏        | 12/100 [36:43<4:29:17, 183.60s/it]

Validation loss decreased (12555008.199414 --> 11948853.795312).  Saving model ...
Epoch: 11 	Training Loss: 1660215.8222 	Training Accuracy: 0.4523 	Validation Loss: 11948853.7953 	Validation Accuracy: 0.4248 	train_kl_div: 238132420.5860


 13%|█▎        | 13/100 [39:47<4:26:26, 183.75s/it]

Validation loss decreased (11948853.795312 --> 11373636.018945).  Saving model ...
Epoch: 12 	Training Loss: 1581555.2141 	Training Accuracy: 0.4640 	Validation Loss: 11373636.0189 	Validation Accuracy: 0.4798 	train_kl_div: 226738390.6242


 14%|█▍        | 14/100 [42:54<4:24:50, 184.77s/it]

Validation loss decreased (11373636.018945 --> 10848516.490625).  Saving model ...
Epoch: 13 	Training Loss: 1507992.9201 	Training Accuracy: 0.4748 	Validation Loss: 10848516.4906 	Validation Accuracy: 0.4610 	train_kl_div: 216141200.8153


 15%|█▌        | 15/100 [45:58<4:21:28, 184.57s/it]

Validation loss decreased (10848516.490625 --> 10356632.014258).  Saving model ...
Epoch: 14 	Training Loss: 1441587.6989 	Training Accuracy: 0.4708 	Validation Loss: 10356632.0143 	Validation Accuracy: 0.4574 	train_kl_div: 206265509.9108


 16%|█▌        | 16/100 [49:03<4:18:28, 184.63s/it]

Validation loss decreased (10356632.014258 --> 9895062.853516).  Saving model ...
Epoch: 15 	Training Loss: 1377319.5240 	Training Accuracy: 0.4852 	Validation Loss: 9895062.8535 	Validation Accuracy: 0.4739 	train_kl_div: 197015607.8471


 17%|█▋        | 17/100 [52:07<4:15:03, 184.37s/it]

Validation loss decreased (9895062.853516 --> 9461712.441797).  Saving model ...
Epoch: 16 	Training Loss: 1319448.1467 	Training Accuracy: 0.4764 	Validation Loss: 9461712.4418 	Validation Accuracy: 0.4923 	train_kl_div: 188340844.9427


 18%|█▊        | 18/100 [55:11<4:11:48, 184.25s/it]

Validation loss decreased (9461712.441797 --> 9055185.607812).  Saving model ...
Epoch: 17 	Training Loss: 1263036.2567 	Training Accuracy: 0.4944 	Validation Loss: 9055185.6078 	Validation Accuracy: 0.4870 	train_kl_div: 180173591.6433


 19%|█▉        | 19/100 [58:14<4:08:30, 184.09s/it]

Validation loss decreased (9055185.607812 --> 8671736.666113).  Saving model ...
Epoch: 18 	Training Loss: 1211581.1466 	Training Accuracy: 0.4907 	Validation Loss: 8671736.6661 	Validation Accuracy: 0.4877 	train_kl_div: 172478691.5669


 20%|██        | 20/100 [1:01:18<4:05:10, 183.88s/it]

Validation loss decreased (8671736.666113 --> 8307154.630859).  Saving model ...
Epoch: 19 	Training Loss: 1161623.3906 	Training Accuracy: 0.5003 	Validation Loss: 8307154.6309 	Validation Accuracy: 0.5034 	train_kl_div: 165206073.4777


 21%|██        | 21/100 [1:04:21<4:01:48, 183.65s/it]

Validation loss decreased (8307154.630859 --> 7967083.464941).  Saving model ...
Epoch: 20 	Training Loss: 1115021.7744 	Training Accuracy: 0.5066 	Validation Loss: 7967083.4649 	Validation Accuracy: 0.4887 	train_kl_div: 158323532.1274


 22%|██▏       | 22/100 [1:07:25<3:58:55, 183.79s/it]

Validation loss decreased (7967083.464941 --> 7641551.730078).  Saving model ...
Epoch: 21 	Training Loss: 1070809.1921 	Training Accuracy: 0.5116 	Validation Loss: 7641551.7301 	Validation Accuracy: 0.4936 	train_kl_div: 151806523.8217


 23%|██▎       | 23/100 [1:10:29<3:55:46, 183.72s/it]

Validation loss decreased (7641551.730078 --> 7334383.111035).  Saving model ...
Epoch: 22 	Training Loss: 1029033.5346 	Training Accuracy: 0.5162 	Validation Loss: 7334383.1110 	Validation Accuracy: 0.4875 	train_kl_div: 145617796.7898


 24%|██▍       | 24/100 [1:13:32<3:52:40, 183.70s/it]

Validation loss decreased (7334383.111035 --> 7040457.821680).  Saving model ...
Epoch: 23 	Training Loss: 989095.8824 	Training Accuracy: 0.5247 	Validation Loss: 7040457.8217 	Validation Accuracy: 0.4934 	train_kl_div: 139747011.1592


 25%|██▌       | 25/100 [1:16:37<3:49:57, 183.96s/it]

Validation loss decreased (7040457.821680 --> 6762030.925977).  Saving model ...
Epoch: 24 	Training Loss: 951128.3321 	Training Accuracy: 0.5275 	Validation Loss: 6762030.9260 	Validation Accuracy: 0.5061 	train_kl_div: 134164505.8854


 26%|██▌       | 26/100 [1:19:40<3:46:36, 183.73s/it]

Validation loss decreased (6762030.925977 --> 6498586.116406).  Saving model ...
Epoch: 25 	Training Loss: 915859.6743 	Training Accuracy: 0.5285 	Validation Loss: 6498586.1164 	Validation Accuracy: 0.5015 	train_kl_div: 128858553.6306


 27%|██▋       | 27/100 [1:22:44<3:43:36, 183.78s/it]

Validation loss decreased (6498586.116406 --> 6243452.781348).  Saving model ...
Epoch: 26 	Training Loss: 881648.1524 	Training Accuracy: 0.5310 	Validation Loss: 6243452.7813 	Validation Accuracy: 0.5227 	train_kl_div: 123804850.0892


 28%|██▊       | 28/100 [1:25:47<3:40:25, 183.68s/it]

Validation loss decreased (6243452.781348 --> 6005664.091211).  Saving model ...
Epoch: 27 	Training Loss: 849033.5963 	Training Accuracy: 0.5351 	Validation Loss: 6005664.0912 	Validation Accuracy: 0.5115 	train_kl_div: 118987155.9236


 29%|██▉       | 29/100 [1:28:52<3:37:36, 183.90s/it]

Validation loss decreased (6005664.091211 --> 5771944.894434).  Saving model ...
Epoch: 28 	Training Loss: 817793.0304 	Training Accuracy: 0.5408 	Validation Loss: 5771944.8944 	Validation Accuracy: 0.5407 	train_kl_div: 114388020.1783


 30%|███       | 30/100 [1:31:55<3:34:16, 183.66s/it]

Validation loss decreased (5771944.894434 --> 5555543.256543).  Saving model ...
Epoch: 29 	Training Loss: 787722.8124 	Training Accuracy: 0.5485 	Validation Loss: 5555543.2565 	Validation Accuracy: 0.5215 	train_kl_div: 110000679.2357


 31%|███       | 31/100 [1:34:58<3:31:06, 183.58s/it]

Validation loss decreased (5555543.256543 --> 5345096.586426).  Saving model ...
Epoch: 30 	Training Loss: 759924.1748 	Training Accuracy: 0.5488 	Validation Loss: 5345096.5864 	Validation Accuracy: 0.5416 	train_kl_div: 105818565.2484


 32%|███▏      | 32/100 [1:38:02<3:27:58, 183.51s/it]

Validation loss decreased (5345096.586426 --> 5142413.406836).  Saving model ...
Epoch: 31 	Training Loss: 732400.7206 	Training Accuracy: 0.5561 	Validation Loss: 5142413.4068 	Validation Accuracy: 0.5693 	train_kl_div: 101818076.8917


 33%|███▎      | 33/100 [1:41:05<3:24:58, 183.55s/it]

Validation loss decreased (5142413.406836 --> 4955038.436914).  Saving model ...
Epoch: 32 	Training Loss: 706977.4146 	Training Accuracy: 0.5598 	Validation Loss: 4955038.4369 	Validation Accuracy: 0.5472 	train_kl_div: 98001767.5414


 34%|███▍      | 34/100 [1:44:08<3:21:34, 183.25s/it]

Validation loss decreased (4955038.436914 --> 4773691.786426).  Saving model ...
Epoch: 33 	Training Loss: 682106.5335 	Training Accuracy: 0.5626 	Validation Loss: 4773691.7864 	Validation Accuracy: 0.5282 	train_kl_div: 94346953.1210


 35%|███▌      | 35/100 [1:47:11<3:18:38, 183.37s/it]

Validation loss decreased (4773691.786426 --> 4597047.267383).  Saving model ...
Epoch: 34 	Training Loss: 658746.9143 	Training Accuracy: 0.5632 	Validation Loss: 4597047.2674 	Validation Accuracy: 0.5521 	train_kl_div: 90856889.5287


 36%|███▌      | 36/100 [1:50:16<3:15:56, 183.69s/it]

Validation loss decreased (4597047.267383 --> 4430702.760645).  Saving model ...
Epoch: 35 	Training Loss: 636777.2682 	Training Accuracy: 0.5654 	Validation Loss: 4430702.7606 	Validation Accuracy: 0.5522 	train_kl_div: 87517321.7834


 37%|███▋      | 37/100 [1:53:21<3:13:11, 183.99s/it]

Validation loss decreased (4430702.760645 --> 4270956.128223).  Saving model ...
Epoch: 36 	Training Loss: 614546.8751 	Training Accuracy: 0.5722 	Validation Loss: 4270956.1282 	Validation Accuracy: 0.5545 	train_kl_div: 84318982.1146


 38%|███▊      | 38/100 [1:56:24<3:09:56, 183.82s/it]

Validation loss decreased (4270956.128223 --> 4118772.561230).  Saving model ...
Epoch: 37 	Training Loss: 593575.1523 	Training Accuracy: 0.5763 	Validation Loss: 4118772.5612 	Validation Accuracy: 0.5480 	train_kl_div: 81252053.1465


 39%|███▉      | 39/100 [1:59:28<3:06:49, 183.77s/it]

Validation loss decreased (4118772.561230 --> 3968753.903809).  Saving model ...
Epoch: 38 	Training Loss: 573293.4481 	Training Accuracy: 0.5852 	Validation Loss: 3968753.9038 	Validation Accuracy: 0.5802 	train_kl_div: 78313639.3376


 40%|████      | 40/100 [2:02:31<3:03:38, 183.64s/it]

Validation loss decreased (3968753.903809 --> 3829942.884473).  Saving model ...
Epoch: 39 	Training Loss: 553794.3262 	Training Accuracy: 0.5915 	Validation Loss: 3829942.8845 	Validation Accuracy: 0.5712 	train_kl_div: 75491989.4013


 41%|████      | 41/100 [2:05:35<3:00:40, 183.73s/it]

Validation loss decreased (3829942.884473 --> 3695531.316504).  Saving model ...
Epoch: 40 	Training Loss: 536524.0715 	Training Accuracy: 0.5860 	Validation Loss: 3695531.3165 	Validation Accuracy: 0.5698 	train_kl_div: 72800418.5987


 42%|████▏     | 42/100 [2:08:39<2:57:38, 183.76s/it]

Validation loss decreased (3695531.316504 --> 3569300.766211).  Saving model ...
Epoch: 41 	Training Loss: 519238.9995 	Training Accuracy: 0.5886 	Validation Loss: 3569300.7662 	Validation Accuracy: 0.5445 	train_kl_div: 70217427.4140


 43%|████▎     | 43/100 [2:11:42<2:54:18, 183.48s/it]

Validation loss decreased (3569300.766211 --> 3445328.110254).  Saving model ...
Epoch: 42 	Training Loss: 502650.3780 	Training Accuracy: 0.5901 	Validation Loss: 3445328.1103 	Validation Accuracy: 0.5524 	train_kl_div: 67734957.6561


 44%|████▍     | 44/100 [2:14:45<2:51:15, 183.49s/it]

Validation loss decreased (3445328.110254 --> 3324874.224902).  Saving model ...
Epoch: 43 	Training Loss: 485797.9203 	Training Accuracy: 0.5977 	Validation Loss: 3324874.2249 	Validation Accuracy: 0.5695 	train_kl_div: 65349829.0955


 45%|████▌     | 45/100 [2:17:49<2:48:21, 183.66s/it]

Validation loss decreased (3324874.224902 --> 3208854.924121).  Saving model ...
Epoch: 44 	Training Loss: 470152.0192 	Training Accuracy: 0.6019 	Validation Loss: 3208854.9241 	Validation Accuracy: 0.5754 	train_kl_div: 63059968.5605


 46%|████▌     | 46/100 [2:20:53<2:45:19, 183.69s/it]

Validation loss decreased (3208854.924121 --> 3100305.661523).  Saving model ...
Epoch: 45 	Training Loss: 454938.6100 	Training Accuracy: 0.6079 	Validation Loss: 3100305.6615 	Validation Accuracy: 0.5679 	train_kl_div: 60855789.1465


 47%|████▋     | 47/100 [2:23:56<2:42:03, 183.47s/it]

Validation loss decreased (3100305.661523 --> 2992901.744336).  Saving model ...
Epoch: 46 	Training Loss: 441083.9615 	Training Accuracy: 0.6064 	Validation Loss: 2992901.7443 	Validation Accuracy: 0.5843 	train_kl_div: 58745504.1529


 48%|████▊     | 48/100 [2:27:01<2:39:33, 184.10s/it]

Validation loss decreased (2992901.744336 --> 2889723.074023).  Saving model ...
Epoch: 47 	Training Loss: 426501.7513 	Training Accuracy: 0.6178 	Validation Loss: 2889723.0740 	Validation Accuracy: 0.5961 	train_kl_div: 56709321.4268


 49%|████▉     | 49/100 [2:30:05<2:36:17, 183.87s/it]

Validation loss decreased (2889723.074023 --> 2792197.767383).  Saving model ...
Epoch: 48 	Training Loss: 413651.1245 	Training Accuracy: 0.6136 	Validation Loss: 2792197.7674 	Validation Accuracy: 0.5887 	train_kl_div: 54754226.0127


 50%|█████     | 50/100 [2:33:07<2:32:57, 183.55s/it]

Validation loss decreased (2792197.767383 --> 2698776.564844).  Saving model ...
Epoch: 49 	Training Loss: 400583.9937 	Training Accuracy: 0.6186 	Validation Loss: 2698776.5648 	Validation Accuracy: 0.5902 	train_kl_div: 52875311.0318


 51%|█████     | 51/100 [2:36:11<2:29:52, 183.52s/it]

Validation loss decreased (2698776.564844 --> 2609274.460742).  Saving model ...
Epoch: 50 	Training Loss: 388023.1428 	Training Accuracy: 0.6228 	Validation Loss: 2609274.4607 	Validation Accuracy: 0.5958 	train_kl_div: 51060582.7261


 52%|█████▏    | 52/100 [2:39:15<2:27:01, 183.78s/it]

Validation loss decreased (2609274.460742 --> 2521828.637305).  Saving model ...
Epoch: 51 	Training Loss: 377214.7058 	Training Accuracy: 0.6206 	Validation Loss: 2521828.6373 	Validation Accuracy: 0.5962 	train_kl_div: 49326764.4331


 53%|█████▎    | 53/100 [2:42:20<2:24:09, 184.03s/it]

Validation loss decreased (2521828.637305 --> 2437939.123145).  Saving model ...
Epoch: 52 	Training Loss: 364702.2114 	Training Accuracy: 0.6267 	Validation Loss: 2437939.1231 	Validation Accuracy: 0.5957 	train_kl_div: 47646916.3312


 54%|█████▍    | 54/100 [2:45:23<2:20:57, 183.85s/it]

Validation loss decreased (2437939.123145 --> 2359646.802344).  Saving model ...
Epoch: 53 	Training Loss: 353709.0807 	Training Accuracy: 0.6331 	Validation Loss: 2359646.8023 	Validation Accuracy: 0.5803 	train_kl_div: 46036572.8662


 55%|█████▌    | 55/100 [2:48:26<2:17:42, 183.61s/it]

Validation loss decreased (2359646.802344 --> 2283712.582422).  Saving model ...
Epoch: 54 	Training Loss: 343372.9995 	Training Accuracy: 0.6360 	Validation Loss: 2283712.5824 	Validation Accuracy: 0.5775 	train_kl_div: 44478494.5732


 56%|█████▌    | 56/100 [2:51:30<2:14:37, 183.58s/it]

Validation loss decreased (2283712.582422 --> 2204751.839648).  Saving model ...
Epoch: 55 	Training Loss: 333187.9145 	Training Accuracy: 0.6370 	Validation Loss: 2204751.8396 	Validation Accuracy: 0.6023 	train_kl_div: 42981114.3185


 57%|█████▋    | 57/100 [2:54:34<2:11:33, 183.58s/it]

Validation loss decreased (2204751.839648 --> 2133390.207031).  Saving model ...
Epoch: 56 	Training Loss: 323720.1797 	Training Accuracy: 0.6347 	Validation Loss: 2133390.2070 	Validation Accuracy: 0.6072 	train_kl_div: 41538181.1465


 58%|█████▊    | 58/100 [2:57:38<2:08:46, 183.96s/it]

Validation loss decreased (2133390.207031 --> 2064731.247168).  Saving model ...
Epoch: 57 	Training Loss: 313450.4981 	Training Accuracy: 0.6443 	Validation Loss: 2064731.2472 	Validation Accuracy: 0.5915 	train_kl_div: 40145404.2293


 59%|█████▉    | 59/100 [3:00:42<2:05:32, 183.73s/it]

Validation loss decreased (2064731.247168 --> 1994466.506641).  Saving model ...
Epoch: 58 	Training Loss: 303965.9516 	Training Accuracy: 0.6508 	Validation Loss: 1994466.5066 	Validation Accuracy: 0.6120 	train_kl_div: 38800580.8408


 60%|██████    | 60/100 [3:03:46<2:02:34, 183.87s/it]

Validation loss decreased (1994466.506641 --> 1932394.887402).  Saving model ...
Epoch: 59 	Training Loss: 295192.1342 	Training Accuracy: 0.6504 	Validation Loss: 1932394.8874 	Validation Accuracy: 0.5904 	train_kl_div: 37503652.6624


 61%|██████    | 61/100 [3:06:45<1:58:31, 182.36s/it]

Validation loss decreased (1932394.887402 --> 1869437.514453).  Saving model ...
Epoch: 60 	Training Loss: 286613.8096 	Training Accuracy: 0.6536 	Validation Loss: 1869437.5145 	Validation Accuracy: 0.5974 	train_kl_div: 36255027.7452


 62%|██████▏   | 62/100 [3:09:43<1:54:45, 181.19s/it]

Validation loss decreased (1869437.514453 --> 1807599.759473).  Saving model ...
Epoch: 61 	Training Loss: 278239.6242 	Training Accuracy: 0.6568 	Validation Loss: 1807599.7595 	Validation Accuracy: 0.6089 	train_kl_div: 35051239.4140


In [None]:
#BayesianCNN with softplus on CIFAR100
bc100_sp_net = BBBAlexNet(c100_outputs, c100_inputs, priors, activation_type='softplus').to(device)
bc100_sp_criterion = metrics.ELBO(len(c100_trainset)).to(device)
bc100_sp_optimizer = Adam(bc100_sp_net.parameters(), lr=lr_start)
bc100_sp_lr_sched = lr_scheduler.ReduceLROnPlateau(bc100_sp_optimizer, patience=6, verbose=True)
bc100_sp_valid_loss_max = np.Inf

ckpt_name = 'Bayesian/Models/bc100_sp.pth'
if os.path.isfile(ckpt_name):
    checkpoint = th.load(ckpt_name)
    bc100_sp_net.load_state_dict(checkpoint['model_state_dict'])
    bc100_sp_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    bc100_sp_lr_sched.load_state_dict(checkpoint['scheduler_state_dict'])
    bc100_sp_valid_loss_max = checkpoint['valid_loss_max']
    print('Model loaded from {}'.format(ckpt_name))

for epoch in tqdm(range(n_epochs)):  # loop over the dataset multiple times

    bc100_sp_train_loss, bc100_sp_train_acc, bc100_sp_train_kl = BCNN.train_model(bc100_sp_net, bc100_sp_optimizer, bc100_sp_criterion, c100_train_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc100_sp_valid_loss, bc100_sp_valid_acc = BCNN.validate_model(bc100_sp_net, bc100_sp_criterion, c100_valid_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc100_sp_lr_sched.step(bc100_sp_valid_loss)

    # save model if validation accuracy has increased
    if bc100_sp_valid_loss <= bc100_sp_valid_loss_max:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            bc100_sp_valid_loss_max, bc100_sp_valid_loss))
        th.save({
            'model_state_dict': bc100_sp_net.state_dict(),
            'optimizer_state_dict': bc100_sp_optimizer.state_dict(),
            'scheduler_state_dict': bc100_sp_lr_sched.state_dict(),
            'valid_loss_max': bc100_sp_valid_loss
        }, ckpt_name)
        bc100_sp_valid_loss_max = bc100_sp_valid_loss

    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
        epoch, bc100_sp_train_loss, bc100_sp_train_acc, bc100_sp_valid_loss, bc100_sp_valid_acc, bc100_sp_train_kl))

# After all epochs are complete, evaluate the model on the test set
bc100_sp_test_loss, bc100_sp_test_acc = BCNN.validate_model(bc100_sp_net, bc100_sp_criterion, c100_test_loader, num_ens=1, beta_type=beta_type)
print('Test Loss: {:.4f} \tTest Accuracy: {:.4f}'.format(bc100_sp_test_loss, bc100_sp_test_acc))

In [None]:
# BayesianCNN with relu on CIFAR10
bc10_rl_net = BBBAlexNet(c10_outputs, c10_inputs, priors, activation_type='relu').to(device)
bc10_rl_criterion = metrics.ELBO(len(c10_trainset)).to(device)
bc10_rl_optimizer = Adam(bc10_rl_net.parameters(), lr=lr_start)
bc10_rl_lr_sched = lr_scheduler.ReduceLROnPlateau(bc10_rl_optimizer, patience=6, verbose=True)
bc10_rl_valid_loss_max = np.Inf

ckpt_name = 'Bayesian/Models/bc10_rl.pth'
if os.path.isfile(ckpt_name):
    checkpoint = th.load(ckpt_name)
    bc10_rl_net.load_state_dict(checkpoint['model_state_dict'])
    bc10_rl_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    bc10_rl_lr_sched.load_state_dict(checkpoint['scheduler_state_dict'])
    bc10_rl_valid_loss_max = checkpoint['valid_loss_max']
    print('Model loaded from {}'.format(ckpt_name))

for epoch in tqdm(range(n_epochs)):  # loop over the dataset multiple times

    bc10_rl_train_loss, bc10_rl_train_acc, bc10_rl_train_kl = BCNN.train_model(bc10_rl_net, bc10_rl_optimizer, bc10_rl_criterion, c10_train_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc10_rl_valid_loss, bc10_rl_valid_acc = BCNN.validate_model(bc10_rl_net, bc10_rl_criterion, c10_valid_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc10_rl_lr_sched.step(bc10_rl_valid_loss)

    # save model if validation accuracy has increased
    if bc10_rl_valid_loss <= bc10_rl_valid_loss_max:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            bc10_rl_valid_loss_max, bc10_rl_valid_loss))
        th.save({
            'model_state_dict': bc10_rl_net.state_dict(),
            'optimizer_state_dict': bc10_rl_optimizer.state_dict(),
            'scheduler_state_dict': bc10_rl_lr_sched.state_dict(),
            'valid_loss_max': bc10_rl_valid_loss
        }, ckpt_name)
        bc10_rl_valid_loss_max = bc10_rl_valid_loss
    

    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
        epoch, bc10_rl_train_loss, bc10_rl_train_acc, bc10_rl_valid_loss, bc10_rl_valid_acc, bc10_rl_train_kl))

# After all epochs are complete, evaluate the model on the test set
bc10_rl_test_loss, bc10_rl_test_acc = BCNN.validate_model(bc10_rl_net, bc10_rl_criterion, c10_test_loader, num_ens=1, beta_type=beta_type)
print('Test Loss: {:.4f} \tTest Accuracy: {:.4f}'.format(bc10_rl_test_loss, bc10_rl_test_acc))

In [None]:
# BayesianCNN with relu on CIFAR100
bc100_rl_net = BBBAlexNet(c100_outputs, c100_inputs, priors, activation_type='relu').to(device)
bc100_rl_criterion = metrics.ELBO(len(c100_trainset)).to(device)
bc100_rl_optimizer = Adam(bc100_rl_net.parameters(), lr=lr_start)
bc100_rl_lr_sched = lr_scheduler.ReduceLROnPlateau(bc100_rl_optimizer, patience=6, verbose=True)
bc100_rl_valid_loss_max = np.Inf

ckpt_name = 'Bayesian/Models/bc100_rl.pth'
if os.path.isfile(ckpt_name):
    checkpoint = th.load(ckpt_name)
    bc100_rl_net.load_state_dict(checkpoint['model_state_dict'])
    bc100_rl_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    bc100_rl_lr_sched.load_state_dict(checkpoint['scheduler_state_dict'])
    bc100_rl_valid_loss_max = checkpoint['valid_loss_max']
    print('Model loaded from {}'.format(ckpt_name))

for epoch in tqdm(range(n_epochs)):  # loop over the dataset multiple times

    bc100_rl_train_loss, bc100_rl_train_acc, bc100_rl_train_kl = BCNN.train_model(bc100_rl_net, bc100_rl_optimizer, bc100_rl_criterion, c100_train_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc100_rl_valid_loss, bc100_rl_valid_acc = BCNN.validate_model(bc100_rl_net, bc100_rl_criterion, c100_valid_loader, num_ens=1, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    bc100_rl_lr_sched.step(bc100_rl_valid_loss)

    # save model if validation accuracy has increased
    if bc100_rl_valid_loss <= bc100_rl_valid_loss_max:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            bc100_rl_valid_loss_max, bc100_rl_valid_loss))
        th.save({
            'model_state_dict': bc100_rl_net.state_dict(),
            'optimizer_state_dict': bc100_rl_optimizer.state_dict(),
            'scheduler_state_dict': bc100_rl_lr_sched.state_dict(),
            'valid_loss_max': bc100_rl_valid_loss
        }, ckpt_name)
        bc100_rl_valid_loss_max = bc100_rl_valid_loss

    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
        epoch, bc100_rl_train_loss, bc100_rl_train_acc, bc100_rl_valid_loss, bc100_rl_valid_acc, bc100_rl_train_kl))

# After all epochs are complete, evaluate the model on the test set
bc100_rl_test_loss, bc100_rl_test_acc = BCNN.validate_model(bc100_rl_net, bc100_rl_criterion, c100_test_loader, num_ens=1, beta_type=beta_type)
print('Test Loss: {:.4f} \tTest Accuracy: {:.4f}'.format(bc100_rl_test_loss, bc100_rl_test_acc))

In [None]:
# FrequentistCNN on CIFAR10
fc10_net = AlexNet(c10_outputs, c10_inputs).to(device)
fc10_criterion = nn.CrossEntropyLoss()
fc10_optimizer = Adam(fc10_net.parameters(), lr=lr_start)
fc10_lr_sched = lr_scheduler.ReduceLROnPlateau(fc10_optimizer, patience=6, verbose=True)
fc10_valid_loss_max = np.Inf

ckpt_name = 'Frequentist/Models/fc10.pth'
if os.path.isfile(ckpt_name):
    checkpoint = th.load(ckpt_name)
    fc10_net.load_state_dict(checkpoint['model_state_dict'])
    fc10_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    fc10_lr_sched.load_state_dict(checkpoint['scheduler_state_dict'])
    fc10_valid_loss_max = checkpoint['valid_loss_max']
    print('Model loaded from {}'.format(ckpt_name))

for epoch in tqdm(range(n_epochs)):

    fc10_train_loss, fc10_train_acc = FCNN.train_model(fc10_net, fc10_optimizer, fc10_criterion, c10_train_loader)
    fc10_valid_loss, fc10_valid_acc = FCNN.validate_model(fc10_net, fc10_criterion, c10_valid_loader)
    fc10_lr_sched.step(fc10_valid_loss)

    # save model if validation accuracy has increased
    if fc10_valid_loss <= fc10_valid_loss_max:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            fc10_valid_loss_max, fc10_valid_loss))
        th.save({
            'model_state_dict': fc10_net.state_dict(),
            'optimizer_state_dict': fc10_optimizer.state_dict(),
            'scheduler_state_dict': fc10_lr_sched.state_dict(),
            'valid_loss_max': fc10_valid_loss
        }, ckpt_name)
        fc10_valid_loss_max = fc10_valid_loss
            
    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f}'.format(
        epoch, fc10_train_loss, fc10_train_acc, fc10_valid_loss, fc10_valid_acc))

# After all epochs are complete, evaluate the model on the test set
fc10_test_loss, fc10_test_acc = FCNN.validate_model(fc10_net, fc10_criterion, c10_test_loader)
print('Test Loss: {:.4f} \tTest Accuracy: {:.4f}'.format(fc10_test_loss, fc10_test_acc))

In [None]:
# Frequentist CNN on CIFAR100
fc100_net = AlexNet(c100_outputs, c100_inputs).to(device)
fc100_criterion = nn.CrossEntropyLoss()
fc100_optimizer = Adam(fc100_net.parameters(), lr=lr_start)
fc100_lr_sched = lr_scheduler.ReduceLROnPlateau(fc100_optimizer, patience=6, verbose=True)
fc100_valid_loss_max = np.Inf

ckpt_name = 'Frequentist/Models/fc100.pth'
if os.path.isfile(ckpt_name):
    checkpoint = th.load(ckpt_name)
    fc100_net.load_state_dict(checkpoint['model_state_dict'])
    fc100_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    fc100_lr_sched.load_state_dict(checkpoint['scheduler_state_dict'])
    fc100_valid_loss_max = checkpoint['valid_loss_max']
    print('Model loaded from {}'.format(ckpt_name))

for epoch in tqdm(range(n_epochs)):

    fc100_train_loss, fc100_train_acc = FCNN.train_model(fc100_net, fc100_optimizer, fc100_criterion, c100_train_loader)
    fc100_valid_loss, fc100_valid_acc = FCNN.validate_model(fc100_net, fc100_criterion, c100_valid_loader)
    fc100_lr_sched.step(fc100_valid_loss)

    # save model if validation accuracy has increased
    if fc100_valid_loss <= fc100_valid_loss_max:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            fc100_valid_loss_max, fc100_valid_loss))
        th.save({
            'model_state_dict': fc100_net.state_dict(),
            'optimizer_state_dict': fc100_optimizer.state_dict(),
            'scheduler_state_dict': fc100_lr_sched.state_dict(),
            'valid_loss_max': fc100_valid_loss
        }, ckpt_name)
        fc100_valid_loss_max = fc100_valid_loss
            
    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f}'.format(
        epoch, fc100_train_loss, fc100_train_acc, fc100_valid_loss, fc100_valid_acc))

# After all epochs are complete, evaluate the model on the test set
fc100_test_loss, fc100_test_acc = FCNN.validate_model(fc100_net, fc100_criterion, c100_test_loader)
print('Test Loss: {:.4f} \tTest Accuracy: {:.4f}'.format(fc100_test_loss, fc100_test_acc))

In [None]:
epochs = range(1, n_epochs+1)

plt.figure(figsize=(10,6))
plt.plot(epochs, bc10_sp_train_acc, label='bc10_sp')
plt.plot(epochs, bc100_sp_train_acc, label='bc100_sp')
plt.plot(epochs, bc10_rl_train_acc, label='bc10_rl')
plt.plot(epochs, bc100_rl_train_acc, label='bc100_rl')
plt.plot(epochs, fc10_train_acc, label='fc10')
plt.plot(epochs, fc100_train_acc, label='fc100')

plt.title('Comparison of training accuracies')
plt.xlabel('Epochs')
plt.ylabel('Training Accuracy')
plt.legend()
plt.show()

In [9]:
epochs = range(1, n_epochs+1)

plt.figure(figsize=(10,6))
plt.plot(epochs, bc10_sp_valid_acc, label='bc10_sp')
#plt.plot(epochs, bc100_sp_valid_acc, label='bc100_sp')
#plt.plot(epochs, bc10_rl_valid_acc, label='bc10_rl')
#plt.plot(epochs, bc100_rl_valid_acc, label='bc100_rl')
#plt.plot(epochs, fc10_valid_acc, label='fc10')
#plt.plot(epochs, fc100_valid_acc, label='fc100')

plt.title('Comparison of validation accuracies')
plt.xlabel('Epochs')
plt.ylabel('Validation Accuracy')
plt.legend()
plt.show()

NameError: name 'bc10_sp_valid_acc' is not defined

<Figure size 1000x600 with 0 Axes>

In [None]:
# Mislabel 10% of the CIFAR10 and CIFAR100 trainsets
c10_trainset_mislabel = data.mislabel_data(c10_trainset, c10_outputs, 0.1)
c100_trainset_mislabel = data.mislabel_data(c100_trainset, c100_outputs, 0.1)

# Create new dataloaders for the mislabeled datasets
c10_train_loader_mislabel, _, _ = data.getDataloader(
    c10_trainset_mislabel, c10_testset, valid_size, batch_size, num_workers)
c100_train_loader_mislabel, _, _ = data.getDataloader(
    c100_trainset_mislabel, c100_testset, valid_size, batch_size, num_workers)