# Legacy code 

In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset

We can also try to train MIMO and Naive models with $M=2$ subnetworks using $K=10$ repititions to see how the variance generalises across repititions, as described by Havasi et al. in their explainer notebook. 
This means that we train $K=10$ versions of the MIMO models (on different sampled data) and then evaluate what they describe as the expected mean squared error 
$$
\mathcal{E}_M = \mathbb{E}_{\bm{x}',y' \in \mathbb{X}_{\text{test}}} [\mathbb{E}_{\mathbb{X}}[(\hat{f}_M(\bm{x}',...,\bm{x}')- y')^2]]
$$
Through bias-variance decomposition, we can get the variance as:
$$
\mathbb{E}_{\bm{x}',y' \in \mathbb{X}_{\text{test}}} [\mathbb{E}_{\mathbb{X}}[(\bar{f}_M(\bm{x}',...,\bm{x}')- \hat{f}_M(\bm{x}',...,\bm{x}'))^2]]
$$
where $\hat{f}_M$ is a regressor with $M$ subnetworks and $\bar{f}_M = \frac{1}{K} \sum_{k=1}^K \hat{f}(x)$ 

In [None]:
# useful functions 🤖

def train(model, optimizer, trainloader, valloader, epochs=500, model_name='MIMO', val_every_n_epochs=10):

    losses = []
    val_losses = []

    best_val_loss = np.inf

    for e in tqdm(range(epochs)):
        
        for x_, y_ in trainloader:

            model.train()

            x_,y_ = x_.float(), y_.float()

            optimizer.zero_grad()

            output, individual_outputs = model(x_)
            loss = nn.functional.mse_loss(individual_outputs, y_)

            loss.backward()
            optimizer.step()

            losses.append(loss.item())  

        if (e+1) % val_every_n_epochs == 0:
            model.eval()

            val_loss_list = []
            with torch.no_grad():
                for val_x, val_y in valloader:
                    val_x, val_y = val_x.float(), val_y.float()
                    val_output, val_individual_outputs = model(val_x)
                    val_loss = nn.functional.mse_loss(val_individual_outputs, val_y)
                    val_loss_list.append(val_loss.item())

            val_losses.extend(val_loss_list)
            mean_val_loss = np.mean(val_loss_list)
            if mean_val_loss < best_val_loss:
                best_val_loss = mean_val_loss
                torch.save(model, f'{model_name}.pt')
            # print(f"Mean validation loss at epoch {e}: {mean_val_loss}")

    return losses, val_losses

def get_train_val_dataloaders(N_train=500, N_val=200, is_naive=False):
    # Generate train data
    x, y = generate_data(N_train, lower=-0.25, upper=1, std=0.02)

    # Generate validation data
    x_val, y_val = generate_data(N_val, lower=-0.25, upper=1, std=0.02)

    # make dataset and get dataloaders
    traindata = ToyDataset(x, y)
    valdata = ToyDataset(x_val, y_val)
    
    if is_naive:
        trainloader = DataLoader(traindata, batch_size=60, shuffle=True, collate_fn=lambda x: naive_collate_fn(x, M), drop_last=False)
        valloader = DataLoader(valdata, batch_size=60, shuffle=False, collate_fn=lambda x: naive_collate_fn(x, M), drop_last=False)
    else:
        trainloader = DataLoader(traindata, batch_size=60*M, shuffle=True, collate_fn=lambda x: train_collate_fn(x, M), drop_last=True)
        valloader = DataLoader(valdata, batch_size=60, shuffle=False, collate_fn=lambda x: test_collate_fn(x, M), drop_last=False)  
    
    return trainloader, valloader



def train_K_repitions(M, epochs=500, val_every_n_epochs=10, repititions=20, is_naive=False):

    K_losses = {}
    K_val_losses = {}

    for k in tqdm(range(repititions)):

        model = NaiveNetwork(n_subnetworks=M) if is_naive else MIMONetwork(n_subnetworks=M)
        name = 'naive_models/Naive' if is_naive else 'mimo_models/MIMO'

        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
        
        trainloader, valloader = get_train_val_dataloaders(N_train, N_val, is_naive=is_naive)

        losses, val_losses = train(model, optimizer, trainloader, valloader, epochs=epochs, model_name=f'{name}_{k}', val_every_n_epochs=val_every_n_epochs)

        K_losses[k] = losses
        K_val_losses[k] = val_losses

    return K_losses, K_val_losses


def plot_loss(losses, val_losses):

    fig, ax = plt.subplots(1,2, figsize=(12,6))

    ax[0].plot(losses, label='Train loss')
    ax[0].set_title('Train loss')
    ax[0].set_xlabel('Iterations')
    ax[0].set_ylabel('Loss')

    ax[1].plot(val_losses, label='Validation loss', color='orange')
    ax[1].set_title('Validation loss')
    ax[1].set_xlabel('Iterations')
    ax[1].set_ylabel('Loss')

    plt.show()

In [None]:
# train MIMO model
M = 2
# number of repititions
K = 10
is_naive = False

K_losses, K_val_losses = train_K_repitions(M, epochs=5000, val_every_n_epochs=2, repititions=K, is_naive=is_naive)

In [None]:
# plot loss for one of the repititions
plot_loss(K_losses[0], K_val_losses[0])

In [None]:
# train Naive model
M = 2
# number of repititions
K = 10
is_naive = True

K_losses, K_val_losses = train_K_repitions(M, epochs=500, val_every_n_epochs=2, repititions=K, is_naive=is_naive)


In [None]:
# plot loss for one of the repititions
plot_loss(K_losses[0], K_val_losses[0])

In [None]:
K = 10
naive_dir = "naive_models"
mimo_dir = "mimo_models"

naive_models = [os.path.join(naive_dir, f) for f in os.listdir(naive_dir)]
mimo_models = [os.path.join(mimo_dir, f) for f in os.listdir(mimo_dir)]

def rep_inference(models, testloader)
    f_hats = np.zeros((K, N_test))

    for i, mod in enumerate(models):

        for test_x, test_y in testloader:

            model = torch.load(mod)
            model.eval()

            output, individual_outputs = model(test_x.float())
            
            f_hats[i,:N_test] = output.detach().numpy()

    return f_hats

def get_predictions(f_hats):
    f_bar = np.mean(f_hats, axis=0)
    variances = np.mean(np.array([(f_bar - f_hat)**2 for f_hat in f_hats]), axis=0)

    return f_bar, variances

naive_f_hats = rep_inference(naive_models, naivetestloader)
mimo_f_hats = rep_inference(mimo_models, testloader)

naive_predcitons, naive_variances = get_predictions(naive_f_hats)
mimo_predictions, mimo_variances = get_predictions(mimo_f_hats)
