In [None]:
import random
from copy import deepcopy


import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('../')
from experiments.fcn_bnns.utils.analysis_utils import *
from tqdm import tqdm
from typing import Dict, List
import probabilisticml as pml
# if already loaded unload and reload probabilisticml
import importlib, sys
if 'probabilisticml' in sys.modules:
    importlib.reload(pml)
else:
    import probabilisticml as pml
    
%load_ext autoreload
%autoreload 2

In [None]:
exp_info = {
    "data": "airfoil.data",
    "replications": 1,
}
X_train, Y_train = load_data(exp_info, splittype='train', data_path="../data/")
X_val, Y_val = load_data(exp_info, splittype='val', data_path="../data/")
X_train = torch.from_numpy(np.array(X_train))
y_train = torch.from_numpy(np.array(Y_train))

In [None]:
class MLP(nn.Module):
    """Simple MLP network."""

    def __init__(
            self, 
            input_size: int, 
            hidden_sizes: List[int], 
            activation: torch.nn.modules.activation,
            dropout_ratio: float
    ) -> None:
        """Instantiate MLP."""
        super().__init__()
        hidden_id = '_'.join([str(x) for x in hidden_sizes])
        self.model_id = f'MLP_{input_size}_{hidden_id}_2'
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.net = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_sizes[0]))
        for i, o in zip(hidden_sizes, hidden_sizes[1:] + [2]):
            self.net.append(activation())
            self.net.append(torch.nn.Linear(i, o))
        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Define forward pass."""
        x = self.net(x)
        return self.dropout(x)

In [None]:
foo = MLP(input_size=X_train.shape[1], hidden_sizes=[16, 16], activation=torch.nn.ReLU, dropout_ratio=0.)
list(foo.state_dict().keys())

In [None]:
class SingleHiddenRELUModel(nn.Module):
    def __init__(self, input_size: int = 1, hidden_size: int = 16):
        super(SingleHiddenRELUModel, self).__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(in_features=hidden_size, out_features=2)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        # the one dimension is the output the other the log var of a gaussian
        return x

In [None]:
def init_weights(layer: nn.Module) -> None:
    """Create checkpoint with network(s) to be loaded in learning."""
    if isinstance(layer, nn.Linear):
        nn.init.xavier_uniform_(layer.weight)
        nn.init.zeros_(layer.bias)

In [None]:
def save_as_json(dictionary: Dict, target: str) -> None:
    """Save a python object as JSON file."""
    with open(target, 'w', encoding='utf-8') as f:
        json.dump(dictionary, f, ensure_ascii=False, indent=4)

In [None]:
class SGDEnsemble:
    def __init__(self, base_learner: nn.Module, ensemble_size: int, ckpt: str='') -> None:
        """Instantiate ensemble."""
        self.ensemble_size = ensemble_size
        self.base_learner = base_learner
        self.ckpt = ckpt
        self.weights = []
            
    def train(
            self,
            num_epochs: int, 
            x: torch.tensor,
            y: torch.tensor,
            criterion: torch.nn.modules.loss,
            log_at_epoch: list,
    ):
        if len(self.ckpt) == 0 and len(log_at_epoch) > 0:
            raise ValueError('Loggin requires path to checkpoint')
        
        for idx in range(self.ensemble_size):
            bl = deepcopy(self.base_learner)
            torch.manual_seed(idx)
            random.seed(idx)
            bl.apply(init_weights)
            opt = optim.Adam(bl.parameters(), lr=0.01, weight_decay=0.01)
            with tqdm(total=num_epochs, desc="Training Progress") as pbar:
                for epoch in range(num_epochs):
                    # Forward pass
                    outputs = bl(x)
                    mean_pred = outputs[:, 0]
                    std_pred = torch.exp(outputs[:, 1])
                    loss = criterion(mean_pred, y.squeeze(), std_pred)
                    
                    if torch.isnan(loss).any() or torch.isinf(loss).any() or loss.item() < -1e6:
                        print("Loss exploded, breaking")
                        break
                    # Backward pass and optimization
                    opt.zero_grad()
                    loss.backward()
                    opt.step()
                    pbar.update(1)
                    pbar.set_postfix_str("Loss: {:.4f}".format(loss.item()))
                    
                    # Weights
                    current_weights = bl.state_dict()
                    if len(log_at_epoch) > 0 and epoch in log_at_epoch:
                        torch.save(current_weights, self.ckpt + f'weights_member_{idx}_epoch_{epoch}.pt')
                        # save_as_json(current_weights, self.ckpt + f'weights_member_{idx}_epoch_{epoch}.json')
            self.weights.append(current_weights)
            
    def predict(self, x: torch.tensor):
        ensemble_prediction = []
        for idx in range(self.ensemble_size):
            bl = self.base_learner
            bl.load_state_dict(self.weights[idx])
            prediction = bl(x)
            ensemble_prediction.append(prediction)
        return torch.stack(tuple(ensemble_prediction))
            

In [None]:
nsync = SGDEnsemble(
    base_learner=MLP(input_size=X_train.shape[1], hidden_sizes=[(16, 16), (16, 16)], activation=torch.nn.ReLU, dropout_ratio=0.), 
    ensemble_size=1,
    ckpt='/home/user/Downloads/'
)

In [None]:
nsync.train(
    num_epochs=1000,
    x=X_train,
    y=y_train,
    criterion=nn.GaussianNLLLoss(),
    log_at_epoch=[999]
)

In [None]:
# predict on the validation set
X_val = torch.from_numpy(np.array(X_val))
y_val = torch.from_numpy(np.array(Y_val))
outputs_val = nsync.predict(X_val).mean(0)
mean_pred = outputs_val[:, 0]
# print(mean_pred[:10])
# print(outputs[:10, 1])
# print(y_val[:10])
rmse = torch.sqrt(torch.mean((y_val.squeeze() - mean_pred)**2))
print("RMSE: {:.4f}".format(rmse.item()))
# rmse of constant 0 predictor
rmse_0 = torch.sqrt(torch.mean(y_val**2))
print("RMSE_0: {:.4f}".format(rmse_0.item()))
# train a linear regression model
from sklearn.linear_model import LinearRegression
reg = LinearRegression().fit(X_train, Y_train)
print("RMSE_lin: {:.4f}".format(np.sqrt(np.mean((np.array(y_val) - reg.predict(X_val))**2))))