<a href="https://colab.research.google.com/github/Gurkenglas/modularity/blob/master/Modularity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

accumulator = []
# Record all outputs of linear modules as f(x) runs.
def activations(f, x):
  global accumulator
  accumulator = []
  f(x)
  return torch.hstack(accumulator)

class Lambda(nn.Module):
  def __init__(self, param, func):
    super().__init__()
    self.param = nn.Parameter(param)
    self.func = func
  def forward(self, input):
    return self.func(input, self.param)

def trace(input, _): #eeeeevil
  global accumulator
  accumulator.append(input)
  return input

layer = lambda dims: nn.Sequential(
  nn.GELU(),
  nn.Linear(dims[0], dims[1]),
  Lambda(torch.zeros(dims[1]), torch.normal),
  Lambda(torch.empty([]), trace)
)

mlp = lambda dims: nn.Sequential(
  layer(dims[0:2]),
  layer(dims[1:3]),
  layer(dims[2:4]),
  layer(dims[3:5]),
)

def correlation(covariance):
  inv_std = (1 / torch.sqrt(torch.diag(covariance))).expand([covariance.shape[0], -1])
  return inv_std * covariance * inv_std.T

index = 0
def show(matrix):
  global index
  plt.figure(index, figsize=(3,3), dpi=160)
  plt.matshow(matrix, fignum = index)
  plt.colorbar()
  index += 1

In [None]:
import ipywidgets
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive
from functools import partial
from torch.autograd.functional import jacobian
!pip install torchviz
import torchviz
from graphviz import Source

@interact(inputdims = (1,8), outputdims = (1,8), teachers = (1,5), studentinner=(1,40), teacherinner=(1,40))
def studentteacher(inputdims = 4, outputdims = 2, teachers = 3, studentinner = 19, teacherinner = 20):
  student = mlp([inputdims, studentinner,studentinner,studentinner, outputdims * teachers])
  teachers = [mlp([inputdims, teacherinner,teacherinner,teacherinner, outputdims]) for _ in range(teachers)]
  label = lambda input: torch.cat([t(input) for t in teachers], -1)
  optimizer = torch.optim.Adam(student.parameters(),lr=0.02)
  log = []
  for _ in range(100):
    input = torch.rand([50,inputdims], requires_grad=True)
    loss = nn.MSELoss()(label(input),student(input))
    loss.backward()
    log.append(torchviz.make_dot(loss))
    optimizer.step()
    optimizer.zero_grad()
  x = torch.rand([2, inputdims], requires_grad=True)
  j0 = jacobian(partial(activations, student), x[0]) # j describes what student does to a neighborhood of x
  cov0 = j0.matmul(j0.T) # student sends a normal distribution around x with covariance matrix 1 to a normal distribution with this covariance matrix
  corr0 = correlation(cov0)
  j1 = jacobian(partial(activations, student), x[1])
  cov1 = j1.matmul(j1.T)
  corr1 = correlation(cov1)
  show(corr0-corr1)
  from graphviz import Source;
  import time
  Source(log[0]).render(filename="lossgraph"+str(time.time()), directory="H:\\", format="pdf")



interactive(children=(IntSlider(value=4, description='inputdims', max=8, min=1), IntSlider(value=2, descriptio…