In [None]:
# %%
import json
from dataclasses import asdict
from pathlib import Path
from typing import Iterable, Optional, Callable, Union
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

from rib.data_accumulator import collect_gram_matrices, collect_interaction_edges
from rib.hook_manager import HookedModel
from rib.interaction_algos import InteractionRotation, calculate_interaction_rotations
from rib.log import logger
from rib.models import MLP
from rib.plotting import plot_interaction_graph
from rib.types import TORCH_DTYPES
from rib.utils import REPO_ROOT, check_outfile_overwrite, set_seed
import matplotlib.pyplot as plt

In [None]:
class InvertNonlinear(nn.Module):
    def __init__(self, n_inputs, n_hidden, nonlinearity):
        super(InvertNonlinear, self).__init__()
        self.fc1 = nn.Linear(n_inputs, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_inputs, bias=False)

        # Set the nonlinearity
        self.nonlinearity = nonlinearity

        # Freeze the parameters of the first layer
        for param in self.fc2.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.nonlinearity(self.fc1(x))
        x = self.nonlinearity(self.fc2(x))
        return x

In [None]:
class IdentityWideHidden(MLP):
    def __init__(
        self,
        input_size=4,
        dtype=torch.float32,
    ):
        super(IdentityWideHidden, self).__init__(
            hidden_sizes=[2 * input_size],
            input_size=input_size,
            output_size=input_size,
            dtype=dtype,
            fold_bias=False,
            activation_fn="gelu",
        )
        W_embed = torch.zeros(4 * input_size - 2, dtype=dtype)
        W_embed[2 * input_size - 2 : 2 * input_size] = torch.tensor([-1, 1])
        W_embed = W_embed.as_strided((input_size, 2 * input_size), (2, 1)).flip(dims=(1,))
        # random_mix = torch.randn(input_size, input_size)
        self.layers[0].W = nn.Parameter(W_embed)
        self.layers[1].W = nn.Parameter(W_embed.T)
        # self.layers[0].W = nn.Parameter(random_mix @ W_embed)
        # self.layers[1].W = nn.Parameter(W_embed.T @ torch.linalg.inv(random_mix))
        for i in range(2):
            self.layers[i].b = nn.Parameter(torch.zeros_like(self.layers[i].b, dtype=dtype))
        # self.fold_bias()

In [None]:
def train_mixers(
    n_inputs: int,
    hidden_sizes: Union[int, list[int]],
    fs: list[Callable],
    activation_fn: str = "relu",
    datasize: int = 2**24,
    epochs: int = 1,
    batch_size: int = 2**10,
    lr: float = 3e-6,
    data_variance: int = 1,
    print_number: int = 20,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(hidden_sizes, int):
        hidden_sizes = [hidden_sizes]
    input_data = data_variance * torch.rand((datasize, n_inputs)).to(device) + 0.1
    mixer_target = torch.stack([f(input_data[:, 0], input_data[:, 1]) for f in fs], dim=1)
    mixer_target = mixer_target.to(device)
    dataset = torch.utils.data.TensorDataset(input_data, mixer_target)
    input_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    mixer = MLP(
        hidden_sizes=hidden_sizes,
        input_size=n_inputs,
        output_size=len(fs),
        activation_fn=activation_fn,
    ).to(device)
    # new variable unmixer is exactly the same architecture as mixer]
    unmixer = MLP(
        hidden_sizes=hidden_sizes,
        input_size=len(fs),
        output_size=n_inputs,
        activation_fn=activation_fn,
    ).to(device)
    mixer_optimizer = torch.optim.Adam(mixer.parameters(), lr=lr)
    unmixer_optimizer = torch.optim.Adam(unmixer.parameters(), lr=lr)
    # first train the mixer
    total = datasize // batch_size
    all_steps = 0
    losses = []
    for epoch in range(epochs):
        for step, (input, target) in enumerate(tqdm(input_loader, total=total)):
            all_steps += 1
            mixer_optimizer.zero_grad()
            output = mixer(input)
            loss = F.mse_loss(output, target)
            loss.backward()
            mixer_optimizer.step()
            losses.append(loss.item())
            if step % (total // print_number) == 0:
                plt.semilogy(losses)
                plt.show()
                test_data = data_variance * torch.randn(1, n_inputs).to(device)
                output = mixer(test_data)
                target = torch.stack([f(test_data[:, 0], test_data[:, 1]) for f in fs], dim=1)
                print("data:", test_data, "\noutput:", output, "\ntarget", target)
        print("epoch:", epoch, "loss:", loss.item())
    # then train the unmixer
    all_steps = 0
    losses2 = []
    for epoch in range(epochs):
        for step, (input, target) in enumerate(tqdm(input_loader, total=total)):
            all_steps += 1
            unmixer_optimizer.zero_grad()
            output = unmixer(mixer(input))
            loss = F.mse_loss(output, input)
            loss.backward()
            unmixer_optimizer.step()
            losses2.append(loss.item())
            if step % (total // print_number) == 0:
                plt.semilogy(losses2)
                plt.show()
                test_input = data_variance * torch.randn(1, n_inputs).to(device)
                target = torch.stack([f(test_input[:, 0], test_input[:, 1]) for f in fs], dim=1)
                output = unmixer(mixer(test_input))
                print("input:", test_input, "\noutput:", output, "\ntarget", target)
        print("epoch:", epoch, "loss:", loss.item())
    return mixer, unmixer

In [None]:
def f1(x, y):
    return x * y


def f2(x, y):
    return x / y


fs = [f1, f2]
mixer, unmixer = train_mixers(2, [1000, 1000, 1000, 1000], fs, epochs=1)

In [None]:
random = torch.randn(1, 2)
print(random)
print(mixer(random))

In [None]:
n_inputs = 2
mixer = MLP(hidden_sizes=[100], input_size=n_inputs, output_size=n_inputs, activation_fn="relu")
unmixer = MLP(hidden_sizes=[100], input_size=n_inputs, output_size=n_inputs, activation_fn="relu")

In [None]:
# Check if CUDA is available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Parameters
learning_rate = 0.0002
steps = 2**25
n_inputs = 5
batch_size = 2**8
n_hidden = 20
print_num = 20

# Data (random data for demonstration)
# In practice, this should be your specific data
input_data = 10 * torch.randn(steps, n_inputs).to(device)
# input_data = torch.eye(n_inputs).to(device)

# Dataset and DataLoader for mini-batch training
dataset = torch.utils.data.TensorDataset(input_data, input_data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Network
network = InvertNonlinear(n_inputs, n_hidden, F.gelu).to(device)

# Optimizer (update parameters of the second layer only)
optimizer = optim.Adam(network.fc1.parameters(), lr=learning_rate)

# Loss function
criterion = nn.MSELoss()

# Training loop
for step, (inputs, targets) in tqdm(enumerate(dataloader), total=steps // batch_size):
    inputs, targets = inputs.to(device), targets.to(device)

    # Forward pass
    outputs = network(inputs)

    # Compute loss
    loss = criterion(outputs, targets)

    # Zero the gradients
    optimizer.zero_grad()

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()

    if step % ((steps // batch_size) // print_num) == 0:
        print(f"step {step+1}/{steps}, Loss: {loss.item()}")
        print(network(torch.eye(n_inputs).to(device)))
    if loss.item() < 1e-6:
        break