In [None]:
import numpy as np
from scipy.optimize import minimize
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import time
import plotly.graph_objs as goa
import matplotlib as mpl
import torch.nn.functional as F
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import copy
from itertools import combinations
from torch.optim.lr_scheduler import _LRScheduler
from scipy.spatial import ConvexHull
import os
import sys
sys.path.append("/Users/jakemendel/Desktop/Code/FeatureFinding") 
from FeatureFinding import utils, datasets, models
from torch.autograd import grad
from typing import Dict, List, Optional, Union, Any, Callable
import importlib
mpl.style.use('seaborn-v0_8')
mpl.rcParams['figure.figsize'] = (15,10)
fontsize = 20
mpl.rcParams['font.size'] = fontsize
mpl.rcParams['xtick.labelsize'] = fontsize
mpl.rcParams['ytick.labelsize'] = fontsize
mpl.rcParams['legend.fontsize'] = fontsize
mpl.rcParams['axes.titlesize'] = fontsize
mpl.rcParams['axes.labelsize'] = fontsize

In [None]:
importlib.reload(utils)
# your chosen seed
chosen_seed = 2
utils.set_seed(chosen_seed)

#Checking for errors
lr_print_rate = 0


# Configure the hyperparameters
f = 40
k = 1
n = 2
MSE = True #else Crossentropy
nonlinearity = utils.relu_plusone
tied = True
final_bias = True
hidden_bias = False
unit_weights = False
learnable_scale_factor = False
initial_scale_factor = 1# (1/(1-np.cos(2*np.pi/f)))**0.5
standard_magnitude = False
initial_embed = None
initial_bias = None


epochs = 40000
total_epochs = 300000
logging_loss = True

#Scheduler params
max_lr = 5
initial_lr = 0.001
warmup_frac = 0.05
final_lr = 2
decay_factor=(final_lr/max_lr)**(1/(total_epochs * (1-warmup_frac)))
warmup_steps = int(total_epochs * warmup_frac)


store_rate = epochs//400
plot_rate=0 #epochs/5


# Instantiate synthetic dataset
dataset = datasets.SyntheticKHot(f,k)
batch_size = len(dataset)//2 #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25
model = models.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)


# Train the model
losses, weights_history, model_history = utils.train(model,
                                                     loader,criterion,
                                                     optimizer,
                                                     epochs,
                                                     logging_loss,
                                                     plot_rate, 
                                                     store_rate,
                                                     scheduler,
                                                     lr_print_rate,
                                                     dtype = torch.float64)

In [None]:
plt.plot(losses[1000:])
plt.show()

In [None]:
utils.plot_weights_interactive(weights_history, store_rate=store_rate)

In [None]:

loss_dict = {i: loss for i, loss in enumerate(losses) if i >=1000}
utils.plot_weights_static(weights_history, loss_dict,store_rate=store_rate, epochs_to_show = [1000,17000,6000,8000,13500,16400,25000,16700,39900],to_label = [12,24,31],scale = 1.6,num_across=3)

In [None]:
past_model = model_history[27000]
np.dot(past_model.embedding.weight.T[10].detach().numpy(), past_model.embedding.weight.T[2].detach().numpy())

In [None]:
model_history32 = {k: model.float() for k, model in model_history.items()}
post_relu = [model(torch.eye(f)).cpu().detach().numpy() for model in model_history32.values()]
post_softmax = [model(torch.eye(f)).softmax(dim=1).cpu().detach().numpy() for model in model_history32.values()]
pre_relu = []
for model in model_history.values():
    out, activations = model(torch.eye(f), hooked=True)
    pre_relu.append(activations['unembed_pre'].cpu().detach().numpy())
if not MSE:
    utils.visualize_matrices_with_slider(post_softmax, store_rate, const_colorbar=True)
utils.visualize_matrices_with_slider([p for p in post_relu], store_rate, const_colorbar=True, plot_size = 800)
utils.visualize_matrices_with_slider(pre_relu, store_rate, const_colorbar=True)

In [None]:
importlib.reload(utils)
r2_data = {}
b_data = {}
r2_b_data = {}
for epoch, model in model_history.items():
    weights = model.embedding.weight.T[torch.norm(model.embedding.weight,dim=0)>0.1]
    b = model.unembedding_bias[torch.norm(model.embedding.weight,dim=0)>0.1]+1
    r2 = torch.norm(weights,dim=1)**2
    r2_data[epoch] = r2.detach().numpy()
    b_data[epoch] = b.detach().numpy()
    r2_b_data[epoch] = (r2+b).detach().numpy()

utils.plot_histograms(r2_data,bins=40)
utils.plot_histograms(b_data,bins=40)
utils.plot_histograms(r2_b_data,bins=40)



In [None]:
plt.plot(r2_b_data.keys(),[np.mean(nums) for nums in r2_b_data.values()], label = '$r^2+b$')
# plt.plot(r2_b_data.keys(),[np.mean(r2s-bs) for r2s,bs in zip(r2_data.values(),b_data.values())], label = '$r^2-b$')

# plt.plot(r2_data.keys(),[np.mean(nums**2) for nums in r2_data.values()], label = '$r^2$')
# plt.plot(b_data.keys(),[np.mean(nums) for nums in b_data.values()], label = '$b$')
plt.xlabel('Epoch')
plt.ylabel('Value')
plt.legend()
plt.show()

In [None]:
importlib.reload(utils)
reduced_model_history = {k:v.to(dtype = torch.float) for k,v in model_history.items() if k%500 == 0}
final_layer_dict = utils.final_layer_dict(reduced_model_history)
utils.visualise_activation_space_history(final_layer_dict, n_points = 100)


In [None]:
importlib.reload(utils)
interference_data = {}
dot_with = 12
ignore = [12,24,31]
for epoch, model in tqdm(model_history.items()):
    weights = model.embedding.weight.T.detach().numpy()
    b = model.unembedding_bias.detach().numpy()+1
    special_b = b[dot_with]
    special_w = weights[dot_with]
    interference = 0
    interfering_row = None
    for i, row in enumerate(weights):
        if np.linalg.norm(row) > 0.1:
            if i not in ignore:
                new_interference = np.dot(special_w, row) + max(special_b,b[i])
                if new_interference > interference:
                    interference = new_interference
                    interfering_row = i
    interference_data[epoch] = interference, interfering_row

responsible_indices = set([i[1] for i in interference_data.values()])
for index in responsible_indices:
    x = []
    y = []
    for i,(epoch,datum) in enumerate(interference_data.items()):
        if datum[1] == index:
            x.append(epoch)
            y.append(datum[0])
    plt.semilogy(x,y, label = f'Interference with {index}')
plt.legend()
plt.ylabel('Max Interference')
plt.show()

In [None]:
responsible_indices = set([i[1] for i in interference_data.values()])
plt.figure()
for index in responsible_indices:
    x = []
    y = []
    for i,(epoch,datum) in enumerate(interference_data.items()):
        if datum[1] == index:
            x.append(epoch)
            y.append(datum[0])
    plt.semilogy(x,y, label = f'Interference with {index}')
plt.legend()
plt.ylabel('Max Interference')
plt.show()

In [None]:
sequences = {}
for epoch, interferences in interference_data.items():
    for i,interference in enumerate(interferences):
        if i in sequences.keys():
            sequences[i]['epochs'].append(epoch)
            sequences[i]['vals'].append(interference)
        else:
            sequences[i] = {'epochs': [epoch], 'vals': [interference]}
for i, data in sequences.items():
    plt.plot(data['epochs'],data['vals'], label = f'$W_{10}\cdot W_{i} + b_{i}$')
plt.legend()


In [None]:
np.array([3,4]).tolist() in [i.tolist() for i in c]

In [None]:
groups, directions = utils.group_vectors(model_history[10000].embedding.weight.T[torch.norm(model_history[10000].embedding.weight,dim=0)>0.1].detach().numpy(),0.001)

In [None]:
va = np.array([len(g) for g in groups])
np.sum(va)/np.sum(va**2)

In [None]:
torch.concat((torch.tensor(np.array(directions)), torch.tensor(np.array(directions)))).T

In [None]:
utils.calculate_angles(torch.concat((torch.tensor(np.array(directions)), torch.tensor(np.array(directions)))).T) * 180/np.pi

In [None]:
model_history[10000].embedding.weight[torch.norm(model_history[10000].embedding.weight,dim=0)>0.1]

In [None]:
list(model_history.values())[-1].unembedding_bias + 1

In [None]:
torch.norm(model_history[10000].embedding.weight,dim=0)**2 + model_history[10000].unembedding_bias + 1

In [None]:
plt.plot(list(model_history.keys())[10:], [np.trace(matrix) for matrix in post_relu][10:], )

In [None]:
importlib.reload(utils)
differences = np.array([np.linalg.norm(matrix[:,31]-matrix[:,12]) for matrix in weights_history['embedding.weight']])
plt.semilogy(model_history.keys(), differences)
plt.xlabel('Epoch')
plt.ylabel('Distance between feature vectors')

In [None]:
importlib.reload(utils)
for b in loader:
    batch = b
assert isinstance(batch, torch.Tensor)
batch32 = batch.float()
model_history32 = {k: model.float() for k, model in model_history.items()}
model_history_hessians = {k:v for k,v in model_history32.items()}

hessians_dict, eigenvalues_dict = utils.calculate_hessians(model_history_hessians,batch32,batch32,nn.MSELoss()) 


In [None]:
#not including embedding.weight
utils.hist_eigenvalues(eigenvalues_dict, 200)

In [None]:
#not including unembedding.weight
utils.hist_eigenvalues(eigenvalues_dict, 200)

In [None]:
epsilons = [0.001,0.003,0.01,0.03]

utils.plot_eigenvalues_in_range(eigenvalues_dict,ranges = [()])

In [None]:
loss_func = nn.MSELoss()
inputs = batch.float()
targets = batch.float()
output = model.float()(inputs)
loss = loss_func(output, targets)

# Compute the gradient of the loss with respect to the model parameters
grad_params = grad(loss, model.parameters(), create_graph=True)
print([g.shape for g in grad_params])
grad_params = torch.concat(tuple([g.reshape(-1) for g in grad_params]))


In [None]:
zeros_only_model_history = {epoch: copy.deepcopy(model) for epoch, model in model_history}
model.unembedding.weight[torch.norm(model.embedding.weight.data,dim=0) > 0.05] = np.inf

In [None]:
{k:v for k,v in model.named_parameters()}

In [None]:
n = 20
W_E = torch.zeros((n,3))
W_U = torch.zeros((n,3))
r = np.sqrt(1/(1-np.cos(2*np.pi/n)))
b = 1-r**2
for i,theta in enumerate(np.arange(n) * 2*np.pi/n):
    W_E[i] = torch.tensor([r*np.cos(theta), r*np.sin(theta), np.sqrt(-b)])
    W_U[i] = torch.tensor([r*np.cos(theta), r*np.sin(theta), -np.sqrt(-b)])
utils.plot_weights_interactive({'Embed': [W_E.T], 'Unembed': [W_U.T]})
utils.visualize_matrices_with_slider([F.relu(W_U @ W_E.T)],1)