In [10]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
from torchvision.transforms import v2 as transforms
import numpy as np
import wandb

In [11]:
class TrainableNormalDistribution(nn.Module):
    LOG_SQRT2PI = np.log(np.sqrt(2 * np.pi))

    def __init__(self, mu, rho):
        super().__init__()

        self.mu = nn.Parameter(mu)
        self.rho = nn.Parameter(rho)
        self.register_buffer('eps', torch.Tensor(self.mu.shape))
        self.sigma = None
        self.w = None

    def sample(self):
        self.eps.data.normal_()
        self.sigma = torch.log1p(torch.exp(self.rho))
        self.w = self.mu + self.sigma * self.eps
        return self.w

    def log_posterior(self):
        assert (self.w is not None), "You can only have a log posterior for W if you've already sampled it"

        log_posteriors = -TrainableNormalDistribution.LOG_SQRT2PI - torch.log(self.sigma) - (((self.w - self.mu) ** 2) / (2 * self.sigma ** 2)) - 0.5
        return log_posteriors.sum()


class PriorWeightDistribution(nn.Module):
    # Calculates a Scale Mixture Prior distribution for the prior part of the complexity cost on Bayes by Backprop paper
    def __init__(self, pi, sigma1, sigma2):
        super().__init__()

        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.dist1 = torch.distributions.Normal(0, sigma1)
        self.dist2 = torch.distributions.Normal(0, sigma2)

    def log_prior(self, w):
        prob_n1 = torch.exp(self.dist1.log_prob(w))
        prob_n2 = torch.exp(self.dist2.log_prob(w))

        # Prior of the mixture distribution, adding 1e-6 prevents numeric problems with log(p) for small p
        prior_pdf = (self.pi * prob_n1 + (1 - self.pi) * prob_n2) + 1e-6

        return (torch.log(prior_pdf) - 0.5).sum()


class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True, prior_sigma_1=0.1, prior_sigma_2=0.4, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-7.0, prior_dist=None):
        super().__init__()

        # our main parameters
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias

        # parameters for the scale mixture prior
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi
        self.prior_dist = prior_dist

        # Variational weight parameters and sample
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1))
        self.weight_sampler = TrainableNormalDistribution(self.weight_mu, self.weight_rho)

        # Variational bias parameters and sample
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(posterior_mu_init, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(posterior_rho_init, 0.1))
        self.bias_sampler = TrainableNormalDistribution(self.bias_mu, self.bias_rho)

        # Priors (as BBP paper)
        self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2)
        self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x):
        # Sample the weights and forward it
        w = self.weight_sampler.sample()

        if self.bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)
        else:
            b = torch.zeros((self.out_features), device=x.device)
            b_log_posterior = 0
            b_log_prior = 0

        # Get the complexity cost
        self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior
        self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior

        # print(x.shape, w.shape, b.shape)
        return F.linear(x, w, b)

    @property
    def kl_divergence(self):
        return self.log_variational_posterior - self.log_prior


def minibatch_weight(batch_idx: int, num_batches: int) -> float:
    return 2 ** (num_batches - batch_idx) / (2 ** num_batches - batch_idx)

In [12]:
class MNISTModel(nn.Module):

    def __init__(self, in_features=28 * 28, out_features=10, prior_sigma_1=0.1, prior_sigma_2=0.4, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-7.0):
        super().__init__()

        self.layers = nn.Sequential(
            BayesianLinear(
                in_features, in_features,
                prior_sigma_1=prior_sigma_1,
                prior_sigma_2=prior_sigma_2,
                prior_pi=prior_pi,
                posterior_mu_init=posterior_mu_init,
                posterior_rho_init=posterior_rho_init
            ),
            nn.ReLU(),
            BayesianLinear(
                in_features, in_features,
                prior_sigma_1=prior_sigma_1,
                prior_sigma_2=prior_sigma_2,
                prior_pi=prior_pi,
                posterior_mu_init=posterior_mu_init,
                posterior_rho_init=posterior_rho_init
            ),
            nn.ReLU(),
            BayesianLinear(
                in_features, out_features,
                prior_sigma_1=prior_sigma_1,
                prior_sigma_2=prior_sigma_2,
                prior_pi=prior_pi,
                posterior_mu_init=posterior_mu_init,
                posterior_rho_init=posterior_rho_init
            ),
            nn.Softmax(dim=1),
        )

    def forward(self, x):
        x = self.layers(x)
        # print(x)
        return x

    @property
    def kl_divergence(self):
        kl = 0
        for module in self.modules():
            kl += getattr(module, 'kl_divergence', 0) if module != self else 0
        return kl

    def sample_elbo(self, inputs, labels, criterion, num_samples, complexity_cost_weight=1):
        loss = 0
        for _ in range(num_samples):
            outputs = self(inputs)
            contr1 = criterion(outputs, labels)
            contr2 = self.kl_divergence * complexity_cost_weight
            # print(f"contr1: {contr1}, contr2: {contr2}")
            loss += contr1
        return loss / num_samples

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [14]:
def train_one_epoch(model, train_loader, optimizer, criterion, num_samples=1):
    model.train()

    total_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        kl_weight = minibatch_weight(batch_idx, len(train_loader))

        loss = model.sample_elbo(data, target, criterion, num_samples, kl_weight)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


def evaluate(model, test_loader, criterion):
    model.eval()

    total_loss = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)

            loss = criterion(output, target) + model.kl_divergence * minibatch_weight(batch_idx, len(test_loader))
            total_loss += loss.item()

            preds = torch.argmax(output, 1)
            correct += (preds == target).sum().item()

            # print(f"Predictions: {preds}, Targets: {target}")

    total = len(test_loader.dataset)
    error = (total - correct) / total

    print(f"Correct: {correct}/{total} ({correct / total:.2%})")
    return total_loss / total, error


def train(model, train_loader, val_loader, optimizer, criterion, num_epochs, num_samples, use_wandb=False):
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, num_samples)
        val_loss, val_error = evaluate(model, val_loader, criterion)

        if use_wandb:
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_error": val_error
            })

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Error: {val_error:.4f}")

In [15]:
def train_mnist(train_loader, val_loader, epochs, lr, num_samples, pi, minus_log_sigma1, minus_log_sigma2, use_wandb=False):
    sigma1 = np.exp(-minus_log_sigma1)
    sigma2 = np.exp(-minus_log_sigma2)

    model = MNISTModel(prior_sigma_1=sigma1, prior_sigma_2=sigma2, prior_pi=pi)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(reduction='sum')

    # if use_wandb:
    #     wandb.init(project="asi-paper", name="mnist")

    train(model, train_loader, val_loader, optimizer, criterion, epochs, num_samples, use_wandb=use_wandb)

    return model

In [16]:
batch_size = 128
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Lambda(lambda x: x.view(28 * 28) / 126.0),
])


mnist_dataset = datasets.MNIST(
    root="./mnist",
    download=True,
    transform=transform,
    train=True
)

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = torch.utils.data.random_split(mnist_dataset, [50_000, 10_000], generator=generator)

kwargs = {
    'batch_size': batch_size,
    'num_workers': 1,
    'pin_memory': True,
    'pin_memory_device': 'cuda',
    'generator': generator,
}

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    **kwargs
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False,
    **kwargs
)

# Grid search with wandb

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
key = user_secrets.get_secret('wandb-api-key')

wandb.login(key=key)


def train_wrapper():
    with wandb.init(project="asi-paper") as run:
        model = train_mnist(
            train_loader,
            val_loader,
            epochs=50,
            lr=run.config.lr,
            num_samples=run.config.sample_nbr,
            pi=run.config.pi,
            minus_log_sigma1=run.config.min_log_sigma1,
            minus_log_sigma2=run.config.min_log_sigma2,
            use_wandb=True
        )

    return model


sweep_configuration = {
    "method": "grid",
    "metric": {"goal": "minimize", "name": "val_loss"},
    'name': "sweep-mnist",
    "parameters": {
        "lr": {'values': [1e-3, 1e-4, 1e-5]},
        "sample_nbr": {'values': [1, 2, 5, 10]},
        "pi": {'values': [0.25, 0.5, 0.75]},
        "min_log_sigma1": {'values': [0, 1, 2]},
        "min_log_sigma2": {'values': [6, 7, 8]},
    },
}

sweep_id = wandb.sweep(sweep=sweep_configuration, project="asi-paper")
wandb.agent(sweep_id, function=train_wrapper)

# Manual training

In [None]:
# model = train_mnist(train_loader, val_loader, epochs=10, lr=0.01, num_samples=5, pi=0.3, minus_log_sigma1=2, minus_log_sigma2=6)
# torch.save(model.state_dict(), "mnist_model.pt")

Correct: 8903/10000 (89.03%)
Epoch 1/10, Train Loss: 217.0029, Val Loss: 1588.3174, Val Error: 0.1097
Correct: 9290/10000 (92.90%)
Epoch 2/10, Train Loss: 195.7099, Val Loss: 1656.0879, Val Error: 0.0710


KeyboardInterrupt: 

In [None]:
# model.load_state_dict(torch.load("mnist_model.pt"))