In [None]:
from collections import OrderedDict
from contextlib import contextmanager
import random

# TODO(eric.cousineau): Use tensorboard in notebook.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader

In [None]:
pd.options.display.max_rows = 1000
pd.options.display.max_colwidth = 1000

In [None]:
def zero_grad(p):
    if p.grad is not None:
        p.grad.zero_()

In [None]:
def l1_loss(y, yh):
    return torch.mean(torch.abs(y - yh))

def mse_loss(y, yh):
    return torch.mean((y - yh)**2)

In [None]:
def flat_cat_detached(ps):
    return torch.cat([p.detach().view(-1) for p in ps])

In [None]:
class SequentialDict(nn.Sequential):
    """
    We must use OrderedDict because otherwise pytorch will sort the keys... I think?
    See:
        https://discuss.pytorch.org/t/append-for-nn-sequential-or-directly-converting-nn-modulelist-to-nn-sequential/7104/4
        https://github.com/pytorch/pytorch/pull/40905
    """
    def __init__(self, *args, **kwargs):
        super().__init__(OrderedDict(*args, **kwargs))

In [None]:
def seed(value):
    random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)

In [None]:
activation_cls = nn.ReLU

class MLP(nn.Module):
    """Simple Multi-Layer Perceptron."""
    def __init__(
        self,
        num_inputs,
        num_outputs,
        *,
        num_hidden_units,
        num_layers,
    ):
        super().__init__()
        assert num_layers >= 2
        self.layers = SequentialDict()
        self.layers.input = SequentialDict(
            fcn=nn.Linear(num_inputs, num_hidden_units),
            activation=activation_cls(),
        )
        hidden = []
        for i in range(num_layers - 2):
            hidden.append(SequentialDict(
                fcn=nn.Linear(num_hidden_units, num_hidden_units),
                activation=activation_cls(),
            ))
        self.layers.hidden = nn.Sequential(*hidden)
        self.layers.output = SequentialDict(
            fcn=nn.Linear(num_hidden_units, num_outputs)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
# Derived from this kinda goal: https://towardsdatascience.com/how-to-visualize-convolutional-features-in-40-lines-of-code-70b7d87b0030
class SaveActivations:
    def __init__(self):
        self.y = None

    def forward_hook(self, module, x, y):
        self.y = y.detach()

@contextmanager
def save_activations(module, cls):
    savers = []
    hooks = []
    for m in module.modules():
        if isinstance(m, cls):
            saver = SaveActivations()
            hook = m.register_forward_hook(saver.forward_hook)
            savers.append(saver)
            hooks.append(hook)
    assert len(hooks) > 0
    yield savers
    for hook in hooks:
        hook.remove()

def compute_activation_ratios(savers, clear=True):
    activation_ratios = []
    for saver in savers:
        assert saver.y is not None
        activation_ratios.append((saver.y > 0).to(torch.float).mean())
        if clear:
            saver.y = None
    return torch.tensor(activation_ratios)

In [None]:
@torch.no_grad()
def shit_init_param(p, scale=1.0, offset=1e-5):
    values = torch.linspace(-1.0, 1.0, p.numel()) / p.numel()
    sign = torch.sign(values)
    sign[sign == 0] = 1.0
    values += offset * sign
    p[:] = scale * values.reshape(p.shape)

def shit_init(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            shit_init_param(m.weight)
            shit_init_param(m.bias, scale=0.1)

In [None]:
# # Parameters.
# amplitude = 1.0
# period_sec = 1.0
# shift_sec = 0.0
# num_periods = 1.0
# dt = 0.1
# count_per_period = int(np.ceil(period_sec / dt))
# count = num_periods * count_per_period
# t = torch.arange(count) * dt

# def waveform(t):
#     omega = 2 * np.pi / period_sec
#     x = omega * (t + shift_sec)
#     return amplitude * torch.sin(x)

t = torch.linspace(0, 1.0, 3)

def waveform(t):
    c = 2.0
    return c * t

In [None]:
def rel_mean_abs(x, tol=torch.tensor(1e-8)):
    xa = x.abs()
    xa_div = torch.fmax(xa.max(), tol)
    xa_mean = xa.mean()
    return xa_mean / xa_div

def mean_abs(x):
    return x.abs().mean()

In [None]:
def fit(model, loss_fn, lr, t, num_epochs=3, batch_size=1):
    model.train()
    # Expected param + labeled dataset.
    t = t.unsqueeze(-1)
    y = waveform(t)

    dataset = list(zip(t, y))
    loader = DataLoader(dataset, batch_size=batch_size)

    # Logging
    dfs = []

    with save_activations(model, activation_cls) as savers:
        
        def simple_log(epoch, batch_idx):
            # Show err.
            activation_ratios = compute_activation_ratios(savers).numpy()
            activation_ratios_str = ", ".join([f"{x:.2f}" for x in activation_ratios])
            dp = p - p_prev
            dfs.append(pd.DataFrame(
                {
                    "epoch": epoch,
                    "batch_idx": batch_idx,
                    "mean|p|": mean_abs(p).detach().numpy(),
                    "mean|Δp|": mean_abs(dp).numpy(),
                    "loss": loss.detach().numpy(),
                    "activation_ratios": activation_ratios_str,
                },
                index=[len(dfs)],
            ))

        params = list(model.parameters())  # this is iter
        p = flat_cat_detached(params)
        p_prev = p.clone()
        with torch.no_grad():
            yh0 = model(t)
            loss = loss_fn(y, yh0)
        simple_log(epoch="pre-opt", batch_idx="n/a")

        for epoch in range(num_epochs):
            for batch_idx, (ti, yi) in enumerate(loader):
                zero_grad(p)
                yhi = model(ti)
                loss = loss_fn(yi, yhi)
                loss.backward()

                # SGD update, no momentum.
                with torch.no_grad():
                    for param in params:
                        v = param.grad  # velocity
                        param -= lr * v  # step

                # Validation.
                p = flat_cat_detached(params)
                simple_log(epoch, batch_idx)
                p_prev = p

    print(loss_fn)
    display(pd.concat(dfs))

In [None]:
def make_model():
#     seed(0)
    model = MLP(1, 1, num_hidden_units=2, num_layers=3)
    shit_init(model)
    return model

In [None]:
# Merp
fit(make_model(), loss_fn=mse_loss, lr=0.5, t=t, num_epochs=10, batch_size=len(t))  # GD, not really SGD

In [None]:
import torch.nn.functional as F

In [None]:
# I feel like an eediot

@torch.no_grad()
def plot_pwa_via_relu():
    t = torch.linspace(0, 4, 5)
    y = F.relu(t) - 2 * F.relu(t - 1) + 2 * F.relu(t - 3)
    plt.plot(t.numpy(), y.numpy(), linewidth=3)
    
    fc1 = nn.Linear(1, 3)
    fc1.weight[:] = torch.tensor([1, 1, 1]).unsqueeze(-1)
    fc1.bias[:] = torch.tensor([0, -1, -3])
    act1 = nn.ReLU()
    fc2 = nn.Linear(3, 1, bias=False)
    fc2.weight[:] = torch.tensor([1, -2, 2])
    
    yh = t.unsqueeze(-1)
    yh = fc1(yh)
    yh = act1(yh)
    yh = fc2(yh)
    plt.plot(t.numpy(), yh.numpy(), linewidth=2, linestyle="--")

plot_pwa_via_relu()