In [None]:
import json
from dataclasses import asdict, dataclass
from pathlib import Path

import fire
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from pydantic import BaseModel, field_validator
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

from rib.data_accumulator import collect_gram_matrices, collect_interaction_edges
from rib.hook_manager import HookedModel
from rib.interaction_algos import InteractionRotation, calculate_interaction_rotations
from rib.log import logger
from rib.models import MLP
from rib.plotting import plot_interaction_graph
from rib.types import TORCH_DTYPES
from rib.utils import REPO_ROOT, check_outfile_overwrite, load_config, set_seed


In [None]:


class IdentityWideHidden(MLP):
    def __init__(
        self,
        input_size=4,
        dtype=torch.float32,
    ):
        super(IdentityWideHidden, self).__init__(
            hidden_sizes = [2*input_size],
            input_size=input_size,
            output_size=input_size,
            dtype=dtype,
            fold_bias=False,
            activation_fn='relu'
        )
        W_embed = torch.zeros(4*input_size-2)
        W_embed[2*input_size-2:2*input_size] = torch.tensor([-1,1])
        W_embed = W_embed.as_strided((input_size, 2*input_size), (2,1)).flip(dims = (1,))
        self.layers[0].W = nn.Parameter(W_embed)
        self.layers[1].W = nn.Parameter(W_embed.T)
        for i in range(2):
            self.layers[i].b = nn.Parameter(torch.zeros_like(self.layers[i].b))
        self.fold_bias()


In [None]:
dataset_size = 1000
input_size = 2
dtype=torch.float32
stds = [1, 1]
mean = torch.zeros((dataset_size, input_size), dtype=dtype)
stds = torch.tensor(stds).broadcast_to((dataset_size, input_size))
print(mean.shape,stds.shape)
torch.normal(mean,stds)

In [None]:
model = IdentityWideHidden(input_size=4)
a = torch.randn(4).unsqueeze(0)
print(a)
model(a)

In [None]:
input_size = 5
W_embed = torch.zeros(4*input_size-2)
W_embed[2*input_size-2:2*input_size] = torch.tensor([-1,1])
print(W_embed)
W_embed = W_embed.as_strided((input_size, 2*input_size), (2,1)).flip(dims = (1,))

In [None]:
W_embed