In [1]:
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
from torchvision.transforms import v2 as transforms
import numpy as np
import wandb
import time
import matplotlib.pyplot as plt

In [2]:
generator = torch.Generator().manual_seed(42)
np.random.seed(42)

In [3]:

class GaussianVariational(nn.Module):

    def __init__(self, mu: torch.Tensor, rho: torch.Tensor) -> None:
        super().__init__()

        self.mu = nn.Parameter(mu)
        self.rho = nn.Parameter(rho)

        self.w = None
        self.sigma = None

        self.normal = torch.distributions.Normal(0, 1)

    def sample(self) -> torch.Tensor:
        device = self.mu.device
        epsilon = self.normal.sample(self.mu.size()).to(device)
        self.sigma = torch.log1p(torch.exp(self.rho))
        self.w = self.mu + self.sigma * epsilon

        return self.w

    def log_posterior(self) -> torch.Tensor:
        assert self.w is not None

        log_const = np.log(np.sqrt(2 * np.pi))
        log_exp = ((self.w - self.mu) ** 2) / (2 * self.sigma ** 2)
        log_posterior = -log_const - torch.log(self.sigma) - log_exp

        return log_posterior.sum()


class ScaleMixture(nn.Module):

    def __init__(self, pi: float, sigma1: float, sigma2: float) -> None:
        super().__init__()

        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2

        self.normal1 = torch.distributions.Normal(0, sigma1)
        self.normal2 = torch.distributions.Normal(0, sigma2)

    def log_prior(self, w: torch.Tensor) -> torch.Tensor:
        likelihood_n1 = torch.exp(self.normal1.log_prob(w))
        likelihood_n2 = torch.exp(self.normal2.log_prob(w))

        p_scalemixture = self.pi * likelihood_n1 + (1 - self.pi) * likelihood_n2
        log_prob = torch.log(p_scalemixture).sum()

        return log_prob


class BayesianModule(nn.Module):
    pass


class BayesLinear(BayesianModule):

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 prior_pi: Optional[float] = 0.5,
                 prior_sigma1: Optional[float] = 1.0,
                 prior_sigma2: Optional[float] = 0.0025) -> None:
        super().__init__()

        w_mu = torch.empty(out_features, in_features).uniform_(-0.2, 0.2, generator=generator)
        w_rho = torch.empty(out_features, in_features).uniform_(-5.0, -4.0, generator=generator)

        bias_mu = torch.empty(out_features).uniform_(-0.2, 0.2, generator=generator)
        bias_rho = torch.empty(out_features).uniform_(-5.0, -4.0, generator=generator)

        self.w_posterior = GaussianVariational(w_mu, w_rho)
        self.bias_posterior = GaussianVariational(bias_mu, bias_rho)

        self.w_prior = ScaleMixture(prior_pi, prior_sigma1, prior_sigma2)
        self.bias_prior = ScaleMixture(prior_pi, prior_sigma1, prior_sigma2)

        self.kl_divergence = 0.0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.w_posterior.sample()
        b = self.bias_posterior.sample()

        w_log_prior = self.w_prior.log_prior(w)
        b_log_prior = self.bias_prior.log_prior(b)

        w_log_posterior = self.w_posterior.log_posterior()
        b_log_posterior = self.bias_posterior.log_posterior()

        total_log_prior = w_log_prior + b_log_prior
        total_log_posterior = w_log_posterior + b_log_posterior
        self.kl_divergence = total_log_posterior - total_log_prior

        return F.linear(x, w, b)


def minibatch_weight(batch_idx: int, num_batches: int) -> float:
    return 2 ** (num_batches - batch_idx) / (2 ** num_batches - batch_idx)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# MNIST classification

In [5]:
class MNISTModel(nn.Module):

    def __init__(self, in_features=28 * 28, out_features=10, prior_sigma_1=0.1, prior_sigma_2=0.4, prior_pi=1):
        super().__init__()

        self.layers = nn.Sequential(
            BayesLinear(
                in_features,
                in_features,
                prior_pi,
                prior_sigma_1,
                prior_sigma_2
            ),
            nn.ReLU(),
            BayesLinear(
                in_features,
                in_features,
                prior_pi,
                prior_sigma_1,
                prior_sigma_2
            ),
            nn.ReLU(),
            BayesLinear(
                in_features,
                out_features,
                prior_pi,
                prior_sigma_1,
                prior_sigma_2,
            ),
            nn.Softmax(dim=1),
        )

    def forward(self, x):
        x = self.layers(x)
        # print(x)
        return x

    @property
    def kl_divergence(self):
        kl = 0
        for module in self.modules():
            if isinstance(module, BayesianModule):
                kl += module.kl_divergence

        return kl

    def sample_elbo(self, inputs, labels, criterion, num_samples, complexity_cost_weight=1):
        loss = 0
        for _ in range(num_samples):
            outputs = self(inputs)
            contr1 = criterion(outputs, labels)
            contr2 = self.kl_divergence * complexity_cost_weight
            # print(f"contr1: {contr1}, contr2: {contr2}")
            loss += contr1 + contr2
        return loss / num_samples

In [6]:
def train_one_epoch(model, train_loader, optimizer, criterion, num_samples=1):
    model.train()

    total_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        kl_weight = minibatch_weight(batch_idx, len(train_loader))

        loss = model.sample_elbo(data, target, criterion, num_samples, kl_weight)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


def evaluate(model, test_loader, criterion):
    model.eval()

    total_loss = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)

            loss = criterion(output, target) + model.kl_divergence * minibatch_weight(batch_idx, len(test_loader))
            total_loss += loss.item()

            preds = torch.argmax(output, 1)
            correct += (preds == target).sum().item()

            # print(f"Predictions: {preds}, Targets: {target}")

    total = len(test_loader.dataset)
    error = (total - correct) / total

    print(f"Correct: {correct}/{total} ({correct / total:.2%})")
    return total_loss / total, error


def train(model, train_loader, val_loader, optimizer, criterion, num_epochs, num_samples, use_wandb=False):
    for epoch in range(num_epochs):
        now = time.time()

        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, num_samples)
        val_loss, val_error = evaluate(model, val_loader, criterion)

        elapsed = time.time() - now

        if use_wandb:
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_error": val_error
            })

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Error: {val_error:.4f}, Time: {elapsed:.2f}s")

In [7]:
def train_mnist(train_loader, val_loader, epochs, lr, num_samples, pi, minus_log_sigma1, minus_log_sigma2, use_wandb=False):
    sigma1 = np.exp(-minus_log_sigma1)
    sigma2 = np.exp(-minus_log_sigma2)

    model = MNISTModel(prior_sigma_1=sigma1, prior_sigma_2=sigma2, prior_pi=pi)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(reduction='sum')

    if use_wandb:
        run = wandb.init(project="asi-paper", name="mnist")

    train(model, train_loader, val_loader, optimizer, criterion, epochs, num_samples, use_wandb=use_wandb)

    if use_wandb:
        run.finish()
    
    return model

In [8]:
batch_size = 128
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: x.view(28 * 28)),
])


mnist_dataset = datasets.MNIST(
    root="./mnist",
    download=True,
    transform=transform,
    train=True
)
# transformed_data = transform(mnist_dataset.data).to(device)
# y = mnist_dataset.targets.to(device)
# mnist_dataset = torch.utils.data.TensorDataset(transformed_data, y)


train_dataset, val_dataset = torch.utils.data.random_split(mnist_dataset, [50_000, 10_000], generator=generator)

kwargs = {
    'batch_size': batch_size,
    'num_workers': 4,
    'generator': generator,
}

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    **kwargs
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=False,
    **kwargs
)

100%|██████████| 9.91M/9.91M [00:00<00:00, 10.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 343kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.19MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.44MB/s]


## Grid search with wandb
Uncomment the code below to run a grid search and log the results to wandb.

In [9]:
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# key = user_secrets.get_secret('wand-api-key-asi')

# wandb.login(key=key)


# def train_wrapper():
#     with wandb.init(project="asi-paper") as run:
#         model = train_mnist(
#             train_loader,
#             val_loader,
#             epochs=10,
#             lr=run.config.lr,
#             num_samples=run.config.sample_nbr,
#             pi=run.config.pi,
#             minus_log_sigma1=run.config.min_log_sigma1,
#             minus_log_sigma2=run.config.min_log_sigma2,
#             use_wandb=True
#         )

#     return model


# sweep_configuration = {
#     "method": "grid",
#     "metric": {"goal": "minimize", "name": "val_error"},
#     'name': "sweep-mnist",
#     "parameters": {
#         "lr": {'values': [1e-3, 1e-4, 1e-5]},
#         "sample_nbr": {'values': [1, 2, 5, 10]},
#         "pi": {'values': [0.25, 0.5, 0.75]},
#         "min_log_sigma1": {'values': [0, 1, 2]},
#         "min_log_sigma2": {'values': [6, 7, 8]},
#     },
# }

# sweep_id = wandb.sweep(sweep=sweep_configuration, project="asi-paper")
# wandb.agent(sweep_id, function=train_wrapper)

## Manual training
Uncomment the code below to train the model with specified hyperparameters and save the model checkpoint.

In [10]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
key = user_secrets.get_secret('wand-api-key-asi')

wandb.login(key=key)

model = train_mnist(train_loader, val_loader, epochs=600, lr=1e-3, num_samples=1, pi=0.75, minus_log_sigma1=1, minus_log_sigma2=7, use_wandb=True)
torch.save(model.state_dict(), "mnist_model.pt")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmatteo-ghia[0m ([33mmatteo-ghia-politecnico-di-torino[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.19.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250516_084521-w1s668z4[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mmnist[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/matteo-ghia-politecnico-di-torino/asi-paper[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/matteo-ghia-politecnico-di-torino/asi-paper/runs/w1s668z4[0m


Correct: 6616/10000 (66.16%)
Epoch 1/600, Train Loss: 21091.2744, Val Loss: 808.8482, Val Error: 0.3384, Time: 16.73s
Correct: 6741/10000 (67.41%)
Epoch 2/600, Train Loss: 20860.9350, Val Loss: 799.0597, Val Error: 0.3259, Time: 16.08s
Correct: 6784/10000 (67.84%)
Epoch 3/600, Train Loss: 20602.4784, Val Loss: 788.7393, Val Error: 0.3216, Time: 16.65s
Correct: 7641/10000 (76.41%)
Epoch 4/600, Train Loss: 20321.1175, Val Loss: 777.9219, Val Error: 0.2359, Time: 16.47s
Correct: 8587/10000 (85.87%)
Epoch 5/600, Train Loss: 20036.6087, Val Loss: 767.8271, Val Error: 0.1413, Time: 16.12s
Correct: 8664/10000 (86.64%)
Epoch 6/600, Train Loss: 19781.2598, Val Loss: 759.2685, Val Error: 0.1336, Time: 16.53s
Correct: 9601/10000 (96.01%)
Epoch 7/600, Train Loss: 19546.2073, Val Loss: 749.9074, Val Error: 0.0399, Time: 15.93s
Correct: 9637/10000 (96.37%)
Epoch 8/600, Train Loss: 19312.7649, Val Loss: 742.6028, Val Error: 0.0363, Time: 15.92s
Correct: 9676/10000 (96.76%)
Epoch 9/600, Train Loss: 19

[34m[1mwandb[0m: uploading wandb-summary.json; uploading config.yaml
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:      epoch ▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇██
[34m[1mwandb[0m: train_loss ██▆▆▅▄▄▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  val_error ▁▁▁▂▁▂▂▂▂▂▂▃▃▃▄▄▅▅▅▅▆▆▆▇▇█▇██▇████▇█████
[34m[1mwandb[0m:   val_loss █▇▇▇▆▅▅▅▅▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:      epoch 599
[34m[1mwandb[0m: train_loss 2049.31952
[34m[1mwandb[0m:  val_error 0.1829
[34m[1mwandb[0m:   val_loss 73.61523
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mmnist[0m at: [34m[4mhttps://wandb.ai/matteo-ghia-politecnico-di-torino/asi-paper/runs/w1s668z4[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/matteo-ghia-politecnico-di-torino/asi-paper[0m
[34

In [11]:
# model.load_state_dict(torch.load("mnist_model.pt"))

# Regression curves

In [12]:
# def generate_samples(num_samples):
#     eps = np.random.normal(0, 0.02, num_samples)
#     x = np.linspace(0, 0.5, num_samples)
#     y = x + 0.3 * np.sin(2 * np.pi * (x + eps)) + 0.3 * np.sin(4 * np.pi * (x + eps))
#     return x, y


# def plot_samples(x, y):
#     plt.figure(figsize=(10, 5))
#     plt.plot(x, y, 'o', label='Generated Samples')
#     plt.title('Generated Samples')
#     plt.xlabel('x')
#     plt.ylabel('y')
#     plt.legend()
#     plt.show()


# x, y = generate_samples(10000)
# plot_samples(x, y)