In [1]:
from torch.nn import Module, Linear, ReLU
from models.network_mapper import to_basic_representation, to_relevance_representation
from utils.Utils import input_mapping

%matplotlib inline
import pylab as pl
import torch
from IPython import display
import time

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


### Define target function, model with basic pytorch functionalities (without data source relevance propagation)

In [None]:
f2 = lambda x: (x[0])**3 + 0.1* torch.randn_like(x[1])

class Model(Module):
    def __init__(self) -> None:
        super().__init__()

        self.relu = ReLU()
        self.linear1 = Linear(2, 32, bias=True)
        self.linear2 = Linear(32, 32, bias=True)
        self.linear3 = Linear(32, 32, bias=True)
        self.mu = Linear(32, 1, bias=True)

    def forward(self, x):

        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x) 
        x = self.relu(x)
        x = self.linear3(x)
        x = self.relu(x)
        mu = self.mu(x)

        return mu

model = Model().to(device)

### Change model to Modality Relevance Propagation and defined input mapping to relevance representation

In [None]:
optimizer = torch.optim.Adam([{"params": model.parameters(), "lr": 0.001}])
batch_size = 4096
num_iter = 10000000

### Train Network and visualize Source Relevances

In [None]:

model = to_relevance_representation(model, verbose=0)

for i in range(num_iter):
    optimizer.zero_grad()
    x1 = ((torch.rand((batch_size, 1))-0.5) * 6).to(device)
    x2 = ((torch.rand((batch_size, 1))-0.5) * 6).to(device)
    y = f2([x1, x2])
  
    x = input_mapping(x1, x2)
    y_mu = model(torch.cat(x,-1))    
    m = y_mu.sum(0)

    loss  = (y-m)**2
    loss = loss.mean()
    loss.backward()
    optimizer.step()

    with torch.no_grad():
      x1 = torch.arange(start=-3, end=3.0001, step=0.01).unsqueeze(-1).to(device)
      x2 = (torch.rand_like(x1, device=device)-0.5)*6

      x = input_mapping(x1, x2)
      y_mu = model(torch.cat(x, -1))

      x1 = x1.cpu()
      x2 = x2.cpu()
      y_mu = y_mu.detach().cpu()

      if i%100==0:
          pl.clf()

          pl.plot(x1, f2([x1,x2]), "*", label="Noisy Function")
          pl.plot(x1 ,y_mu[0].detach(), "x", label="Bias", )
          pl.plot(x1 ,y_mu[1].detach(), "o", label="Source 1")
          pl.plot(x1 ,y_mu[2].detach(), "--", label="Source 2")
          pl.plot(x1 ,y_mu.sum(0).detach(), label="Pred")


          pl.legend()
          pl.title(f"Iter {i} -  Loss: {loss.cpu().item()}")
          display.display(pl.gcf())
          display.clear_output(wait=True)
          time.sleep(0.001)

model = to_basic_representation(model=model, verbose=0)