# Normalization experiments

In [None]:
from dataclasses import dataclass

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn


In [None]:
class DeepMLP(nn.Module):
    def __init__(self, width: int, depth: int, norm: str = "none"):
        super().__init__()
        layers: list[nn.Module] = []
        for i in range(depth):
            if norm == "none":
                layers.append(nn.Identity())
            elif norm == "layer":
                layers.append(nn.LayerNorm(width))
            elif norm == "batch":
                layers.append(nn.BatchNorm1d(width))
            layers.append(nn.Linear(width, width, bias=False))
            layers.append(nn.Sigmoid())
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        h = x
        for lin in self.layers:
            h = lin(h)
        return h

In [None]:
@dataclass
class Result:
    norm: str
    run: int
    epoch: int
    depth: int
    g_input: float
    g_first_layer: float
    g_last_layer: float
    g_max: float
    loss: float


def run_scenario(
    norm: str,
    depths: list[int],
    n_runs: int = 10,
    n_epochs: int = 100,
    width: int = 128,
    batch: int = 256,
    seed: int = 0,
) -> list[Result]:
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"

    print(f"\n=== {norm} ===")
    torch.manual_seed(seed)
    results = []
    x_all_runs = torch.randn(n_runs, batch, width, device=device)

    for depth in depths:
        for run in range(n_runs):
            net = DeepMLP(width, depth, norm=norm)
            net.to(device)

            optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

            x = x_all_runs[run].clone().requires_grad_(True)

            for epoch in range(n_epochs):
                y = net(x)
                # Loss is the squared difference between the output and the input
                loss = ((y - x) ** 2).sum()
                optimizer.zero_grad()
                loss.backward()

                if x.grad is None:
                    raise ValueError("x.grad is None")

                gin = x.grad.norm().item()
                grads = [
                    lin.weight.grad.detach().norm().item()
                    for lin in net.layers
                    if isinstance(lin, nn.Linear)
                ]
                results.append(
                    Result(
                        norm=norm,
                        run=run,
                        depth=depth,
                        epoch=epoch,
                        g_input=gin,
                        g_last_layer=grads[-1],
                        g_first_layer=grads[0],
                        g_max=max(grads),
                        loss=loss.item(),
                    )
                )
                optimizer.step()

            print(f"{norm} - {depth=} - {run=} done.")
    return results

In [None]:
depths = [2, 4, 8, 16, 32, 64, 128]

all_results = []
results = run_scenario(norm="none", depths=depths, n_epochs=10)
all_results.extend(results)

results = run_scenario(norm="layer", depths=depths, n_epochs=10)
all_results.extend(results)

results = run_scenario(norm="batch", depths=depths, n_epochs=10)
all_results.extend(results)

# convert into a pandas dataframe

In [None]:
# convert into a pandas dataframe
df = pd.DataFrame(all_results)
df

In [None]:
df_last_epoch = df[df["epoch"] == df["epoch"].max()]
df_last_epoch

In [None]:
hue = "depth"
hue_order = df[hue].unique()
hue_order.sort()
print(hue_order)

palette = sns.color_palette("flare", n_colors=len(hue_order))

In [None]:
sns.relplot(
    kind="line",
    x="epoch",
    y="loss",
    data=df,
    col="norm",
    hue="depth",
    palette=palette,
)
plt.yscale("log")

In [None]:
# plot the results in same figure using seaborn
sns.relplot(
    kind="line",
    x="epoch",
    y="g_input",
    data=df,
    hue="depth",
    col="norm",
    palette=palette,
)
plt.yscale("log")


In [None]:
# plot the results in same figure using seaborn
sns.lineplot(x="depth", y="g_input", data=df_last_epoch, hue="norm")
plt.yscale("log")

In [None]:
sns.relplot(
    kind="line",
    x="epoch",
    y="g_first_layer",
    data=df,
    hue="depth",
    palette=palette,
    col="norm",
)
plt.yscale("log")

In [None]:
sns.lineplot(x="depth", y="g_first_layer", data=df_last_epoch, hue="norm")

In [None]:
sns.relplot(
    kind="line",
    x="epoch",
    y="g_last_layer",
    data=df,
    hue="depth",
    palette=palette,
    col="norm",
)
plt.yscale("log")

In [None]:
sns.relplot(
    kind="line",
    x="epoch",
    y="g_max",
    data=df,
    hue="depth",
    palette=palette,
    col="norm",
)
plt.yscale("log")

In [None]:
sns.lineplot(x="depth", y="g_max", data=df_last_epoch, hue="norm")