In [None]:
from baukit import PlotWidget, show, pbar
import torch
from copy import deepcopy
from torch.nn import Sequential, Linear, ReLU
from torch.nn.functional import mse_loss, cross_entropy
from torch.optim import Adam, SGD
from collections import OrderedDict
import numpy

def classify_target(x, y):
    return (y.floor() + x.floor()).long() % 2
def classify_target(x, y):
    return (y > (x * 3).sin()).long()

def endpoints(w, b, scale=10):
    if abs(w[1]) > abs(w[0]):
        x0 = torch.tensor([-scale, scale]).to(w.device)
        x1 = (-b - w[0] * x0) / w[1]
    else:
        x1 = torch.tensor([-scale, scale]).to(w.device)
        x0 = (-b - w[1] * x1) / w[0]
    return torch.stack([x0, x1], dim=1)

def visualize_net(fig, net=None):
    ax1, ax2, ax3 = fig.axes
    ax1.clear()
    ax2.clear()
    ax3.clear()
    if net is None:
        return
    grid = torch.stack([
        torch.linspace(-2, 2, 100)[None, :].expand(100, 100),
        torch.linspace(2, -2, 100)[:, None].expand(100, 100),
    ])
    x, y = grid
    target = classify_target(x, y)
    ax1.set_title('target')
    ax1.imshow(target.float(), cmap='hot', extent=[-2,2,-2,2])
    ax2.set_title('network output')
    score = net(grid.permute(1, 2, 0).reshape(-1, 2)).softmax(1)
    ax2.imshow(score[:,1].reshape(100, 100).detach().cpu(), cmap='hot', extent=[-2,2,-2,2])
    ax3.imshow(score[:,1].reshape(100, 100).detach().cpu(), cmap='hot', extent=[-2,2,-2,2], alpha=0.2)

    ax3.set_title('first layer folds')
    module = [m for m in net.modules() if isinstance(m, torch.nn.Linear)][0]
    w = module.weight.detach().cpu()
    b = module.bias.detach().cpu()
    e = torch.stack([endpoints(wc, bc) for wc, bc in zip(w, b)])
    for ep in e:
        ax3.plot(ep[:,0], ep[:,1], '#00aa00', linewidth=0.75, alpha=0.33)
    ax3.set_ylim(-2, 2)
    ax3.set_xlim(-2, 2)
    ax3.set_aspect(1.0)

In [None]:
history = []
seed = 1
device='cpu'
mlp = torch.nn.Sequential(OrderedDict([
    ('layer1', Sequential(Linear(2, 100), ReLU())),
    ('layer2', Sequential(Linear(100, 2, bias=False)))
]))

prng = numpy.random.RandomState(seed)
with torch.no_grad():
    for p in mlp.parameters():
        p[...] = torch.tensor(prng.randn(p.numel())).reshape(p.shape)

def sample_x():
    return torch.tensor(prng.randn(1000)).float().reshape(500, 2)

eye2 = torch.eye(2).to(device)
mlp.to(device)

optimizer = Adam(mlp.parameters(), lr=0.01)
for iteration in pbar(range(1000)):
    in_batch = sample_x().to(device)
    target_batch = classify_target(in_batch[:,0], in_batch[:,1])
    out_batch = mlp(in_batch)
    loss = mse_loss(out_batch, eye2[target_batch])
    loss = cross_entropy(out_batch, target_batch)
    #loss += sum((p ** 2).sum() * 1e-6 for p in mlp.parameters())
    if iteration > 0:
        mlp.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        history.append(deepcopy(mlp).cpu())


In [None]:
from baukit import Range, Numberbox
def set_iteration(i):
    widget.net = history[scrubber.value]
    widget.redraw()
scrubber = Range(min=0, max=len(history)-1)
scrubber.on('value', set_iteration)

widget = PlotWidget(visualize_net, nrows=1, ncols=3, figsize=(12,4), bbox_inches='tight')
show([[[show.style(alignContent='center'), 'Iteration'],
        Numberbox(value=scrubber.prop('value')),
        show.style(flex=20), scrubber],
     widget])
scrubber.value = 0

In [None]:
widget.net = history[scrubber.value]