In [None]:
import numpy as np
from scipy.optimize import minimize
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import time
import plotly.graph_objs as goa
import matplotlib as mpl
import torch.nn.functional as F
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import copy
from itertools import combinations
from torch.optim.lr_scheduler import _LRScheduler
from fancy_einsum import einsum
from scipy.spatial import ConvexHull
import os
import examples_setup as utils
import copy
mpl.style.use('seaborn-v0_8')
mpl.rcParams['figure.figsize'] = (15,10)
fontsize = 20
mpl.rcParams['font.size'] = fontsize
mpl.rcParams['xtick.labelsize'] = fontsize
mpl.rcParams['ytick.labelsize'] = fontsize
mpl.rcParams['legend.fontsize'] = fontsize
mpl.rcParams['axes.titlesize'] = fontsize
mpl.rcParams['axes.labelsize'] = fontsize

In [None]:
# your chosen seed
chosen_seed = 12
utils.set_seed(chosen_seed)

#Checking for errors
lr_print_rate = 0


# Configure the hyperparameters
f = 40
k = 1
n = 2
MSE = True #else Crossentropy
nonlinearity = F.relu
tied = False
final_bias = True
hidden_bias = False
unit_weights = False
learnable_scale_factor = False
initial_scale_factor = 1# (1/(1-np.cos(2*np.pi/f)))**0.5
standard_magnitude = False
initial_embed = None
initial_bias = None


epochs = 150000
all_epochs = 250000
logging_loss = True

#Scheduler params
max_lr = 5
initial_lr = 0.001
warmup_frac = 0.05
final_lr = 2
decay_factor=(final_lr/max_lr)**(1/(all_epochs * (1-warmup_frac)))
warmup_steps = int(all_epochs * warmup_frac)


store_rate = epochs//100
plot_rate=0 #epochs/5


# Instantiate synthetic dataset
dataset = utils.SyntheticKHot(f,k)
batch_size = len(dataset) #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25
model = utils.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)


# Train the model
losses, weights_history, model_history = utils.train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, store_rate, scheduler, lr_print_rate)

In [None]:
utils.plot_weights_interactive(weights_history, store_rate=store_rate)


In [None]:
def calculate_angle(v1, v2, degrees = False):
    dot_product = torch.dot(v1, v2)
    v1_norm = torch.linalg.norm(v1)
    v2_norm = torch.linalg.norm(v2)
    cos_theta = dot_product / (v1_norm * v2_norm)

    # Convert to degrees
    theta = torch.acos(cos_theta)
    if degrees:
        theta *= (180 / torch.pi)
    return theta

class ControllableNet(utils.Net):
    def __init__(self, net: utils.Net, free_rows = [11]):
        super().__init__(net.input_dim, net.hidden_dim, tied = False, final_bias = True, hidden_bias = False, nonlinearity = F.relu, unit_weights=False, learnable_scale_factor = False, standard_magnitude = False, initial_scale_factor = 1.0, initial_embed = None, initial_bias = None)
        self.free_rows = free_rows
        # Define the input layer (embedding)
        self.controlled_mask = torch.ones(len(net.embedding.weight.data.T), dtype=torch.bool)
        self.controlled_mask[free_rows] = False
        
        self.embedding_free = nn.Linear(len(free_rows), self.hidden_dim, bias=False)
        self.embedding_free.weight.data = net.embedding.weight.data[:,~self.controlled_mask]

        self.embedding_controlled = nn.Linear(self.input_dim - len(free_rows), self.hidden_dim, bias=False)
        self.embedding_controlled.weight.data = net.embedding.weight.data[:,self.controlled_mask]
        for param in self.embedding_controlled.parameters():
            param.requires_grad = False

        self.unembedding_free = nn.Linear(self.hidden_dim, len(free_rows), bias=True)
        self.unembedding_free.weight.data = net.unembedding.weight.data[~self.controlled_mask]
        self.unembedding_free.bias.data = net.unembedding.bias.data[~self.controlled_mask]

        self.unembedding_controlled = nn.Linear(self.hidden_dim, self.input_dim - len(free_rows), bias=True)
        self.unembedding_controlled.weight.data = net.unembedding.weight.data[self.controlled_mask]
        self.unembedding_controlled.bias.data = net.unembedding.bias.data[~self.controlled_mask]


        self.embedding = nn.Linear(self.input_dim, self.hidden_dim, bias=False)
        self.unembedding = nn.Linear(self.hidden_dim, self.input_dim, bias=False)
        for param in self.embedding.parameters():
            param.requires_grad = False
        for param in self.unembedding.parameters():
            param.requires_grad = False

        self.embedding.weight.data = net.embedding.weight.data
        self.unembedding.weight.data = net.unembedding.weight.data
        self.unembedding.bias.data = net.unembedding.bias.data

        for param in self.unembedding_controlled.parameters():
            param.requires_grad = False
    
    def angle_between(self, rows = [1,18], in_embedding = True):
        matrix = self.embedding.weight.data.T if in_embedding else self.unembedding.weight.data
        vec1, vec2 = matrix[rows[0]], matrix[rows[1]]
        return calculate_angle(vec1,vec2)
    
    def reduce_angle(self, rows = [1,18], in_embedding = True, reduction = 0.01):
        assert all([row not in self.free_rows for row in rows])
        matrix = self.embedding.weight.data.T if in_embedding else self.unembedding.weight.data
        vec1, vec2 = matrix[rows[0]], matrix[rows[1]]
        magnitude1, magnitude2 = torch.linalg.norm(vec1), torch.linalg.norm(vec2)
        theta = calculate_angle(vec1, vec2)

        average = 1/2 * (vec1/magnitude1 + vec2/magnitude2)
        phi = torch.atan2(average[1],average[0])

        new_v1 = magnitude1 * torch.tensor([torch.cos(phi - theta/2 + reduction/2), torch.sin(phi - theta/2 + reduction/2)])
        new_v2 = magnitude2 * torch.tensor([torch.cos(phi + theta/2 - reduction/2), torch.sin(phi + theta/2 - reduction/2)])
        assert calculate_angle(new_v1, new_v2) == theta - reduction, print(theta, calculate_angle(new_v1, new_v2))
        matrix[rows[0]] = new_v1
        matrix[rows[1]] = new_v2
        if in_embedding:
            self.embedding.weight.data = matrix.T
            self.embedding_controlled.weight.data = self.embedding.weight.data[:,self.controlled_mask]
        if not in_embedding:
            self.embedding.weight.data = matrix.T
            self.embedding_controlled.weight.data = self.embedding.weight.data[:,self.controlled_mask]
    
    def embedding_forward(self,x):
        matrix = torch.zeros(self.input_dim, self.hidden_dim)
        matrix[self.controlled_mask] = self.embedding_controlled.weight
        matrix[~self.controlled_mask] = self.embedding_free.weight
        return matrix@x

    def unembedding_forward(self,x):
        matrix = torch.zeros(self.hidden_dim, self.input_dim)
        matrix[self.controlled_mask] = self.unembedding_controlled.weight
        matrix[~self.controlled_mask] = self.unembedding_free.weight
        vector = torch.zeros(self.input_dim)
        vector[self.controlled_mask] = self.unembedding_controlled.bias
        vector[~self.controlled_mask] = self.unembedding_free.bias
        return matrix@x + vector

    def forward(self, x, hooked = False):
        if hooked:
            activations = {}
            activations['res_pre'] = self.embedding_forward(x)
            activations['unembed_pre'] = self.unembedding_forward(activations['res_pre'])
            activations['output'] = self.scale_factor * self.nonlinearity(activations['unembed_pre'])
            return activations['output'], activations
        else:
            x = self.embedding_forward(x)
            x = self.unembedding_forward(x)
            x = self.nonlinearity(x)
            return self.scale_factor * x

In [None]:
def reduce_angle(model, rows = [1,18], in_embedding = True, reduction = 0.01):
    matrix = model.embedding.weight.data.T if in_embedding else model.unembedding.weight.data
    vec1, vec2 = matrix[rows[0]], matrix[rows[1]]
    magnitude1, magnitude2 = torch.linalg.norm(vec1), torch.linalg.norm(vec2)
    theta = calculate_angle(vec1, vec2)

    average = 1/2 * (vec1/magnitude1 + vec2/magnitude2)
    phi = torch.atan2(average[1],average[0])

    new_v1 = magnitude1 * torch.tensor([torch.cos(phi - theta/2 + reduction/2), torch.sin(phi - theta/2 + reduction/2)])
    new_v2 = magnitude2 * torch.tensor([torch.cos(phi + theta/2 - reduction/2), torch.sin(phi + theta/2 - reduction/2)])
    assert calculate_angle(new_v1, new_v2) == theta - reduction, print(theta, calculate_angle(new_v1, new_v2))
    matrix[rows[0]] = new_v1
    matrix[rows[1]] = new_v2
    if in_embedding:
        model.embedding.weight.data = matrix.T
    if not in_embedding:
        model.unembedding.weight.data = matrix

In [None]:
def weights_schedule(model, rows, epochs, end = 0, in_embedding = False):
    matrix = model.embedding.weight.data.T if in_embedding else model.unembedding.weight.data
    vec1, vec2 = matrix[rows[0]], matrix[rows[1]]
    magnitude1, magnitude2 = torch.linalg.norm(vec1), torch.linalg.norm(vec2)
    theta = calculate_angle(vec1, vec2).item()
    average = 1/2 * (vec1/magnitude1 + vec2/magnitude2)
    phi = torch.atan2(average[1],average[0]).item()
    reductions = np.linspace(theta, theta * end, epochs)

    new_v1 = magnitude1 * torch.tensor(np.array([np.cos(phi - theta/2 + reductions/2), np.sin(phi - theta/2 + reductions/2)]))
    new_v2 = magnitude2 * torch.tensor(np.array([np.cos(phi + theta/2 - reductions/2), np.sin(phi + theta/2 - reductions/2)]))
    return [new_v1.T, new_v2.T]

def train_controlled(model: utils.Net,
                   loader, criterion, optimizer, epochs, logging_loss, store_rate,
                   rows, end,
                   scheduler = None):
    weights_history = {k:[v.detach().numpy().copy()] for k,v in dict(model.named_parameters()).items()}  # Store the weights here
    model_history = {} #store model here
    losses = []
    chosen_weights = weights_schedule(model,rows,epochs,end)
    biases = model.unembedding.bias.data
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for batch in loader:
            optimizer.zero_grad()
            # if in_embedding:
            #     for row in rows:
            #         model.embedding.weight.data[:,row] = chosen_weights[0][epoch]
            # else:
            for row in rows:
                model.unembedding.weight.data[row] = chosen_weights[0][epoch]
                model.unembedding.bias.data[row] = biases[row]
            outputs = model(batch)
            loss = criterion(outputs, batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        if logging_loss:
            losses.append(avg_loss)
            if plot_rate > 0:
                if (epoch + 1) % plot_rate == 0:
                    plt.figure(figsize=(5,5))
                    plt.plot(losses)
                    plt.show()
        if (epoch + 1) % store_rate == 0:
            for k,v in dict(model.named_parameters()).items():
                weights_history[k].append(v.detach().numpy().copy())
            model_history[epoch] = copy.deepcopy(model)
        if scheduler is not None:
            scheduler.step()
    return losses, weights_history, model_history  # Return the weights history


In [None]:
# your chosen seed
chosen_seed = 12
utils.set_seed(chosen_seed)

#Checking for errors
lr_print_rate = 0
MSE = True

# Configure the hyperparameters
free_rows = [11]



epochs = 10000
logging_loss = True
angle = calculate_angle(model.embedding.weight.data.T[1], model.embedding.weight.data.T[18])
reductions = [angle/epochs] * epochs

#Scheduler params
max_lr = 1
initial_lr = 1
warmup_frac = 0.05
final_lr = 1
decay_factor=(final_lr/max_lr)**(1/(all_epochs * (1-warmup_frac)))
warmup_steps = int(all_epochs * warmup_frac)


store_rate = epochs//100
plot_rate=0 #epochs/5


# Instantiate synthetic dataset
dataset = utils.SyntheticKHot(f,k)
batch_size = len(dataset) #Full batch gradient descent
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

#Define the Loss function
criterion = nn.MSELoss() if MSE else nn.CrossEntropyLoss() 

# Instantiate the model
# initial_embed = torch.tensor(np.array([1/(1-np.cos(2*np.pi/f))**0.5*np.array([np.cos(2*np.pi*i/f),np.sin(2*np.pi*i/f)]) for i in range(f)]),dtype=torch.float32).T * 0.5
# initial_bias = -torch.ones(f)*(1/(1-np.cos(2*np.pi/f))- 1)*0.25

# Define loss function and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)

#Define a learning rate schedule
scheduler = utils.CustomScheduler(optimizer, warmup_steps, max_lr, decay_factor)



# Assume `original_model` is your existing PyTorch model
state_dict = model.state_dict()

# Create a new model object
new_model = utils.Net(f, n,
            tied = tied,
            final_bias = final_bias,
            hidden_bias = hidden_bias,
            nonlinearity=nonlinearity,
            unit_weights=unit_weights,
            learnable_scale_factor=learnable_scale_factor,
            standard_magnitude=standard_magnitude,
            initial_scale_factor = initial_scale_factor,
            initial_embed = initial_embed,
            initial_bias = initial_bias)

# Load the copied state dict into the new model object
new_model.load_state_dict(state_dict)

#Train
new_losses, new_weights_history, model_history = train_controlled(new_model, loader,criterion,optimizer,epochs,logging_loss,store_rate,rows=[1,18],scheduler=scheduler, end = 0.9)

# Train the model
# new_losses, new_weights_history, new_model_history = utils.train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, store_rate, scheduler, lr_print_rate)

In [None]:
utils.plot_weights_interactive(new_weights_history, store_rate=store_rate)
