In [1]:
import torch as th
import numpy as np
from torch.optim import Adam, lr_scheduler
from torch.nn import functional as F
from tqdm import tqdm
from utils import data, metrics
import Frequentist_main as FCNN
import Bayesian_main as BCNN

In [2]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")

In [5]:
activation_type = 'softplus'  # 'softplus' or 'relu'
priors={
    'prior_mu': 0,
    'prior_sigma': 0.1,
    'posterior_mu_initial': (0, 0.1),  # (mean, std) normal_
    'posterior_rho_initial': (-5, 0.1),  # (mean, std) normal_
}

n_epochs = 5
lr_start = 0.001
num_workers = 4
valid_size = 0.2
batch_size = 256
train_ens = 1
valid_ens = 1
beta_type = 0.1 

In [6]:
from Bayesian.BayesianCNN import BBBAlexNet
from Bayesian_main import train_model, validate_model

trainset, testset, inputs, outputs = data.getDataset('CIFAR10')
train_loader, valid_loader, test_loader = data.getDataloader(
    trainset, testset, valid_size, batch_size, num_workers)
net = BBBAlexNet(outputs, inputs, priors, activation_type).to(device)

#ckpt_dir = f'checkpoints/{dataset}/bayesian'
#ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}_{activation_type}.pt'

#if not os.path.exists(ckpt_dir):
 #   os.makedirs(ckpt_dir, exist_ok=True)

criterion = metrics.ELBO(len(trainset)).to(device)
optimizer = Adam(net.parameters(), lr=lr_start)
lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
valid_loss_max = np.Inf
for epoch in tqdm(range(n_epochs)):  # loop over the dataset multiple times

    train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
    lr_sched.step(valid_loss)

    print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))

Files already downloaded and verified
Files already downloaded and verified


 20%|██        | 1/5 [04:16<17:07, 256.94s/it]

Epoch: 0 	Training Loss: 38102597.7452 	Training Accuracy: 0.1460 	Validation Loss: 29979998.4500 	Validation Accuracy: 0.1655 	train_kl_div: 374937775.2866


 40%|████      | 2/5 [08:02<11:54, 238.30s/it]

Epoch: 1 	Training Loss: 25636172.5223 	Training Accuracy: 0.2155 	Validation Loss: 22023878.6000 	Validation Accuracy: 0.2416 	train_kl_div: 255244121.4777


 60%|██████    | 3/5 [11:46<07:44, 232.12s/it]

Epoch: 2 	Training Loss: 19457054.7771 	Training Accuracy: 0.2508 	Validation Loss: 17190627.8000 	Validation Accuracy: 0.2579 	train_kl_div: 193537329.6306


 80%|████████  | 4/5 [15:08<03:39, 219.96s/it]

Epoch: 3 	Training Loss: 15445241.0955 	Training Accuracy: 0.2697 	Validation Loss: 13870556.2000 	Validation Accuracy: 0.2355 	train_kl_div: 153455476.8917


100%|██████████| 5/5 [18:05<00:00, 217.15s/it]

Epoch: 4 	Training Loss: 12600298.2739 	Training Accuracy: 0.2948 	Validation Loss: 11436288.3250 	Validation Accuracy: 0.3205 	train_kl_div: 125037782.3694



