In [1]:
import torch
import torch.nn as nn

In [16]:
class MLP(nn.Module):
    def __init__(self,
                 input_dim : int,
                 hidden : int,
                 output_dim : int):
        super().__init__()
        self.l1 = nn.Linear(1, 16)
        self.l2 = nn.Linear(16, 16)
        self.l3 = nn.Linear(16, 1)
        
        nn.init.xavier_uniform_(self.l1.weight)
        nn.init.xavier_uniform_(self.l2.weight)
        nn.init.xavier_uniform_(self.l3.weight)
        
    def forward(self, inputs):
        x = self.l1(inputs)
        x = nn.functional.tanh(x)
        x = self.l2(x)
        x = nn.functional.tanh(x)
        x = self.l3(x)
        return x

In [17]:
class ModelWithPrior(nn.Module):
    def __init__(self,
                 base_model : nn.Module,
                 prior_model : nn.Module,
                 prior_scale : float = 1.0):
        super().__init__()
        self.base_model = base_model
        self.prior_model = prior_model
        self.prior_scale = prior_scale
        
    def forward(self, inputs):
        with torch.no_grad():
            prior_out = self.prior_model(inputs)
            prior_out = prior_out.detach()
        model_out = self.base_model(inputs)
        return model_out + (self.prior_scale * prior_out)

In [18]:
def train_model(x_train, y_train, base_model, prior_model):
    model = ModelWithPrior(base_model, prior_model, 1.0)
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
    
    for epoch in range(100):
        model.train()
        preds = model(x_train)
        loss = loss_fn(preds, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
    return model

In [19]:
priors = []
models = []
model_w_priors = []

In [20]:
for i in range(6):
    priors.append(MLP(1, 16, 1))
    models.append(MLP(1, 16, 1))
    model_w_priors.append(ModelWithPrior(models[-1], priors[-1]))