In [3]:
# # # local
project_directory = "../"


# # # # colab
# from google.colab import drive
# drive.mount('/content/drive')
# project_directory = "/content/drive/MyDrive/colab_working_directory/diversity-enforced-ensembles/"
# !pip install cached-property

In [4]:
from pathlib import Path
import pandas as pd
import numpy as np

# allow import of decompose locally
import sys
sys.path.append(project_directory + 'src/')

from decompose import SquaredLoss
import bvdlib

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.func import stack_module_state
from torch.func import functional_call
from torch import vmap
import copy

In [5]:
def init_layer(layer, generator):
    torch.nn.init.xavier_uniform_(layer.weight, generator=generator)
    layer.bias.data.fill_(0.01)

class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, generator):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        init_layer(self.fc1, generator)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        init_layer(self.fc2, generator)

        self.fc3 = nn.Linear(hidden_dim, output_dim)
        init_layer(self.fc3, generator)

    def forward(self, x):

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        return x

In [8]:
def forward_through_metamodel(params, buffers, x):
        return functional_call(meta_model, (params, buffers), (x,))

def torch_MSE_combiner(ensemble_output):
    return torch.mean(ensemble_output, axis=0)

combiner_rule = torch_MSE_combiner

def ens_forward(input, ensemble):

  params, buffers = stack_module_state(nn.ModuleList(ensemble))

  member_output = vmap(forward_through_metamodel)(params, buffers, input.repeat(len(ensemble), 1, 1))

  ensemble_output = combiner_rule(member_output)
  return ensemble_output, member_output

In [36]:
trial_space = np.arange(0,15) / 10
decomp_fn = SquaredLoss
seed = 0
criterion = torch.nn.MSELoss()
epoch_n=7

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


torch_generator = torch.manual_seed(0)

x = torch.rand((200,13), dtype=torch.float).to(device)
y = torch.randint(0,10,(200,),dtype=torch.float).to(device)


In [79]:
n_estimators = 11
# init model
ensemble = []
optims = []
losses = []
for member_n in range(n_estimators):
    ensemble.append(SimpleMLP(len(x[0]), 8, 1, torch_generator).to(device))
    optims.append(torch.optim.SGD(ensemble[member_n].parameters(), lr=0.0001))
    losses.append(None)

meta_model = copy.deepcopy(ensemble[0]).to('meta')

lambda_ = 0.3

In [80]:
def hook(grad):
    print("grad[0]", grad[0])
    # print(grad.shape)

In [81]:
bx, by = x, y

with torch.no_grad():
    ensemble_output, member_output = ens_forward(bx, ensemble)

member_output = member_output.detach()

for i, member in enumerate(ensemble):
    print("NEW MEMBER")
    optims[i].zero_grad()
    member_pred = member(bx) #predict
    member_pred.register_hook(hook)
    print("y[0]", y[0])
    print("mp[0]", member_pred[0])
    print("bad grad", ensemble_output[0] - y[0])
    print("good grad", (member_pred[0] - y[0] - (lambda_ * (2*(10/11)*(member_pred[0]-ensemble_output[0])))))
    print("Ultra_bad_grad", (2*(member_pred[0] - y[0])))
    member_grad_output = ((1/n_estimators) * member_pred.unsqueeze(dim=0)) + (((1-n_estimators)/n_estimators)*torch.cat((member_output[:i], member_output[i+1:])).detach())
    ens_grad_output = combiner_rule(member_grad_output)

    # member_loss = (criterion(member_pred, by.unsqueeze(dim=-1)) + ((lambda_) * criterion(member_pred, ens_grad_output)))
    member_loss = criterion(member_pred, by.unsqueeze(dim=-1))
    # member_loss.register_hook(hook)
    print(member_loss)
    member_loss.backward()
    optims[i].step()

NEW MEMBER
y[0] tensor(8., device='cuda:0')
mp[0] tensor([-0.5175], device='cuda:0', grad_fn=<SelectBackward0>)
bad grad tensor([-8.0412], device='cuda:0')
good grad tensor([-8.2577], device='cuda:0', grad_fn=<SubBackward0>)
Ultra_bad_grad tensor([-17.0351], device='cuda:0', grad_fn=<MulBackward0>)
tensor(32.3215, device='cuda:0', grad_fn=<MseLossBackward0>)
grad[0] tensor([-0.0852], device='cuda:0')
NEW MEMBER
y[0] tensor(8., device='cuda:0')
mp[0] tensor([0.1552], device='cuda:0', grad_fn=<SelectBackward0>)
bad grad tensor([-8.0412], device='cuda:0')
good grad tensor([-7.9519], device='cuda:0', grad_fn=<SubBackward0>)
Ultra_bad_grad tensor([-15.6897], device='cuda:0', grad_fn=<MulBackward0>)
tensor(27.8612, device='cuda:0', grad_fn=<MseLossBackward0>)
grad[0] tensor([-0.0784], device='cuda:0')
NEW MEMBER
y[0] tensor(8., device='cuda:0')
mp[0] tensor([0.1370], device='cuda:0', grad_fn=<SelectBackward0>)
bad grad tensor([-8.0412], device='cuda:0')
good grad tensor([-7.9602], device='cu