In [None]:
import numpy as np
from scipy.optimize import minimize
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import time
import plotly.graph_objs as goa
import matplotlib as mpl
import torch.nn.functional as F
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import copy
from itertools import combinations
from torch.optim.lr_scheduler import _LRScheduler
from fancy_einsum import einsum
from scipy.spatial import ConvexHull
import os
from typing import Dict, List, Optional, Union, Any, Callable
from torch.autograd import grad
import sys
sys.path.append("/Users/jakemendel/Desktop/Code/FeatureFinding") 
from FeatureFinding import utils, datasets, models
mpl.style.use('seaborn-v0_8')
mpl.rcParams['figure.figsize'] = (15,10)
fontsize = 20
mpl.rcParams['font.size'] = fontsize
mpl.rcParams['xtick.labelsize'] = fontsize
mpl.rcParams['ytick.labelsize'] = fontsize
mpl.rcParams['legend.fontsize'] = fontsize
mpl.rcParams['axes.titlesize'] = fontsize
mpl.rcParams['axes.labelsize'] = fontsize

In [None]:
# your chosen seed
chosen_seed = 12
utils.set_seed(chosen_seed)

#Checking for errors
lr_print_rate = 0


# Configure the hyperparameters
f = 40
k = 1
n = 2
MSE = True #else Crossentropy
nonlinearity = F.relu
tied = False
final_bias = True
hidden_bias = False
unit_weights = False
learnable_scale_factor = False
initial_scale_factor = 1# (1/(1-np.cos(2*np.pi/f)))**0.5
standard_magnitude = False
initial_embed = None
initial_bias = None


epochs = 250000
logging_loss = True

#Scheduler params
max_lr = 5
initial_lr = 0.001
warmup_frac = 0.05
final_lr = 2
decay_factor=(final_lr/max_lr)**(1/(epochs * (1-warmup_frac)))
warmup_steps = int(epochs * warmup_frac)


store_rate = epochs//100
plot_rate=0 #epochs/5


# Instantiate synthetic dataset
dataset = utils.SyntheticKHot(f,k)
batch_size = len(dataset) #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25
model = utils.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)


# Train the model
losses, weights_history, model_history = utils.train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, store_rate, scheduler, lr_print_rate)

In [None]:
utils.plot_weights_interactive(weights_history, store_rate=store_rate)

In [None]:
plt.plot(losses[1000:])


In [None]:
post_relu = [model(torch.eye(f)).cpu().detach().numpy() for model in model_history.values()]
post_softmax = [model(torch.eye(f)).softmax(dim=1).cpu().detach().numpy() for model in model_history.values()]
pre_relu = []
for model in model_history.values():
    out, activations = model(torch.eye(f), hooked=True)
    pre_relu.append(activations['unembed_pre'].cpu().detach().numpy())
if not MSE:
    utils.visualize_matrices_with_slider(post_softmax, store_rate, const_colorbar=True)
utils.visualize_matrices_with_slider([p for p in post_relu], store_rate, const_colorbar=True, plot_size = 800)
utils.visualize_matrices_with_slider(pre_relu, store_rate, const_colorbar=True)

In [None]:
# your chosen seed
chosen_seed = 12
utils.set_seed(chosen_seed)

#Checking for errors
lr_print_rate = 0


# Configure the hyperparameters
f = 40
k = 1
n = 2
MSE = True #else Crossentropy
nonlinearity = F.relu
tied = False
final_bias = True
hidden_bias = False
unit_weights = False
learnable_scale_factor = False
initial_scale_factor = 1# (1/(1-np.cos(2*np.pi/f)))**0.5
standard_magnitude = False
initial_embed = None
initial_bias = None


epochs = 200000
logging_loss = True

#Scheduler params
max_lr = 5
initial_lr = 0.001
warmup_frac = 0.05
final_lr = 2
decay_factor=(final_lr/max_lr)**(1/(250000 * (1-warmup_frac)))
warmup_steps = int(250000 * warmup_frac)


store_rate = epochs//500
plot_rate=0 #epochs/5


# Instantiate synthetic dataset
dataset = utils.SyntheticKHot(f,k)
batch_size = len(dataset) #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25
model = utils.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)


# Train the model
losses, weights_history, model_history = utils.train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, store_rate, scheduler, lr_print_rate)

In [None]:
utils.plot_weights_interactive(weights_history, store_rate=store_rate)


In [None]:
#Fix the no
class ControllableNet(utils.Net):
    def __init__(self, net: utils.Net, controlled_rows):
        super().__init__(self, net.input_dim, net.hidden_dim, tied = False, final_bias = True, hidden_bias = False, nonlinearity = F.relu, unit_weights=False, learnable_scale_factor = False, standard_magnitude = False, initial_scale_factor = 1.0, initial_embed = None, initial_bias = None)

        # Define the input layer (embedding)
        self.controlled_mask = torch.zeros(len(net.embedding.weight.data), dtype=torch.bool)
        self.controlled_mask[controlled_rows] = True
        
        self.embedding_free = nn.Linear(self.input_dim - len(controlled_rows), self.hidden_dim, bias=False)
        self.embedding_free.weight.data = net.embedding.weight.data[~self.controlled_mask]

        self.embedding_controlled = nn.Linear(len(controlled_rows), self.hidden_dim, bias=False)
        self.embedding_controlled.weight.data = net.embedding.weight.data[self.controlled_mask]
        for param in model.embedding_controlled.parameters():
            param.requires_grad = False

        self.unembedding_free = nn.Linear(self.hidden_dim, self.input_dim - len(controlled_rows), bias=True)
        self.unembedding_free.weight.data = net.unembedding.weight.data[:,~self.controlled_mask]
        self.unembedding_free.bias.data = net.unembedding.bias.data[~self.controlled_mask]

        self.unembedding_controlled = nn.Linear(self.hidden_dim, len(controlled_rows), bias=True)
        self.unembedding_controlled.weight.data = net.unembedding.weight.data[:,self.controlled_mask]
        self.unembedding_controlled.bias.data = net.unembedding.bias.data[~self.controlled_mask]

        for param in model.unembedding_controlled.parameters():
            param.requires_grad = False
    
    def embedding(self,x):
        matrix = torch.zeros(self.input_dim, self.hidden_dim)
        matrix[self.controlled_mask] = self.embedding_controlled.weight
        matrix[~self.controlled_mask] = self.embedding_free.weight
        return matrix@x

    def unembedding(self,x):
        matrix = torch.zeros(self.hidden_dim, self.input_dim)
        matrix[self.controlled_mask] = self.unembedding_controlled.weight
        matrix[~self.controlled_mask] = self.unembedding_free.weight
        vector = torch.zeros(self.input_dim)
        vector[self.controlled_mask] = self.unembedding_controlled.bias
        vector[~self.controlled_mask] = self.unembedding_free.bias
        return matrix@x + vector


    def forward(self, x, hooked = False):
        if self.tied:
            self.unembedding.weight.data = self.embedding.weight.data.transpose(0, 1)
        if hooked:
            activations = {}
            activations['res_pre'] = self.embedding(x)
            activations['unembed_pre'] = self.unembedding(activations['res_pre'])
            activations['output'] = self.scale_factor * self.nonlinearity(activations['unembed_pre'])
            return activations['output'], activations
        else:
            x = self.embedding(x)
            x = self.unembedding(x)
            x = self.nonlinearity(x)
            return self.scale_factor * x

In [None]:
# your chosen seed
chosen_seed = 12
utils.set_seed(chosen_seed)

#Checking for errors
lr_print_rate = 0


# Configure the hyperparameters
f = 40
k = 1
n = 2
MSE = True #else Crossentropy
nonlinearity = utils.relu_plusone
tied = True
final_bias = True
hidden_bias = False
unit_weights = False
learnable_scale_factor = False
initial_scale_factor = 1# (1/(1-np.cos(2*np.pi/f)))**0.5
standard_magnitude = False
initial_embed = None
initial_bias = None


epochs = 40000
total_epochs = 300000
logging_loss = True

#Scheduler params
max_lr = 5
initial_lr = 0.001
warmup_frac = 0.05
final_lr = 2
decay_factor=(final_lr/max_lr)**(1/(total_epochs * (1-warmup_frac)))
warmup_steps = int(total_epochs * warmup_frac)


store_rate = epochs//100
plot_rate=0 #epochs/5


# Instantiate synthetic dataset
dataset = utils.SyntheticKHot(f,k)
batch_size = len(dataset) #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25
model = utils.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)


# Train the model
losses, weights_history, model_history = utils.train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, store_rate, scheduler, lr_print_rate)

In [None]:
plt.plot(losses[1000:50000])

In [None]:
utils.plot_weights_interactive(weights_history, store_rate=store_rate)

In [None]:
post_relu = [model(torch.eye(f)).cpu().detach().numpy() for model in model_history.values()]
post_softmax = [model(torch.eye(f)).softmax(dim=1).cpu().detach().numpy() for model in model_history.values()]
pre_relu = []
for model in model_history.values():
    out, activations = model(torch.eye(f), hooked=True)
    pre_relu.append(activations['unembed_pre'].cpu().detach().numpy())
if not MSE:
    utils.visualize_matrices_with_slider(post_softmax, store_rate, const_colorbar=True)
utils.visualize_matrices_with_slider([p for p in post_relu], store_rate, const_colorbar=True, plot_size = 800)
utils.visualize_matrices_with_slider(pre_relu, store_rate, const_colorbar=True)