<a href="https://colab.research.google.com/github/VRehnberg/modularity/blob/master/Modularity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import matplotlib.pyplot as plt
from functools import partial
from ipywidgets import interact

import torch
from torch import nn, optim
from torch.autograd.functional import jacobian

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
class MLP(nn.Module):
    def __init__(self, layers, drop_p=0.0, Activation=nn.ReLU):
        super().__init__()
        self.hidden = nn.ModuleList()
        for li, lo in zip(layers, layers[1:]):
            self.hidden.append(nn.Linear(li, lo))
        self.activation = Activation()
        self.droput = nn.Dropout(drop_p)
        
    def forward(self, x):
        for i, layer in enumerate(self.hidden):
            x = layer(x)
            if i < len(self.hidden) - 1:
                x = self.activation(x)
                x = self.droput(x)
        return x


def reset_hooks(network):
    if hasattr(network, "hooks"):
        for h in network.hooks:
            h.remove()
    network.hooks = []


def record_activations(network):
    activations = []
    def save_activation(mod, inp, out):
        activations.append(out.cpu())
    
    reset_hooks(network)
    for name, m in network.named_modules():
        if type(m)==nn.Linear:
            h = m.register_forward_hook(save_activation)
            network.hooks.append(h)

    return activations

def get_activations(network, x):
    activations = record_activations(network)
    network(x)
    return torch.hstack(activations)


def correlation(covariance):
    inv_std = (1 / torch.sqrt(torch.diag(covariance))).expand([covariance.shape[0], -1])
    return inv_std * covariance * inv_std.T


vmap = lambda f, x: torch.stack([f(x) for x in x.unbind()])


def show(matrix, fignum=0):
    plt.figure(fignum, figsize=(3,3), dpi=160)
    plt.matshow(matrix, fignum=fignum, vmin=-1, vmax=1)
    plt.colorbar()

In [3]:
@interact(inputdims = (1,8), outputdims = (1,8), nlayers = (1, 8), teachers = (1,5), studentinner = (1,40), teacherinner = (1,40), sqrtnsamples = (1, 50))
def studentteacher(inputdims = 4, outputdims = 2, nlayers = 1, teachers = 3, studentinner = 19, teacherinner = 20, sqrtnsamples = 3):
    student = MLP([inputdims, *(nlayers * [studentinner]), outputdims * teachers])
    teachers = [MLP([inputdims, *(nlayers * [studentinner]), outputdims]) for _ in range(teachers)]
    label = lambda input: torch.cat([t(input) for t in teachers], -1)
    optimizer = optim.Adam(student.parameters(), lr=0.02)
    log = []
    for _ in range(100):
        input = torch.rand([50,inputdims], requires_grad=True)
        loss = nn.MSELoss()(label(input), student(input))
        loss.backward()
        log.append(loss)
        optimizer.step()
        optimizer.zero_grad()

    x = torch.rand([sqrtnsamples * sqrtnsamples, inputdims])
    # each j describes what student does to a neighborhood of x
    js = vmap(lambda x: jacobian(lambda x: get_activations(student, x), x), x)
    # student sends a normal distribution around x with covariance matrix 1
    # to a normal distribution with this     covariance matrix:
    covs = vmap(lambda j: j @ j.T, js)
    squaredcorrs = vmap(correlation, covs).pow(2)
    yyxx = squaredcorrs.view(sqrtnsamples, sqrtnsamples, *squaredcorrs.shape[1:]).permute([2, 0, 3, 1])
    show(yyxx.reshape(yyxx.shape[0] * yyxx.shape[1], yyxx.shape[2] * yyxx.shape[3]))


interactive(children=(IntSlider(value=4, description='inputdims', max=8, min=1), IntSlider(value=2, descriptioâ€¦