In [None]:
from baukit import PlotWidget, show, pbar, Widget, Range, Numberbox
import torch
from copy import deepcopy
from torch.nn import Sequential, Linear, ReLU, Sigmoid
from torch.nn.functional import mse_loss, cross_entropy
from torch.optim import Adam, SGD
from collections import OrderedDict
import numpy

# Checkerboard classificaiton
def classify_target(x, y):
    return (y.floor() + x.floor()).long() % 2

# Sine wave classification
def classify_target(x, y):
    return (y > (x * 3).sin()).long()

# Xor classificaiton
def classify_target(x, y):
    return (y.sign() != x.sign()).long()

#def classify_target(x, y):
#    return (x**2 + y**2).long() % 2

class TwoDNetworkWidget(Widget):
    def __init__(self, classify_target=classify_target):
        super().__init__()
        self.history = []
        self.plot = PlotWidget(self.visualize_net, mosaic='012\n333', figsize=(11,6),
                               bbox_inches='tight', gridspec_kw={'hspace': 0.25, 'height_ratios': [2,1]})
        scrubber = Range(min=0, max=0, value=self.plot.prop('index'))
        numbox = Numberbox(value=self.plot.prop('index'))
        self.content = [
            [
                [show.style(alignContent='center'), 'Iteration'],
                numbox,
                show.style(flex=20), scrubber
            ],
            self.plot
        ]
        self.plot.on('click', self.plot_click)
    
    def _repr_html_(self):
        return show.html(self.content)
    
    def add(self, net, stats=None):
        with torch.no_grad():
            self.history.append((deepcopy(net).cpu(), stats))
        self.content[0][-1].max = len(self.history) - 1
        if len(self.history) == 1:
            self.plot.index = len(self.history) - 1

    def visualize_net(self, fig, index=0):
        def endpoints(w, b, scale=10):
            if abs(w[1]) > abs(w[0]):
                x0 = torch.tensor([-scale, scale]).to(w.device)
                x1 = (-b - w[0] * x0) / w[1]
            else:
                x1 = torch.tensor([-scale, scale]).to(w.device)
                x0 = (-b - w[1] * x1) / w[0]
            return torch.stack([x0, x1], dim=1)

        ax1, ax2, ax3, ax4 = fig.axes
        ax1.clear()
        ax2.clear()
        ax3.clear()
        ax4.clear()
        if index >= len(self.history):
            return
        net, data = self.history[index]
        grid = torch.stack([
            torch.linspace(-2, 2, 100)[None, :].expand(100, 100),
            torch.linspace(2, -2, 100)[:, None].expand(100, 100),
        ])
        x, y = grid
        target = classify_target(x, y)
        ax1.set_title('target')
        ax1.imshow(target.float(), cmap='hot', extent=[-2,2,-2,2], vmin=0, vmax=1)
        ax2.set_title('network output')
        score = net(grid.permute(1, 2, 0).reshape(-1, 2)).softmax(1)
        ax2.imshow(score[:,1].reshape(100, 100).detach().cpu(), cmap='hot', extent=[-2,2,-2,2], vmin=0, vmax=1)

        ax3.set_title('first layer folds')
        module = [m for m in net.modules() if isinstance(m, torch.nn.Linear)][0]
        w = module.weight.detach().cpu()
        b = module.bias.detach().cpu()
        e = torch.stack([endpoints(wc, bc) for wc, bc in zip(w, b)])
        for ep in e:
            ax3.plot(ep[:,0], ep[:,1], '#00aa00', linewidth=0.75, alpha=0.33)
        ax3.set_ylim(-2, 2)
        ax3.set_xlim(-2, 2)
        ax3.set_aspect(1.0)
        
        ax4.set_title('training curve')
        ax4.set_xlabel('iteration')
        ax4.axvline(index, color='red', linewidth=0.5)
        for k, v in data.items():
            label = f'{k} = {v:.3g}'
            ax4.plot(range(len(self.history)), [h[1][k] for h in self.history], linewidth=0.5, label=label)
        ax4.legend()
        
    def plot_click(self, e):
        loc = self.plot.event_location(e)
        if loc.axis == 3:
            self.plot.index = max(0, min(len(self.history) - 1, int(loc.x + 0.5)))
            self.plot.redraw()
        

In [None]:
widget = TwoDNetworkWidget(classify_target)

seed = 5 # is good with Adam
seed = 2
device='cpu'
mlp = torch.nn.Sequential(OrderedDict([
    ('layer1', Sequential(Linear(2, 2), Sigmoid())),
    ('layer2', Sequential(Linear(2, 2), Sigmoid())),
    ('layer3', Sequential(Linear(2, 2), Sigmoid())),
    ('layer4', Sequential(Linear(2, 2, bias=False)))
]))

prng = numpy.random.RandomState(seed)
with torch.no_grad():
    for p in mlp.parameters():
        p[...] = torch.tensor(prng.randn(p.numel())).reshape(p.shape)

def sample_x():
    return torch.tensor(prng.randn(1000)).float().reshape(500, 2)

def softmax(z):
    numerator = z.exp()
    denominator = numerator.sum(dim=1, keepdim=True)
    return numerator / denominator

eye2 = torch.eye(2).to(device)
def true_probability(x, rule):
    classnumber = rule(x[:,0], x[:,1])
    return eye2[classnumber]

def nce_loss(p, y):
    return -(y * p.log()).sum(dim=1).mean(dim=0)

def mse_loss(p, y):
    return ((p - y)**2).sum(dim=1).mean(dim=0)

mlp.to(device)
    
def my_loss(x):
    y = true_probability(x, classify_target)
    z = mlp(x)
    p = softmax(z)
    return mse_loss(p, y)

#optimizer = SGD(mlp.parameters(), lr=0.01)
optimizer = SGD([
    dict(params=mlp.layer1.parameters(), lr=1000.0),
    dict(params=mlp.layer2.parameters(), lr=100.0),
    dict(params=mlp.layer3.parameters(), lr=1.0),
    dict(params=mlp.layer4.parameters(), lr=0.01)
], momentum=0.5)
optimizer = Adam([
    dict(params=mlp.layer1.parameters(), lr=5.0),
    dict(params=mlp.layer2.parameters(), lr=0.5),
    dict(params=mlp.layer3.parameters(), lr=0.10),
    dict(params=mlp.layer4.parameters(), lr=0.01)
])
for iteration in pbar(range(1000)):
    x = sample_x().to(device)
    loss = my_loss(x)
    grads=[0,0,0,0]
    if iteration > 0:
        mlp.zero_grad()
        loss.backward()
        weights = [getattr(mlp, f'layer{i+1}')[0].weight for i in range(4)]
        with torch.no_grad():
            grads = [w.grad.norm() / w.norm() for w in weights]
        optimizer.step()
    widget.add(mlp, dict(#loss=loss.detach().item(),
                         grad4=grads[3], grad3=grads[2], grad2=grads[1], grad1=grads[0]))
widget.plot.index = iteration
show(widget)

In [None]:
import math

widget = TwoDNetworkWidget(classify_target)

seed = 5 # is good with Adam
seed = 1
device='cpu'
mlp = torch.nn.Sequential(OrderedDict([
    ('layer1', Sequential(Linear(2, 2), Sigmoid())),
    ('layer2', Sequential(Linear(2, 2), Sigmoid())),
    ('layer3', Sequential(Linear(2, 2), Sigmoid())),
    ('layer4', Sequential(Linear(2, 2, bias=False)))
]))

prng = numpy.random.RandomState(seed)
with torch.no_grad():
    for p in mlp.parameters():
        p[...] = torch.tensor(prng.randn(p.numel()) * math.sqrt(5)).reshape(p.shape)

def sample_x():
    return torch.tensor(prng.randn(1000)).float().reshape(500, 2)

def softmax(z):
    numerator = z.exp()
    denominator = numerator.sum(dim=1, keepdim=True)
    return numerator / denominator

eye2 = torch.eye(2).to(device)
def true_probability(x, rule):
    classnumber = rule(x[:,0], x[:,1])
    return eye2[classnumber]

def nce_loss(p, y):
    return -(y * p.log()).sum(dim=1).mean(dim=0)

def mse_loss(p, y):
    return ((p - y)**2).sum(dim=1).mean(dim=0)

mlp.to(device)
    
def my_loss(x):
    y = true_probability(x, classify_target)
    z = mlp(x)
    p = softmax(z)
    return mse_loss(p, y)

#optimizer = SGD(mlp.parameters(), lr=0.01)
optimizer = Adam([
    dict(params=mlp.layer1.parameters(), lr=5.0),
    dict(params=mlp.layer2.parameters(), lr=0.5),
    dict(params=mlp.layer3.parameters(), lr=0.10),
    dict(params=mlp.layer4.parameters(), lr=0.01)
])
optimizer = SGD(mlp.parameters(), lr=0.01)
for iteration in pbar(range(1000)):
    x = sample_x().to(device)
    loss = my_loss(x)
    grads=[0,0,0,0]
    if iteration > 0:
        mlp.zero_grad()
        loss.backward()
        weights = [getattr(mlp, f'layer{i+1}')[0].weight for i in range(4)]
        with torch.no_grad():
            grads = [w.grad.norm() / w.norm() for w in weights]
        optimizer.step()
    widget.add(mlp, dict(#loss=loss.detach().item(),
                         grad4=grads[3], grad3=grads[2], grad2=grads[1], grad1=grads[0]))
widget.plot.index = iteration
show(widget)