In [None]:
import wandb
import numpy as np
from scipy.optimize import minimize
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import time
import plotly.graph_objs as go
import matplotlib as mpl
import torch.nn.functional as F

In [None]:
mpl.style.use('seaborn-v0_8')
mpl.rcParams['figure.figsize'] = (15,10)
fontsize = 20
mpl.rcParams['font.size'] = fontsize
mpl.rcParams['xtick.labelsize'] = fontsize
mpl.rcParams['ytick.labelsize'] = fontsize
mpl.rcParams['legend.fontsize'] = fontsize
mpl.rcParams['axes.titlesize'] = fontsize
mpl.rcParams['axes.labelsize'] = fontsize

In [None]:

# Define dimensions
f = 10  # Input/Output dimensions
n = 2   # Hidden layer dimensions

# Create synthetic dataset
class SyntheticDataset(Dataset):
    def __init__(self, num_samples, f):
        self.num_samples = num_samples
        self.f = f
        self.data = self.generate_data()
        
    def generate_data(self):
        data = torch.zeros((self.num_samples, self.f))
        for i in range(self.num_samples):
            index = torch.randint(0, self.f, (1,))
            data[i, index] = torch.rand(1)
        return data

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

class SyntheticNormalised(Dataset):
    def __init__(self, f):
        self.f = f
        self.data = self.generate_data()
        
    def generate_data(self):
        return torch.eye(self.f)

    def __len__(self):
        return self.f

    def __getitem__(self, idx):
        return self.data[idx]

class FCNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, tied = True, final_bias = False, nonlinearity = F.relu, mlps = 0, is_resnet = False, mlp_dim = None):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.nonlinearity = nonlinearity
        self.mlps = mlps
        self.tied = tied
        self.final_bias = final_bias
        self.is_resnet = int(is_resnet)
        self.mlp_dim = hidden_dim if mlp_dim is None else mlp_dim

        # Define the input layer (embedding)
        self.embedding = nn.Linear(self.input_dim, self.hidden_dim, bias=False)

        #Add extra hidden layers
        if mlps > 0:
            self.hidden = nn.ModuleList([nn.Linear(self.hidden_dim, self.hidden_dim, bias = False) for _ in range(mlps)])
        
        # Define the output layer (unembedding)
        self.unembedding = nn.Linear(self.hidden_dim, self.input_dim, bias=final_bias)

        # Tie the weights
        if tied:
            self.unembedding.weight = torch.nn.Parameter(self.embedding.weight.transpose(0, 1))

    def forward(self, x):
        x = self.embedding(x)
        if self.mlps > 0:
            for layer in self.hidden:
                x = self.nonlinearity(layer(x)) + self.is_resnet * x
        x = self.unembedding(x)
        x = self.nonlinearity(x)
        return x

class FCNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, tied = True, final_bias = False, nonlinearity = F.relu, mlps = 0, is_resnet = False, mlp_dim = None):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.nonlinearity = nonlinearity
        self.mlps = mlps
        self.tied = tied
        self.final_bias = final_bias
        self.is_resnet = int(is_resnet)
        self.mlp_dim = hidden_dim if mlp_dim is None else mlp_dim


        # Define the input layer (embedding)
        self.embedding = nn.Linear(self.input_dim, self.hidden_dim, bias=False)

        #Add extra hidden layers
        if mlps > 0:
            self.hidden = nn.ModuleList([self.create_mlp_layer() for _ in range(mlps)])
        
        # Define the output layer (unembedding)
        self.unembedding = nn.Linear(self.hidden_dim, self.input_dim, bias=final_bias)

        # Tie the weights
        if tied:
            self.unembedding.weight = torch.nn.Parameter(self.embedding.weight.transpose(0, 1))
            
    def create_mlp_layer(self):
        mlp_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, self.mlp_dim, bias=False),
            self.nonlinearity,
            nn.Linear(self.mlp_dim, self.hidden_dim, bias=False)
        )
        return mlp_layer

    def forward(self, x):
        x = self.embedding(x)
        if self.mlps > 0:
            for layer in self.hidden:
                x = self.nonlinearity(layer(x)) + self.is_resnet * x
        x = self.unembedding(x)
        x = self.nonlinearity(x)
        return x

class FCNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, mlp_dim, tied = True, final_bias = False, nonlinearity = torch.sigmoid, mlps = 0, is_resnet = False):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.nonlinearity = nonlinearity
        self.mlps = mlps
        self.tied = tied
        self.final_bias = final_bias
        self.is_resnet = int(is_resnet)

        # Define the input layer (embedding)
        self.embedding = nn.Linear(self.input_dim, self.hidden_dim, bias=False)

        # Add extra hidden layers
        if mlps > 0:
            self.hidden_layers = nn.ModuleList([nn.Linear(self.hidden_dim, self.mlp_dim, bias=False) for _ in range(mlps)])
            self.hidden_layers_t = nn.ModuleList([nn.Linear(self.mlp_dim, self.hidden_dim, bias=False) for _ in range(mlps)])
        
        # Define the output layer (unembedding)
        self.unembedding = nn.Linear(self.hidden_dim, self.input_dim, bias=final_bias)

        # Tie the weights
        if tied:
            self.unembedding.weight = torch.nn.Parameter(self.embedding.weight.transpose(0, 1))
            
    def forward(self, x):
        x = self.embedding(x)
        if self.mlps > 0:
            for layer, layer_t in zip(self.hidden_layers, self.hidden_layers_t):
                x = layer(x)
                x = self.nonlinearity(x)
                x = layer_t(x) + self.is_resnet * x
        x = self.unembedding(x)
        return x



def plot_weights(weight_matrix, jitter = 0.05, normalised = False, save = False, epoch = None):
    plt.figure(figsize=(8, 8))

    for i in range(weight_matrix.shape[0]):
        normalisation = (weight_matrix[i,0]**2 + weight_matrix[i,1]**2) **0.5 if normalised else 1 
        plt.arrow(0, 0, weight_matrix[i,0]/normalisation, weight_matrix[i,1]/normalisation, head_width=0.05, head_length=0.1, fc='blue', ec='blue')
        plt.text(weight_matrix[i,0]/normalisation + jitter * torch.randn(1), weight_matrix[i,1]/normalisation + jitter * torch.randn(1), f"{i}", color='red', fontsize=12)

    mins = -1.2 if normalised else weight_matrix.min()-0.5
    maxs = 1.2 if normalised else weight_matrix.max()+0.5
    plt.xlim(mins,maxs)
    plt.ylim(mins,maxs)
    plt.grid()
    if save:
        assert epoch is not None
        plt.savefig(f"weights_{epoch}.png")
    plt.close()

def identity(x):
    return x

def scaled_sigmoid(scale=1, shift=0):
    def output(x):
        return scale * (torch.sigmoid(x) + shift)
    return output

def normalise(matrix, tolerance = 1e-10):
    out = np.zeros_like(matrix)
    for i,row in enumerate(matrix):
        norm = (row.T @ row) ** 0.5
        if norm > tolerance:
            out[i] = row / norm
        else:
            continue
    return out

def np_relu(matrix):
    return np.maximum(matrix,0)

In [None]:

using_wandb = False
logging_loss = True
new_run = True

# Configure the hyperparameters
f = 50
n = 2
batch_size = f
epochs = 20000
learning_rate = 5
image_store_rate = epochs/1000
plot_rate=epochs/5
nonlinearity =  nn.GELU()
tied = False
final_bias = False
hidden_layers = 1
is_resnet = True

# Start a new run
if using_wandb:
  run = wandb.init(project="tiny_superposition", entity="jake-mendel")
if logging_loss:
    if new_run:
        if 'model' in globals():
            old_model = model#type: ignore
        losses = []



# print_rate = 5000
if using_wandb:
    config = wandb.config
    f = f
    config.n = n
    config.batch_size = batch_size
    config.epochs = epochs
    config.learning_rate = learning_rate
    config.image_store_rate = image_store_rate

# Instantiate synthetic dataset
dataset = SyntheticNormalised(f)
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True)

# Instantiate the model
if new_run:
    model = FCNet(f, n,
                  tied = tied,
                  final_bias = final_bias,
                  nonlinearity=nonlinearity,
                  hidden_layers=hidden_layers,
                  is_resnet = is_resnet)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

#Define a learning rate schedule
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=epochs//100, gamma=3000**(-100))

# Watch the model
if using_wandb:
  wandb.watch(model)

# Training
def train(model, loader, criterion, optimizer, epochs):
    weights_history = {k:[v.detach().numpy().copy()] for k,v in dict(model.named_parameters()).items()}  # Store the weights here
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for batch in loader:
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs, batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        if using_wandb:
            wandb.log({"epoch": epoch, "loss": avg_loss})
        if logging_loss:
           losses.append(avg_loss)
           if (epoch + 1) % plot_rate == 0:
               plt.figure(figsize = (5,5))
               plt.plot(losses)
               plt.show()
        # if (epoch + 1) % print_rate == 0:
        #   print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss}")

        # Every image_store_rate epochs, save weights
        if (epoch + 1) % image_store_rate == 0:
            for k,v in dict(model.named_parameters()).items():
                weights_history[k].append(v.detach().numpy().copy())
    scheduler.step()
    return weights_history  # Return the weights history

# Train the model
if new_run:
    weights_history = train(model, loader, criterion, optimizer, epochs)
else:
    w = train(model, loader, criterion, optimizer, epochs)
    for k,v in w:
        weights_history[k] += v
# Close the wandb run
if using_wandb:
  run.finish()


In [None]:


def plot_weights_interactive(weights_history, x_dir = None, y_dir = None, transpose = False):
    # Initialize figure
    fig = go.Figure()
    max_value = np.max([np.abs(weight_matrix).max() for weight_matrix in weights_history])

    x_shape = weights_history[0].shape[0] if transpose else weights_history[0].shape[-1]
    y_shape = weights_history[0].shape[0] if transpose else weights_history[0].shape[-1]
    assert type(x_dir) == type(y_dir)
    if isinstance(x_dir, np.ndarray):
        assert len(x_dir) == weights_history[0].shape[-1]
    if x_dir is None:
        x_dir, y_dir = (np.zeros(x_shape), np.zeros(y_shape))
        x_dir[0] = 1
        y_dir[1] = 1
    elif isinstance(x_dir,int):
        x = x_dir
        y = y_dir
        x_dir, y_dir = (np.zeros(x_shape), np.zeros(y_shape))
        x_dir[x] = 1
        y_dir[y] = 1
    else:
        if isinstance(x_dir,list):
            x_dir = np.array(x_dir)
            y_dir = np.array(y_dir)
        assert isinstance(x_dir, np.ndarray)
        assert len(x_dir) == weights_history[0].shape[-1]
    
    # Create a scatter plot for each weight matrix
    for i, weight_matrix in enumerate(weights_history):
        # Convert weight matrix into arrow endpoints and labels
        if transpose:
            weight_matrix = weight_matrix.T
        x_values = weight_matrix @ x_dir
        y_values = weight_matrix @ y_dir
        labels = list(range(len(x_values)))

        # Create scatter plot
        scatter = go.Scatter(x=x_values, y=y_values, mode='markers+text', text=labels,
                             textposition='top center', marker=dict(size=8), visible = False, name=f'Epoch {i*image_store_rate}')

        fig.add_trace(scatter)
    
    # Make the first scatter plot visible
    fig.data[0].visible = True

    # Update xaxis and yaxis properties
    fig.update_xaxes(title_text='X Value', range=[-max_value*1.1, max_value*1.1])
    fig.update_yaxes(title_text='Y Value', range=[-max_value*1.1, max_value*1.1])

    # Add a slider to switch between epochs
    steps = []
    for i in range(len(fig.data)):
        step = dict(
            method='restyle',
            args=['visible', [False] * len(fig.data)],
            label=f'Epoch {i * image_store_rate}'
        )
        step['args'][1][i] = True  # Toggle i'th trace to "visible"
        steps.append(step)
    sliders = [dict(steps=steps, active=0)]

    fig.update_layout(sliders=sliders, width=800, height=800)

    fig.show()

# plot_weights_interactive(weights_history['unembedding.weight'])


In [None]:
plot_weights_interactive(weights_history['embedding.weight'], transpose=True)


In [None]:
[(k,v[0].shape) for k,v in weights_history.items()]

In [None]:
plot_weights_interactive(unembedding_history)

In [None]:
def plot_weights_interactive_3d(weights_history, visualisation = 'points', numbering=False):
    # Initialize figure
    fig = go.Figure()

    # Calculate maximum value of all components
    max_value = np.max([np.abs(weight_matrix).max() for weight_matrix in weights_history])

    # Create a scatter plot for each weight matrix
    for i, weight_matrix in tqdm(enumerate(weights_history)):
        # Convert weight matrix into arrow endpoints and labels
        x_values = weight_matrix[:,0]
        y_values = weight_matrix[:,1]
        z_values = weight_matrix[:,2]
        labels = list(range(len(x_values))) if numbering else ['' for _ in range(len(x_values))]

        # Create scatter plot
        if visualisation == 'points':
            scatter = go.Scatter3d(x=x_values, y=y_values, z=z_values, mode='markers+text', text=labels,
                                textposition='middle center', marker=dict(size=4), visible=False, name=f'Epoch {i*image_store_rate}')
            fig.add_trace(scatter)
        elif visualisation == 'lines':
            for j in range(len(x_values)):
                # Create arrow body as a line
                line = go.Scatter3d(x=[0, x_values[j]], y=[0, y_values[j]], z=[0, z_values[j]], 
                                    mode='lines',
                                    line=dict(width=2, color='blue'),
                                    showlegend=False,
                                    visible=False)
                fig.add_trace(line)

    # Make the first scatter plot visible
    fig.data[0].visible = True

    # Update xaxis, yaxis, and zaxis properties to use the calculated maximum value
    fig.update_layout(scene = dict(
                    xaxis=dict(range=[-max_value,max_value], title='X Value'),
                    yaxis=dict(range=[-max_value,max_value], title='Y Value'),
                    zaxis=dict(range=[-max_value,max_value], title='Z Value')))

    fig.update_layout(scene = dict(
                    xaxis_title='X Value',
                    yaxis_title='Y Value',
                    zaxis_title='Z Value'))

    # Add a slider to switch between epochs
    steps = []
    for i in tqdm(range(len(fig.data))):
        step = dict(
            method='restyle',
            args=['visible', [False] * len(fig.data)],
            label=f'Epoch {i * image_store_rate}'
        )
        step['args'][1][i] = True  # Toggle i'th trace to "visible"
        steps.append(step)
    sliders = [dict(steps=steps, active=0)]
    fig.update_layout(sliders=sliders, width=800, height=800)

    fig.show()

plot_weights_interactive_3d(embedding_history,'points', False)


In [None]:
import plotly.graph_objects as go
import numpy as np
from tqdm import tqdm
from itertools import cycle
import plotly.express as px

def generate_colors(n):
    color_cycle = cycle(px.colors.qualitative.Plotly)
    colors = [next(color_cycle) for _ in range(n)]
    return colors


def plot_multiple_3d_vectors(vectors_tuple, visualisation = 'points', numbering=False):
    # Initialize figure
    fig = go.Figure()

    # Calculate maximum value of all components
    max_value = np.max([np.abs(vector).max() for vectors_list in vectors_tuple for vector in vectors_list])

    # Colors for different vectors_list in vectors_tuple
    colors = generate_colors(len(vectors_tuple)) if visualisation == 'points' else generate_colors(len(vectors_tuple[0]))

    # Create a scatter plot for each list of vectors
    for col, vectors_list in zip(colors, vectors_tuple):
        for i, vector in tqdm(enumerate(vectors_list)):
            # Convert vector into arrow endpoints and labels
            if visualisation == 'points':
                color = col
            else:
                color = colors[i]
            x_value = vector[0]
            y_value = vector[1]
            z_value = vector[2]
            label = i if numbering else ''

            if visualisation == 'lines':
                # Create line from origin to point
                scatter = go.Scatter3d(x=[0, x_value], y=[0, y_value], z=[0, z_value], mode='lines',
                                line=dict(width=2, color=color), showlegend=False)
            else:
                # Create scatter plot
                scatter = go.Scatter3d(x=[x_value], y=[y_value], z=[z_value], mode='markers+text', text=[label],
                                textposition='bottom center', marker=dict(size=4, color=color), showlegend=False)

            fig.add_trace(scatter)

    # Update xaxis, yaxis, and zaxis properties to use the calculated maximum value
    fig.update_layout(scene = dict(
                        xaxis=dict(range=[-max_value,max_value], title='X Value'),
                        yaxis=dict(range=[-max_value,max_value], title='Y Value'),
                        zaxis=dict(range=[-max_value,max_value], title='Z Value')))

    fig.update_layout(scene = dict(
                        xaxis_title='X Value',
                        yaxis_title='Y Value',
                        zaxis_title='Z Value'))

    # Buttons for zooming
    updatemenus=[
        dict(
            type="buttons",
            showactive=False,
            buttons=[
                dict(label='Original',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1, y=1, z=1))]),
                dict(label='2X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/2, y=1/2, z=1/2))]),
                dict(label='4X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/4, y=1/4, z=1/4))]),
                dict(label='8X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/8, y=1/8, z=1/8))]),
                dict(label='16X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/16, y=1/16, z=1/16))]),
            ]
        ),
    ]

    fig.update_layout(updatemenus=updatemenus, width=800, height=800, showlegend = False)

    fig.show()

def plot_multiple_2d_vectors(vectors_tuple, visualisation = 'points', numbering=False):
    #TOFIX
    # Initialize figure
    fig = go.Figure()

    # Calculate maximum value of all components
    max_value = np.max([np.abs(vector).max() for vectors_list in vectors_tuple for vector in vectors_list])

    # Colors for different vectors_list in vectors_tuple
    colors = generate_colors(len(vectors_tuple)) if visualisation == 'points' else generate_colors(len(vectors_tuple[0]))

    # Create a scatter plot for each list of vectors
    for col, vectors_list in zip(colors, vectors_tuple):
        for i, vector in tqdm(enumerate(vectors_list)):
            # Convert vector into arrow endpoints and labels
            if visualisation == 'points':
                color = col
            else:
                color = colors[i]
            x_value = vector[0]
            y_value = vector[1]
            label = i if numbering else ''

            if visualisation == 'lines':
                # Create line from origin to point
                scatter = go.Scatter(x=[0, x_value], y=[0, y_value], mode='lines',
                                line=dict(width=2, color=color), showlegend=False)
            else:
                # Create scatter plot
                scatter = go.Scatter3d(x=[x_value], y=[y_value], mode='markers+text', text=[label],
                                textposition='bottom center', marker=dict(size=4, color=color), showlegend=False)

            fig.add_trace(scatter)
    fig.data[0].visible = True
    # Update xaxis, yaxis, and zaxis properties to use the calculated maximum value
    fig.update_layout(scene = dict(
                        xaxis=dict(range=[-max_value,max_value], title='X Value'),
                        yaxis=dict(range=[-max_value,max_value], title='Y Value')))

    fig.update_layout(scene = dict(
                        xaxis_title='X Value',
                        yaxis_title='Y Value'))

    # Buttons for zooming
    updatemenus=[
        dict(
            type="buttons",
            showactive=False,
            buttons=[
                dict(label='Original',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1, y=1))]),
                dict(label='2X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/2, y=1/2))]),
                dict(label='4X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/4, y=1/4))]),
                dict(label='8X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/8, y=1/8))]),
                dict(label='16X',
                     method='relayout',
                     args=['scene.camera', dict(eye=dict(x=1/16, y=1/16))]),
            ]
        ),
    ]

    fig.update_layout(width=800, height=800, showlegend = False)

    fig.show()

# Call the function with tuple of lists of vectors
# plot_multiple_3d_vectors(([vector_1_1, vector_1_2], [vector_2_1, vector_2_2]), 'points', True)


In [None]:
a = normalise(unembedding_history[-1][:10])
b = normalise(embedding_history[-1][:10])

plot_multiple_2d_vectors((a,b), 'points', True)

In [None]:
a = model(torch.eye(50)).detach().numpy()
plt.imshow(a,'RdBu')
plt.colorbar()
plt.show()


In [None]:
mat = (model.unembedding.weight@model.embedding.weight).detach().numpy()
plt.imshow(mat, 'RdBu')
plt.colorbar()

In [None]:
relued = nonlinearity(torch.tensor(embedding_history[-1]@unembedding_history[-1].T)).detach().numpy()
prerelu = embedding_history[-1]@unembedding_history[-1].T
plt.imshow(prerelu, cmap = 'Blues_r')
plt.colorbar()

In [None]:
old_model.embedding.weight

In [None]:
relued = F.relu(torch.tensor(old_model.embedding.weight.T@old_model.unembedding.weight.T)).detach().numpy()
prerelu = embedding_history[-1]@unembedding_history[-1].T
plt.imshow(relued, cmap = 'Blues_r')
plt.colorbar()

In [None]:
relued = torch.nn.functional.relu(torch.tensor(embedding_history[-1].T@unembedding_history[-1])).detach().numpy()
prerelu = embedding_history[-1].T@unembedding_history[-1]
plt.imshow(prerelu, cmap = 'Blues_r')
plt.colorbar()

In [None]:
plt.plot(torch.sigmoid(torch.linspace(-5,5,200)))

In [None]:
def normalise(matrix):
    out = np.zeros_like(matrix)
    for i,row in enumerate(matrix):
        norm = (row.T @ row) ** 0.5
        out[i] = row / norm
    return out

def np_relu(matrix):
    return np.maximum(matrix,0)

In [None]:
max_prods = []
normalised = normalise(weights_history[-1])
for i,row1 in enumerate(normalised):
    rowprods = []
    for j,row2 in enumerate(normalised):
        if i != j:
            rowprods.append(row1.T@row2)
    max_prods.append(max(rowprods))

plt.figure(figsize = (20,10))
plt.hist(max_prods, bins = 1000, log = True)
plt.show()

In [None]:
prods = []
normalised = normalise(embedding_history[-1])
for i in normalised:
    for j in normalised:
        prods.append(np.arccos(i.T @ j) * 180/np.pi)
plt.figure(figsize = (20,10))
plt.hist(prods, bins = 100)
plt.show()

In [None]:
a = normalise(unembedding_history[-1])
b = normalise(embedding_history[-1])
plt.imshow(np_relu(a@b.T), 'Blues_r')
plt.colorbar()

In [None]:
def group_vectors(vectors, epsilon):
    # Store the groups of similar vectors here
    groups = []
    norms = []
    directions = []

    for v in vectors:
        # Normalize the current vector
        v_norm = v / np.linalg.norm(v)
        if np.linalg.norm(v) < 0.01:
            continue

        # This flag will tell us if the current vector has been added to any group
        added_to_group = False

        # Go through each existing group to check if this vector belongs there
        for i,group in enumerate(groups):
            # We use the first vector in the group as representative
            group_representative = group[0]
            group_representative_norm = group_representative / np.linalg.norm(group_representative)

            # Calculate the dot product between the normalized vectors
            dot_product = np.dot(v_norm, group_representative_norm)

            # Check if the dot product is close enough to 1 (indicating they are scalar multiples of each other)
            if np.abs(dot_product - 1) < epsilon:
                group.append(v)
                norms[i].append(v_norm)
                added_to_group = True
                break

        # If the current vector has not been added to any group, we create a new group for it
        if not added_to_group:
            groups.append([v])
            norms.append([v_norm])
    
    for norm in norms:
        arr = np.array(norm)
        directions.append(np.mean(arr,axis=0))

        

    return groups,directions


In [None]:
groups,directions = group_vectors(weights_history[-1], 0.05)
print(len(directions))
print([len(group) for group in groups])

In [None]:
d_matrix = np.array(directions)
np.mean(d_matrix, axis = 0)

In [None]:


def visualize_3d_vectors(matrix):
    """
    Visualize a set of 3D vectors in a 3D plot.

    Args:
    matrix (np.array): A numpy array where each row is a 3D vector.
    """

    # Create an empty figure
    fig = go.Figure()

    # For each vector in the matrix, plot a line from the origin to the point defined by the vector
    for vector in matrix:
        fig.add_trace(go.Scatter3d(
            x=[0, vector[0]],
            y=[0, vector[1]],
            z=[0, vector[2]],
            mode='lines',
        ))

    # Update layout for better view
    fig.update_layout(scene = dict(
                    xaxis_title='X',
                    yaxis_title='Y',
                    zaxis_title='Z'),
                    width=700,
                    margin=dict(r=20, l=10, b=10, t=10))
    # Show the plot
    fig.show()


In [None]:
U,S,V = np.linalg.svd(d_matrix)
S

In [None]:
visualize_3d_vectors(d_matrix)

In [None]:
def dimensionality(directions):
    out = []
    for w_i in directions:
        interference = 0
        for w_j in directions:
            interference += np.dot(w_i,w_j)**2
        out.append(1/interference)
    return out

In [None]:
plt.scatter([len(g) for g in groups], dimensionality(directions))

In [None]:
plt.plot(sorted(dimensionality(directions)))

In [None]:
a = np.array([[round(np.mean(normalise(groups[i]),axis=0).T @ np.mean(normalise(groups[j]), axis = 0),4) for j in range(len(groups)) if j != i] for i in range(len(groups))])
sorted(a.flatten(), reverse=True)[:10]

In [None]:
plt.hist(a.flatten(), bins = 100)
plt.show()

In [None]:
sorted(dimensionality(directions))

In [None]:
np.array([dimensionality(directions), [len(g) for g in groups]]).T


In [None]:
plt.plot([model(SyntheticNormalised(f)[i]).argmax() for i in range(f)])