### 1. Set your experiment settings

In [None]:
import torch
import numpy as np

exp_name = "Test"
config = {
         "prior_type": "prob",
         "temperature": 0.2, 
         "batch_size": 250,
         "perc_prior": 0.8,
         "alpha":0.4,
         "data_name":"mnist",
         "prior_epochs": 1,
         "learning_rate_prior": 0.5, 
         "momentum_prior": 0.9,
         "sigma_prior": 0.01,
         "posterior_epochs": 1,
         "learning_rate": 0.1,
         "momentum": 0.8,
         "kl_penalty": 1,
         "objective": "fclassic",
         "mc_samples": 1
         }
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
print(f"Init Experiment {exp_name} with settings:")
print(config)

### 2. Split train, test, prior, posterior dataset

In [None]:
from data import load_train_test, data_transform, SimCLRAugmentedDataset
from torch.utils.data import random_split

data_name = config["data_name"]
perc_prior = config["perc_prior"]
transform = None
train, test = load_train_test(name=data_name, transform=transform)
prior_size = int(perc_prior * len(train))
posterior_size = len(train) - prior_size
prior_dataset, posterior_dataset = random_split(train, [prior_size, posterior_size])
print(f"Size test {len(test)} | Size prior {len(prior_dataset)} | Size posterior {len(posterior_dataset)}")
dataset_list = [train, test, prior_dataset, posterior_dataset]
augmented_dataset_list = [SimCLRAugmentedDataset(dataset, name=data_name) for dataset in dataset_list]

### 3. Train prior and posterior models

In [None]:
from run import ExperimentRunner

exp_runner = ExperimentRunner(config)
prior_type = config["prior_type"]
print(f"Starting training of the {prior_type} prior")
if prior_type == "det":
    exp_runner.train_prior(prior_dataset)
elif prior_type == "prob":
    exp_runner.train_prob_prior(prior_dataset)
print(f"Starting training of the posterior using the learned {prior_type} prior")
exp_runner.train_posterior(train)

### 4. Compute risk certificate using the augmented posterior dataset

In [None]:
exp_runner.risk_cert.forward(net=exp_runner.posterior_model, augmented_dataset=augmented_dataset_list[3])

### 5. Compute test losses

In [None]:
from evaluate import evaluate_contrastive_loss
from loss import ZeroOneLoss, SimplifiedContrastiveLoss, ContrastiveLoss

device = exp_runner.device
temperature =  config["temperature"]
batch_size = config["batch_size"]
list_contrastive_loss = [ZeroOneLoss(), SimplifiedContrastiveLoss(temperature=temperature)]
list_loss_names = ["Contrastive zero-one Loss", "Simplified contrastive loss"]
list_dataset_names = ["train", "test", "prior", "posterior"]
for idx in [0, 1]:
    augmented_dataset = augmented_dataset_list[idx]
    name = list_dataset_names[idx]
    augmented_loader = torch.utils.data.DataLoader(
    augmented_dataset, batch_size=batch_size, shuffle=False)
    print(f"Metrics for the {name} dataset:")
    for contrastive_loss, loss_name in zip(list_contrastive_loss, list_loss_names):
        loss_value = evaluate_contrastive_loss(exp_runner.posterior_model, augmented_loader, contrastive_loss, device)
        print(f"\u2001 -{loss_name}: {loss_value:.4f}")
        if loss_name== "Simplified contrastive loss" and name=="train":
            save_loss_value = loss_value

### 6. Train linear classifier

In [None]:
from linear_classifier import LinearClassifier
from torch.utils.data import DataLoader

data_name = config["data_name"]
num_epochs = 1
transform = data_transform(data_name=data_name)
train_sup, test_sup = load_train_test(name=data_name, transform=transform)
test_loader = DataLoader(test_sup, batch_size=250, shuffle=False)
train_loader = DataLoader(train_sup, batch_size=250, shuffle=True)

projection_options = [False, True]
for projection in projection_options:
    print(f"Linear classifier {'with' if projection else 'without'} projection head")
    model = LinearClassifier(exp_runner.posterior_model, projection=projection, data_name=data_name).to(exp_runner.device)
    model.train_classifier(train_loader, num_epochs=num_epochs, lr=0.01)
    model.test_classifier(test_loader)

### 7. Compute bound on downstream classification

In [None]:
from transfer_bound import Sigma, Bound
from torch.utils.data import DataLoader
import numpy as np
temperature =  config["temperature"]
m = config["batch_size"]
neg_samples = m-1
sigma = Sigma()
transform = data_transform(data_name=data_name)
train_sup, test_sup = load_train_test(name=data_name, transform=transform)
train_loader = DataLoader(train_sup, batch_size=250, shuffle=True)
sigma_value = sigma.forward(exp_runner.posterior_model, train_loader)
bound = Bound(tau=temperature, num_neg_samples=neg_samples)
bound_values = [bound.forward(save_loss_value, sigma_value, index=idx) for idx in [1, 2]]
print("Upper-bounds on the linear classifier loss:")
print(f"\u2001 -Bao et al. : {bound_values[0]:.4f}")
print(f"\u2001 -Theorem 3 : {np.min(bound_values):.4f}")