In [None]:
import numpy as np
from scipy.optimize import minimize
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import time
import plotly.graph_objs as goa
import matplotlib as mpl
import torch.nn.functional as F
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import copy
from itertools import combinations
from torch.optim.lr_scheduler import _LRScheduler
from fancy_einsum import einsum
from scipy.spatial import ConvexHull
import os
import examples_setup as utils
mpl.style.use('seaborn-v0_8')
mpl.rcParams['figure.figsize'] = (15,10)
fontsize = 20
mpl.rcParams['font.size'] = fontsize
mpl.rcParams['xtick.labelsize'] = fontsize
mpl.rcParams['ytick.labelsize'] = fontsize
mpl.rcParams['legend.fontsize'] = fontsize
mpl.rcParams['axes.titlesize'] = fontsize
mpl.rcParams['axes.labelsize'] = fontsize

In [None]:
# your chosen seed
chosen_seed = 12
utils.set_seed(chosen_seed)
#Checking for errors
lr_print_rate = 0
# Configure the hyperparameters
f = 6
k = 1
n = 2
MSE = True #else Crossentropy
nonlinearity = F.relu
tied = False
final_bias = False
hidden_bias = False
unit_weights = False
learnable_scale_factor = False
initial_scale_factor = 1# (1/(1-np.cos(2*np.pi/f)))**0.5
standard_magnitude = False
initial_embed = None
initial_bias = None
epochs = 20000
logging_loss = True
#Scheduler params
max_lr = 0.08
initial_lr = 0.02
warmup_frac = 0.05
final_lr = 0.02
decay_factor=(final_lr/max_lr)**(1/(epochs * (1-warmup_frac)))
warmup_steps = int(epochs * warmup_frac)
store_rate = epochs//100
plot_rate=0 #epochs/5
# Instantiate synthetic dataset
dataset = utils.SyntheticKHot(f,k)
batch_size = len(dataset) #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25
model = utils.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)

# Train the model
losses, weights_history, model_history = utils.train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, store_rate, scheduler, lr_print_rate)

In [None]:
utils.plot_weights_interactive(weights_history, store_rate=store_rate)


In [None]:
post_relu = [model(torch.eye(f)).cpu().detach().numpy() for model in model_history.values()]
post_softmax = [model(torch.eye(f)).softmax(dim=1).cpu().detach().numpy() for model in model_history.values()]
pre_relu = []
for model in model_history.values():
    out, activations = model(torch.eye(f), hooked=True)
    pre_relu.append(activations['unembed_pre'].cpu().detach().numpy())
if not MSE:
    utils.visualize_matrices_with_slider(post_softmax, store_rate, const_colorbar=True)
utils.visualize_matrices_with_slider([p for p in post_relu], store_rate, const_colorbar=True, plot_size = 800)
utils.visualize_matrices_with_slider(pre_relu, store_rate, const_colorbar=True)

In [None]:
# your chosen seed
chosen_seed = 12
utils.set_seed(chosen_seed)
#Checking for errors
lr_print_rate = 0
# Configure the hyperparameters
f = 40
k = 1
n = 3
MSE = True #else Crossentropy
nonlinearity = F.relu
tied = False
final_bias = False
hidden_bias = False
unit_weights = False
learnable_scale_factor = False
initial_scale_factor = 1# (1/(1-np.cos(2*np.pi/f)))**0.5
standard_magnitude = False
initial_embed = None
initial_bias = None
epochs = 50000
logging_loss = True
#Scheduler params
max_lr = 2
initial_lr = 0.1
warmup_frac = 0.05
final_lr = 0.5
decay_factor=(final_lr/max_lr)**(1/(epochs * (1-warmup_frac)))
warmup_steps = int(epochs * warmup_frac)
store_rate = epochs//100
plot_rate=0 #epochs/5
# Instantiate synthetic dataset
dataset = utils.SyntheticKHot(f,k)
batch_size = len(dataset) #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25
model = utils.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)

# Train the model
losses, weights_history, model_history = utils.train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, store_rate, scheduler, lr_print_rate)

In [None]:
plt.plot(losses)

In [None]:
utils.plot_weights_interactive(weights_history, store_rate=store_rate)
