<a href="https://colab.research.google.com/github/BharathSShankar/DSA4212_Assignments/blob/bharath-exp/VOGN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install medmnist torchmetrics
!python -m medmnist download

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Downloading pathmnist...
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
Downloading chestmnist...
Using downloaded and verified file: /root/.medmnist/chestmnist.npz
Downloading dermamnist...
Using downloaded and verified file: /root/.medmnist/dermamnist.npz
Downloading octmnist...
Using downloaded and verified file: /root/.medmnist/octmnist.npz
Downloading pneumoniamnist...
Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Downloading retinamnist...
Using downloaded and verified file: /root/.medmnist/retinamnist.npz
Downloading breastmnist...
Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Downloading bloodmnist...
Using downloaded and verified file: /root/.medmnist/bloodmnist.npz
Downloading tissuemnist...
Using downloaded and verified file: /root/.medmnist/tissuemnist.npz
Downloading organamnist...
Using downloaded and verified

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Optimizer
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import copy

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd "/content/drive/MyDrive/DSA4212/Assignment 3"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/DSA4212/Assignment 3


In [None]:
import medmnist
from medmnist import INFO, Evaluator

In [None]:
import torch
from torch.optim.optimizer import Optimizer


class VOGN(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-8, weight_decay=0, delta = 25 / 71800, tau = 1, gamma = 5e-2):
        defaults = dict(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, weight_decay=weight_decay, delta = delta, tau = tau, gamma = gamma)
        super(VOGN, self).__init__(params, defaults)

        # Initialize m and mu to zero and a deepcopy of the initial parameters respectively
        self.m = [torch.zeros_like(p) for p in self.param_groups[0]['params']]
        self.mu = copy.deepcopy(self.param_groups[0]['params'])
        self.s = None
        self.net_params = self.param_groups[0]['params']

    def _squared_grad_hessian_approx(self, loss):
        grads = [torch.autograd.grad(loss_elem, self.net_params, retain_graph=True) for loss_elem in loss]
        grads_squared = [[g * g for g in grad] for grad in grads]
        mean_grad = []
        grad_hess = []
        for i in range(len(grads[0])):
            grad_param = torch.cat([grad[i] for grad in grads]).mean(axis = 0)
            hess_param = torch.cat([grad[i] for grad in grads_squared]).mean(axis = 0)
            mean_grad.append(grad_param)
            grad_hess.append(hess_param)
        return grad_hess, mean_grad

    def _vari_gradhess(self, loss):
        """ computes gradient and hessian approximation averaged based on current variational distribution """
        vari_grads = []
        vari_hess = []
        for i, loss_elems in enumerate(loss):
            grad_hess, mean_grad = self._squared_grad_hessian_approx(loss_elems)
            vari_grads.append(mean_grad)
            vari_hess.append(grad_hess)
        out_grads = []
        out_hess = []
        for i in range(len(mean_grad[0])):
            out_grads.append(torch.cat([vari_grad[i].unsqueeze(0) for vari_grad in vari_grads]).mean(axis = 0).squeeze())
            out_hess.append(torch.cat([vari_hes[i].unsqueeze(0) for vari_hes in vari_hess]).mean(axis = 0).squeeze())
        return out_grads, out_hess
    
    def _vogn_step(self, gh):
        # update m
        for i, (g, mu) in enumerate(zip(gh[0], self.mu)):
            self.m[i] = self.m[i] * self.param_groups[0]["beta_1"] + (g + mu * self.param_groups[0]["delta"]) * (1 - self.param_groups[0]["beta_1"])

        # update s
        for i, h in enumerate(gh[1]):
            self.s[i] = self.s[i] * (1 - self.param_groups[0]["beta_2"] * self.param_groups[0]["tau"]) + h * self.param_groups[0]["beta_2"] * self.param_groups[0]["tau"]

        # update mu
        for i, (m, s) in enumerate(zip(self.m, self.s)):
            self.mu[i] = self.mu[i] - m * self.param_groups[0]["lr"] / (s + self.param_groups[0]["delta"] + self.param_groups[0]["gamma"])

    def step(self, losses):
        # Calculate current gradient and hessian approximation
        if not self.s:
            self.s, _ = self._squared_grad_hessian_approx(losses[0])
        gh = self._vari_gradhess(losses)

        # Update VOGN state
        self._vogn_step(gh)

        # Update model parameters with new mean value
        with torch.no_grad():
            for p, mu in zip(self.net_params, self.mu):
                p.copy_(mu)


In [None]:
data_flag = 'pathmnist'
# data_flag = 'breastmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])


In [None]:
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

Using downloaded and verified file: /root/.medmnist/pathmnist.npz
Using downloaded and verified file: /root/.medmnist/pathmnist.npz
Using downloaded and verified file: /root/.medmnist/pathmnist.npz


In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PI = 0.5
SIGMA_1 = torch.FloatTensor([np.exp(1)]).to(DEVICE)
SIGMA_2 = torch.FloatTensor([np.exp(-4)]).to(DEVICE)

In [None]:
class Gaussian(object):
    def __init__(self, mu, rho):
        super().__init__()
        self.mu = mu
        self.rho = rho
        self.normal = torch.distributions.Normal(0,1)
    
    @property
    def sigma(self):
        return torch.log1p(torch.exp(self.rho))
    
    def sample(self):
        epsilon = self.normal.sample(self.rho.size()).to(DEVICE)
        return self.mu + self.sigma * epsilon
    
    def log_prob(self, input):
        return (-np.log(np.sqrt(2 * np.pi))
                - torch.log(self.sigma)
                - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()

In [None]:
class ScaleMixtureGaussian(object):
    def __init__(self, pi, sigma1, sigma2):
        super().__init__()
        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.gaussian1 = torch.distributions.Normal(0,sigma1)
        self.gaussian2 = torch.distributions.Normal(0,sigma2)
    
    def log_prob(self, input):
        prob1 = torch.exp(self.gaussian1.log_prob(input))
        prob2 = torch.exp(self.gaussian2.log_prob(input))
        return (torch.log(self.pi * prob1 + (1-self.pi) * prob2)).sum()

In [None]:
class BayesianDense(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.1, 0.1))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-7,-5))
        self.weight = Gaussian(self.weight_mu, self.weight_rho)

        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.1, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-7,-5))
        self.bias = Gaussian(self.bias_mu, self.bias_rho)

        self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
        self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, input, sample=False, calculate_log_probs=False):
        if self.training or sample:
            weight = self.weight.sample()
            bias = self.bias.sample()
        else:
            weight = self.weight.mu
            bias = self.bias.mu
        if self.training or calculate_log_probs:
            self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
            self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
        else:
            self.log_prior, self.log_variational_posterior = 0, 0

        return F.linear(input, weight, bias)
    
    def kl_loss(self):
        kl = self.weight_prior.log_prob(self.weight.mu) - self.weight.log_prob(self.weight.mu)
        kl += self.bias_prior.log_prob(self.bias.mu) - self.bias.log_prob(self.bias.mu)
        return kl.sum()

In [None]:
class BayesianConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, k_size):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k_size = k_size

        self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels, k_size, k_size).uniform_(-0.1, 0.1))
        self.weight_rho = nn.Parameter(torch.Tensor(out_channels, in_channels, k_size,k_size).uniform_(-7,-5))
        self.weight = Gaussian(self.weight_mu, self.weight_rho)

        self.bias_mu = nn.Parameter(torch.Tensor(out_channels).uniform_(-0.1, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_channels).uniform_(-7,-5))
        self.bias = Gaussian(self.bias_mu, self.bias_rho)

        self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
        self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, input, sample=False, calculate_log_probs=False):
        if self.training or sample:
            weight = self.weight.sample()
            bias = self.bias.sample()
        else:
            weight = self.weight.mu
            bias = self.bias.mu
        if self.training or calculate_log_probs:
            self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
            self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
        else:
            self.log_prior, self.log_variational_posterior = 0, 0

        return F.conv2d(input, weight, bias)
    
    def kl_loss(self):
        kl = self.weight_prior.log_prob(self.weight.mu) - self.weight.log_prob(self.weight.mu)
        kl += self.bias_prior.log_prob(self.bias.mu) - self.bias.log_prob(self.bias.mu)
        return kl.sum()

In [None]:
def train_bayesian_net(net, train_loader, test_loader, n_epochs=20, lr=3e-4, log_interval=10, beta=1e-7, k=5):
    # Define loss function and optimizer
    loss_func = nn.CrossEntropyLoss(reduce = False)

    # Move model to device
    net.to(DEVICE)

    # Initialize VOGN
    vogn = VOGN(net.parameters(), lr=lr)

    # Training loop
    for epoch in range(n_epochs):
        # Training mode
        net.train()
        train_loss = 0
        ce_loss = 0
        kl_loss = 0
        correct = 0
        total = 0
        with tqdm(train_loader, desc=f"Epoch {epoch+1}") as t:
            for batch_idx, (data, target) in enumerate(t):
                data, target = data.to(DEVICE), target.to(DEVICE)

                # Compute loss and update model
                losses = []
                for i in range(k):
                    outputs = net(data)
                    ce_loss = loss_func(outputs, target.T[0])
                    kl_loss = net.kl_loss() / len(train_loader)
                    loss = ce_loss - beta * kl_loss
                    for loss_elem in loss:
                        loss_elem.backward(retain_graph=True)
                    losses.append(loss)

                vogn.step(losses)

                # Update training statistics
                train_loss += loss.mean().item()
                ce_loss = ce_loss.mean()
                _, predicted = outputs.max(1)
                total += data.shape[0]
                correct += predicted.eq(target.T[0]).sum().item()

                # Log training progress
                if batch_idx % log_interval == 0:
                    train_acc = correct / total
                    train_loss /= log_interval
                    ce_loss /= log_interval
                    kl_loss /= log_interval
                    t.set_postfix(ce_loss=f"{ce_loss:.6f}", kl_loss=f"{kl_loss:.6f}", loss=f"{train_loss:.6f}", accuracy=f"{train_acc:.2f}")
                    train_loss = 0
                    ce_loss = 0
                    kl_loss = 0
                    correct = 0
                    total = 0

        # Evaluation mode
        net.eval()
        test_loss = 0
        ce_loss = 0
        kl_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(DEVICE), target.to(DEVICE)

                # Forward pass
                outputs = net(data)

                # Compute loss and update evaluation statistics
                ce_loss = loss_func(outputs, target.T[0])
                kl_loss = net.kl_loss() / len(test_loader)
                loss = ce_loss - beta * kl_loss

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += data.shape[0]
                correct += predicted.eq(target.T[0]).sum().item()

        # Log evaluation statistics
        test_acc = 100. * correct / total
        test_loss /= len(test_loader)
        print('Test set: Average loss: {:.4f}, CE loss: {:.4f}, KL loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
            test_loss, ce_loss, kl_loss, test_acc))


In [None]:
class BayesianNeuralNetConv(nn.Module):
    def __init__(self, channel_list, input_channels, n_classes):
        super().__init__()
        self.inputLayer = BayesianConv2D(input_channels, channel_list[0], k_size=3)
        layer_list = []
        for i in range(1, len(channel_list)):
            layer_list.append(BayesianConv2D(channel_list[i - 1], channel_list[i], k_size=1))
        self.convs = nn.ModuleList(layer_list)
        self.fc = BayesianDense(channel_list[-1] * 9, n_classes)
    
    def forward(self, input, sample=False, calculate_log_probs=False):
        x = self.inputLayer(input, sample, calculate_log_probs)
        x = F.rrelu(x)
        x = F.max_pool2d(x, 2)
        for conv in self.convs:
            x = conv(x, sample, calculate_log_probs)
            x = F.rrelu(x)
            x = F.max_pool2d(x, 2)
        x = nn.Flatten()(x)
        x = self.fc(x, sample, calculate_log_probs)
        return x
    
    def kl_loss(self):
        kl = 0.0
        for layer in self.convs:
            kl += layer.kl_loss()
        kl += self.inputLayer.kl_loss()
        kl += self.fc.kl_loss()
        return kl



In [None]:
net = BayesianNeuralNetConv([128, 256, 512], 3, 9)

train_bayesian_net(net, train_loader, test_loader)

Epoch 1:   1%|          | 5/704 [01:04<2:26:28, 12.57s/it, accuracy=0.08, ce_loss=0.225333, kl_loss=-168.958572, loss=0.225350]