In [None]:
import torch
from collections import OrderedDict
from torch.nn import Linear, ReLU, Sequential



In [None]:
# This visulaization function is used for the next example

%matplotlib inline
from matplotlib import pyplot as plt

def endpoints(w, b, scale=10):
    if abs(w[1]) > abs(w[0]):
        x0 = torch.tensor([-scale, scale]).to(w.device)
        x1 = (-b - w[0] * x0) / w[1]
    else:
        x1 = torch.tensor([-scale, scale]).to(w.device)
        x0 = (-b - w[1] * x1) / w[0]
    return torch.stack([x0, x1], dim=1)

def visualize_net(net, classify_target):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
    grid = torch.stack([
        torch.linspace(-2, 2, 100)[None, :].expand(100, 100),
        torch.linspace(2, -2, 100)[:, None].expand(100, 100),
    ])
    x, y = grid
    target = classify_target(x, y)
    ax1.set_title('target')
    ax1.imshow(target.float(), cmap='hot', extent=[-2,2,-2,2])
    ax2.set_title('network output')
    score = net(grid.permute(1, 2, 0).reshape(-1, 2).cuda()).softmax(1)
    ax2.imshow(score[:,1].reshape(100, 100).detach().cpu(), cmap='hot', extent=[-2,2,-2,2])

    ax3.set_title('first layer folds')
    module = [m for m in net.modules() if isinstance(m, torch.nn.Linear)][0]
    w = module.weight.detach().cpu()
    b = module.bias.detach().cpu()
    e = torch.stack([endpoints(wc, bc) for wc, bc in zip(w, b)])
    for ep in e:
        ax3.plot(ep[:,0], ep[:,1], '#00ff00', linewidth=0.75, alpha=0.33)
    ax3.set_ylim(-2, 2)
    ax3.set_xlim(-2, 2)
    ax3.set_aspect(1.0)
    plt.show()

In [None]:
from torch.optim import Adam, SGD
from torch.nn.functional import cross_entropy, mse_loss

def classify_target(x, y):
    return (y > (x * 3).sin()).long()
def classify_target(x, y):
    return (y.floor() + x.floor()).long() % 2

mlp = torch.nn.Sequential(OrderedDict([
    ('layer1', Sequential(Linear(2, 100), ReLU())),
    ('layer2', Sequential(Linear(100, 2, bias=False)))
]))
mlp.cuda()
#optimizer = SGD(mlp.parameters(), lr=0.01, momentum=0.95)
optimizer = Adam(mlp.layer2.parameters(), lr=0.01)
for iteration in range(1024 * 8):
    in_batch = torch.randn(10000, 2, device='cuda')
    target_batch = classify_target(in_batch[:,0], in_batch[:,1])
    out_batch = mlp(in_batch)
    loss = mse_loss(out_batch, torch.eye(2)[target_batch].cuda())
    loss += (mlp.layer2[0].weight ** 2).sum() * 1e-6
    #loss = cross_entropy(out_batch, target_batch)
    if iteration > 0:
        mlp.zero_grad()
        loss.backward()
        optimizer.step()
    if iteration == 2 ** iteration.bit_length() - 1:
        pred_batch = out_batch.max(1)[1]
        accuracy = (pred_batch == target_batch).float().sum() / len(in_batch)
        print(f'Iteration {iteration} accuracy: {accuracy}')
        visualize_net(mlp, classify_target)



In [None]:
mlp = torch.nn.Sequential(OrderedDict([
    ('layer1', Sequential(Linear(2, 100), ReLU())),
    ('layer2', Sequential(Linear(100, 2, bias=False)))
]))
mlp.cuda()
in_batch = torch.randn(100000, 2, device='cuda')
target_batch = classify_target(in_batch[:,0], in_batch[:,1])
target_batch = torch.eye(2)[target_batch].cuda()
with torch.no_grad():
    hidden_batch = mlp.layer1(in_batch)
    inv = torch.linalg.pinv(hidden_batch.t().double()).float()
    mlp.layer2[0].weight[...] = target_batch.t() @ inv
visualize_net(mlp, classify_target)


In [None]:
mlp = torch.nn.Sequential(OrderedDict([
    ('layer1', Sequential(Linear(2, 200), ReLU())),
    ('layer2', Sequential(Linear(200, 2, bias=False)))
])).cuda()
optimizer = Adam(mlp.parameters(), lr=0.01)
for iteration in range(1024 * 8):
    in_batch = torch.randn(10000, 2, device='cuda')
    target_batch = classify_target(in_batch[:,0], in_batch[:,1])
    out_batch = mlp(in_batch)
    #loss = mse_loss(out_batch, torch.eye(2)[target_batch].cuda())
    loss = cross_entropy(out_batch, target_batch)
    #loss += (mlp.layer2[0].weight ** 2).sum() * 1e-6
    if iteration > 0:
        mlp.zero_grad()
        loss.backward()
        optimizer.step()
    if iteration == 2 ** iteration.bit_length() - 1:
        pred_batch = out_batch.max(1)[1]
        accuracy = (pred_batch == target_batch).float().sum() / len(in_batch)
        print(f'Iteration {iteration} accuracy: {accuracy}')
        visualize_net(mlp, classify_target)
