# Compute performance for non-Bayesian deep ensembles

In [1]:
import torch
import torch.nn as nn
from typing import Final, List
import itertools
import pandas as pd
import sys
from numpyro.distributions import Normal
sys.path.append('../')
from experiments.fcn_bnns.utils.analysis_utils import *

In [2]:
class MLP(nn.Module):
    """Simple MLP network."""

    def __init__(
        self,
        input_size: int,
        hidden_sizes: List[int],
        activation: nn.modules.activation,
        dropout_ratio: float,
    ) -> None:
        """Instantiate MLP."""
        super().__init__()
        hidden_id = '_'.join([str(x) for x in hidden_sizes])
        self.model_id = f'MLP_{input_size}_{hidden_id}_2'
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.net = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_sizes[0]))
        for i, o in zip(hidden_sizes, hidden_sizes[1:] + [2]):
            self.net.append(activation())
            self.net.append(torch.nn.Linear(i, o))
        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Define forward pass."""
        x = self.net(x)
        return self.dropout(x)

In [3]:
def compute_rmse(y_true, y_pred):
    return torch.sqrt(torch.mean((y_true - y_pred)**2))

In [14]:
DATASETS: Final = ["airfoil", "bikesharing", "concrete", "energy", "yacht", "protein"]
ARCHITECTURES: Final = ["16-16"]
ACTIVATIONS: Final = ["relu"]
REPLICATIONS: Final = [1]  # [1, 2, 3]
BATCH_SIZE: Final = [32, 64]  # [32, 64, -1] 
WEIGHT_DECAY: Final = [0.01, 0.001]  # [0.01, 0.001, 0.0001]
VAL_SIZE: Final = [0.1]
ENSEMBLE_SIZE: Final = 12

In [9]:
main_dir = '../results/'
os.makedirs(os.path.join(main_dir, 'de_perf'), exist_ok=True)

In [10]:
def compute_lppd_de(y_true, mean_pred, std_pred):
    """variant of lppd for de samples"""
    # log_prob_means = []
    # for idx in range(y_true.shape[0]):
    #     yi = y_true[idx]
    #     mean = mean_pred[:, idx].mean(0)  # ensemble mean
    #     std_in = jnp.exp(std_pred[:, idx])  # individual stds
    #     # variance = (jnp.power(std_in, 2) + jnp.power(mean_pred[:, idx], 2)).mean(0) - jnp.power(mean, 2)
    #     variance = (std_in + jnp.power(mean_pred[:, idx], 2)).mean(0) - jnp.power(mean, 2)
    #     std = jnp.power(variance, 0.5)
    #     predictive_prob = Normal(mean, std).log_prob(yi)
    #     log_prob_means.append(predictive_prob)
    # log_prob_means = jnp.array(log_prob_means)
    # return log_prob_means[np.isfinite(log_prob_means)]
    log_prob_means = []
    for idx in range(y_true.shape[0]):
        yi = y_true[idx]
        log_prob_means_m = []
        for m in range(mean_pred.shape[0]):
            mean = mean_pred[m, idx]
            std_in = jnp.power(jnp.exp(std_pred[m, idx]), 0.5)
            predictive_prob = jnp.exp(Normal(mean, std_in).log_prob(yi))
            log_prob_means_m.append(predictive_prob)
        log_prob_means.append(jnp.array(log_prob_means_m).mean(0))
    log_prob_means = jnp.log(jnp.array(log_prob_means))
    return log_prob_means[np.isfinite(log_prob_means)]

## Grid search

In [19]:
rows_grid_search = []
for ds, wd, bs in itertools.product(DATASETS, WEIGHT_DECAY, BATCH_SIZE):
    identifier = [ds, wd, bs]
    dirname = f'{ds}.data|16-16|tanh|wd{str(wd)}|bs{str(bs)}|val|1|'
    exp_info = {"data": f'{ds}.data', "replications": 1}
    regr_dataset = pml.data.dataset.DatasetTabular(
        data_path=f'../data/{ds}.data',
        target_indices=[],
        split_spec={'train': 0.7, 'val': 0.1, 'test': 0.2},
        seed=1,
        standardize=True,
    )
    X_test, Y_test = regr_dataset.get_data(split='test', data_type='jax')
    if ds in ["bikesharing", "protein"]:
        X_test = X_test[:2000, :]
        Y_test = Y_test[:2000, :]
    X_test = torch.from_numpy(np.array(X_test))
    Y_test = torch.from_numpy(np.array(Y_test)).squeeze()
    ensemble_mean = []
    ensemble_sd = []
    for i in range(ENSEMBLE_SIZE):
        weight_dict = torch.load(os.path.join(main_dir, 'de', f"{dirname}/stdict_{i}.pt"))
        model = MLP(
            input_size=X_test.shape[1], 
            hidden_sizes=[16, 16], 
            activation=nn.Tanh, 
            dropout_ratio=0.
        )
        model.load_state_dict(weight_dict)
        outputs = model(X_test)
        ensemble_mean.append(outputs[:, 0])
        ensemble_sd.append(outputs[:, 1])
    ensemble_mean_agg = torch.stack(tuple(ensemble_mean)).detach()
    ensemble_sd_agg = torch.stack(tuple(ensemble_sd)).detach()
    rmse_ensemble = compute_rmse(Y_test, ensemble_mean_agg.mean(0)).numpy()
    lppd_ensemble = compute_lppd_de(Y_test.numpy(), ensemble_mean_agg.numpy(), ensemble_sd_agg.numpy())
    rows_grid_search.append(identifier + [rmse_ensemble])

In [35]:
df = pd.DataFrame(
    rows_grid_search, columns=['dataset', 'weight_decay', 'batch_size', 'rmse']
)
df.sort_values(['dataset', 'weight_decay', 'batch_size'], inplace=True)
df['rmse'] = df['rmse'].apply(lambda x: f'{x:.4f}')
df

Unnamed: 0,dataset,weight_decay,batch_size,rmse
2,airfoil,0.001,32,0.2853
3,airfoil,0.001,64,0.2858
0,airfoil,0.01,32,0.2951
1,airfoil,0.01,64,0.3183
6,bikesharing,0.001,32,0.2832
7,bikesharing,0.001,64,0.2886
4,bikesharing,0.01,32,0.3289
5,bikesharing,0.01,64,0.3392
10,concrete,0.001,32,0.3566
11,concrete,0.001,64,0.3652


In [34]:
print(
    df.to_latex(
        index=False, 
        formatters={"name": str.upper}, 
        float_format="{:.4f}".format,
    )
)

\begin{tabular}{lrrl}
\toprule
dataset & weight_decay & batch_size & rmse \\
\midrule
airfoil & 0.0010 & 32 & 0.2853 \\
airfoil & 0.0010 & 64 & 0.2858 \\
airfoil & 0.0100 & 32 & 0.2951 \\
airfoil & 0.0100 & 64 & 0.3183 \\
bikesharing & 0.0010 & 32 & 0.2832 \\
bikesharing & 0.0010 & 64 & 0.2886 \\
bikesharing & 0.0100 & 32 & 0.3289 \\
bikesharing & 0.0100 & 64 & 0.3392 \\
concrete & 0.0010 & 32 & 0.3566 \\
concrete & 0.0010 & 64 & 0.3652 \\
concrete & 0.0100 & 32 & 0.3592 \\
concrete & 0.0100 & 64 & 0.3609 \\
energy & 0.0010 & 32 & 0.2156 \\
energy & 0.0010 & 64 & 0.2134 \\
energy & 0.0100 & 32 & 0.2123 \\
energy & 0.0100 & 64 & 0.2165 \\
protein & 0.0010 & 32 & 0.7280 \\
protein & 0.0010 & 64 & 0.7281 \\
protein & 0.0100 & 32 & 0.7943 \\
protein & 0.0100 & 64 & 0.7937 \\
yacht & 0.0010 & 32 & 0.6188 \\
yacht & 0.0010 & 64 & 0.5277 \\
yacht & 0.0100 & 32 & 0.5355 \\
yacht & 0.0100 & 64 & 0.6066 \\
\bottomrule
\end{tabular}


In [None]:
df.to_csv(os.path.join(os.path.join(main_dir, 'de_perf'), 'de_perf_grid_search.csv'))

## Performance for ReLU

In [None]:
rows_bnn_relu = []
for ds, rep in itertools.product(DATASETS, REPLICATIONS):
    identifier = [ds, rep]
    dirname = f'{ds}.data|16-16|relu|{str(rep)}|'
    exp_info = {"data": f'{ds}.data', "replications": rep}
    regr_dataset = pml.data.dataset.DatasetTabular(
        data_path=f'../data/{ds}.data',
        target_indices=[],
        split_spec={'train': 0.8, 'test': 0.2},
        seed=rep,
        standardize=True,
    )
    X_test, Y_test = regr_dataset.get_data(split='test', data_type='jax')
    if ds in ["bikesharing", "protein"]:
        X_test = X_test[:2000, :]
        Y_test = Y_test[:2000, :]
    X_test = torch.from_numpy(np.array(X_test))
    Y_test = torch.from_numpy(np.array(Y_test)).squeeze()
    ensemble_mean = []
    ensemble_sd = []
    for i in range(ENSEMBLE_SIZE):
        weight_dict = torch.load(os.path.join(main_dir, 'de', f"{dirname}/stdict_{i}.pt"))
        model = MLP(
            input_size=X_test.shape[1], 
            hidden_sizes=[16, 16], 
            activation=nn.ReLU, 
            dropout_ratio=0.
        )
        model.load_state_dict(weight_dict)
        outputs = model(X_test)
        ensemble_mean.append(outputs[:, 0])
        ensemble_sd.append(outputs[:, 1])
    ensemble_mean_agg = torch.stack(tuple(ensemble_mean)).detach()
    ensemble_sd_agg = torch.stack(tuple(ensemble_sd)).detach()
    rmse_ensemble = compute_rmse(Y_test, ensemble_mean_agg.mean(0)).numpy()
    rmse_individual = [compute_rmse(Y_test, ensemble_mean_agg[i]).numpy() for i in range(ensemble_mean_agg.shape[0])]
    lppd_ensemble = compute_lppd_de(Y_test.numpy(), ensemble_mean_agg.numpy(), ensemble_sd_agg.numpy())
    # lppd_ensemble_finite = lppd_ensemble[np.isfinite(lppd_ensemble)].mean(0)
    lppd_individual = [
        compute_lppd_de(Y_test.numpy(), ensemble_mean_agg[i].unsqueeze(0).numpy(), ensemble_sd_agg[i].unsqueeze(0).numpy()) 
        for i in range(ensemble_mean_agg.shape[0])
    ]
    lppd_individual = [lppd_individual[i].mean(0) for i in range(len(lppd_individual))]
    rmse_individual_avg = sum(rmse_individual) / len(rmse_individual)
    lppd_individual_avg = sum(lppd_individual) / len(lppd_individual)
    rows_bnn_relu.append(identifier + [rmse_ensemble, rmse_individual_avg, lppd_ensemble.mean(0), lppd_individual_avg])

In [None]:
df = pd.DataFrame(
    rows_bnn_relu, columns=['dataset', 'rep', 'rmse_ensemble', 'rmse_ind', 'lppd_ensemble', 'lppd_ind']
)
df

In [None]:
df.to_csv(os.path.join(os.path.join(main_dir, 'de_perf_playground'), 'aggregated_data_de.csv'))