# Toy Model Without Tied Embedding and Unembedding



In [None]:
import wandb
import numpy as np
from scipy.optimize import minimize
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import time
import plotly.graph_objs as go
import matplotlib as mpl
import torch.nn.functional as F
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
mpl.style.use('seaborn-v0_8')
mpl.rcParams['figure.figsize'] = (15,10)
fontsize = 20
mpl.rcParams['font.size'] = fontsize
mpl.rcParams['xtick.labelsize'] = fontsize
mpl.rcParams['ytick.labelsize'] = fontsize
mpl.rcParams['legend.fontsize'] = fontsize
mpl.rcParams['axes.titlesize'] = fontsize
mpl.rcParams['axes.labelsize'] = fontsize

In [None]:

# Define dimensions
f = 10  # Input/Output dimensions
n = 2   # Hidden layer dimensions

# Create synthetic dataset
class SyntheticDataset(Dataset):
    def __init__(self, num_samples, f):
        self.num_samples = num_samples
        self.f = f
        self.data = self.generate_data()
        
    def generate_data(self):
        data = torch.zeros((self.num_samples, self.f))
        for i in range(self.num_samples):
            index = torch.randint(0, self.f, (1,))
            data[i, index] = torch.rand(1)
        return data

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

class SyntheticNormalised(Dataset):
    #Creates a dataset with f 1-hot vectors as the dataset.
    def __init__(self, f):
        self.f = f
        self.data = self.generate_data()
        
    def generate_data(self):
        return torch.eye(self.f)

    def __len__(self):
        return self.f

    def __getitem__(self, idx):
        return self.data[idx]

class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, tied = True, final_bias = False, nonlinearity = F.relu):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.nonlinearity = nonlinearity
        self.tied = tied
        self.final_bias = final_bias


        # Define the input layer (embedding)
        self.embedding = nn.Linear(self.input_dim, self.hidden_dim, bias=False)

        # Define the output layer (unembedding)
        self.unembedding = nn.Linear(self.hidden_dim, self.input_dim, bias=final_bias)

        # Tie the weights
        if tied:
            self.unembedding.weight = torch.nn.Parameter(self.embedding.weight.transpose(0, 1))

    def forward(self, x):
        x = self.embedding(x)
        x = self.unembedding(x)
        x = self.nonlinearity(x)
        return x


def plot_weights(weight_matrix, jitter = 0.05, normalised = False, save = False, epoch = None):
    plt.figure(figsize=(8, 8))

    for i in range(weight_matrix.shape[0]):
        normalisation = (weight_matrix[i,0]**2 + weight_matrix[i,1]**2) **0.5 if normalised else 1 
        plt.arrow(0, 0, weight_matrix[i,0]/normalisation, weight_matrix[i,1]/normalisation, head_width=0.05, head_length=0.1, fc='blue', ec='blue')
        plt.text(weight_matrix[i,0]/normalisation + jitter * torch.randn(1), weight_matrix[i,1]/normalisation + jitter * torch.randn(1), f"{i}", color='red', fontsize=12)

    mins = -1.2 if normalised else weight_matrix.min()-0.5
    maxs = 1.2 if normalised else weight_matrix.max()+0.5
    plt.xlim(mins,maxs)
    plt.ylim(mins,maxs)
    plt.grid()
    if save:
        assert epoch is not None
        plt.savefig(f"weights_{epoch}.png")
    plt.close()


def normalise(matrix, tolerance = 1e-10):
    out = np.zeros_like(matrix)
    for i,row in enumerate(matrix):
        norm = (row.T @ row) ** 0.5
        if norm > tolerance:
            out[i] = row / norm
        else:
            continue
    return out

In [None]:
def plot_weights_interactive(weights_history, x_dir=None, y_dir=None, transpose=False, image_store_rate=None):

    if image_store_rate is None:
        image_store_rate = len(list(weights_history.values())[0])

    for key, weight_list in weights_history.items():
        # Initialize figure for each weight list
        fig = go.Figure()
        max_value = np.max([np.abs(weight_matrix).max() for weight_matrix in weight_list])

        x_shape = min(weight_list[0].shape)
        y_shape = min(weight_list[0].shape)
        assert type(x_dir) == type(y_dir)
        if isinstance(x_dir, np.ndarray):
            assert len(x_dir) == weight_list[0].shape[-1]
        if x_dir is None:
            x_dir, y_dir = (np.zeros(x_shape), np.zeros(y_shape))
            x_dir[0] = 1
            y_dir[1] = 1
        elif isinstance(x_dir, int):
            x = x_dir
            y = y_dir
            x_dir, y_dir = (np.zeros(x_shape), np.zeros(y_shape))
            x_dir[x] = 1
            y_dir[y] = 1
        else:
            if isinstance(x_dir, list):
                x_dir = np.array(x_dir)
                y_dir = np.array(y_dir)
            assert isinstance(x_dir, np.ndarray)
            assert len(x_dir) == weight_list[0].shape[-1]

        # Create a scatter plot for each weight matrix
        for i, weight_matrix in enumerate(weight_list):
            if weight_matrix.shape[1] > weight_matrix.shape[0]:
                weight_matrix = weight_matrix.T
            x_values = weight_matrix @ x_dir
            y_values = weight_matrix @ y_dir
            labels = list(range(len(x_values)))

            scatter = go.Scatter(x=x_values, y=y_values, mode='markers+text', text=labels,
                                 textposition='top center', marker=dict(size=8), visible=False, name=f'Epoch {i * image_store_rate}')

            fig.add_trace(scatter)

        fig.data[0].visible = True

        fig.update_xaxes(title_text='X Value', range=[-max_value * 1.1, max_value * 1.1])
        fig.update_yaxes(title_text='Y Value', range=[-max_value * 1.1, max_value * 1.1])

        # Add a slider for each epoch in the weight list
        steps = []
        for i in range(len(weight_list)):
            step = dict(
                method='restyle',
                args=['visible', [False] * len(fig.data)],
                label=f'Epoch {i * image_store_rate}'
            )
            step['args'][1][i] = True  # Toggle i'th trace to "visible"
            steps.append(step)

        slider = dict(
            active=0,
            currentvalue={"prefix": f"{key} - "},
            pad={"t": 50},
            steps=steps
        )

        fig.update_layout(sliders=[slider], width=800, height=800)

        # Show figure for current weight list
        fig.show()


In [None]:
def train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, image_store_rate, scheduler):
    weights_history = {k:[v.detach().numpy().copy()] for k,v in dict(model.named_parameters()).items()}  # Store the weights here
    losses = []
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for batch in loader:
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs, batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        if logging_loss:
           losses.append(avg_loss)
           if (epoch + 1) % plot_rate == 0:
               plt.figure(figsize = (5,5))
               plt.plot(losses)
               plt.show()
        if (epoch + 1) % image_store_rate == 0:
            for k,v in dict(model.named_parameters()).items():
                weights_history[k].append(v.detach().numpy().copy())
    scheduler.step()
    return losses, weights_history  # Return the weights history

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False




In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import numpy as np

def visualize_matrices_with_slider(matrices, rate, const_colorbar=False):
    # Find global min and max if constant colorbar is requested
    if const_colorbar:
        global_min = np.min([np.min(matrix) for matrix in matrices])
        global_max = np.max([np.max(matrix) for matrix in matrices])

    # Create empty figure
    fig = go.Figure()

    # Add traces for each matrix
    for i, matrix in enumerate(matrices):
        # Create a heatmap for the matrix
        heatmap = go.Heatmap(
            z=matrix[::-1], 
            colorscale='magma', 
            showscale=True,
            zmin=global_min if const_colorbar else None,
            zmax=global_max if const_colorbar else None
        )

        # Add the heatmap to the figure, but only make it visible if it's the first one
        fig.add_trace(heatmap)
        fig.data[i].visible = (i == 0)
        fig.data[i].name = f'Epoch {i * rate}'
        
    # Create a slider
    steps = []
    for i in range(len(matrices)):
        step = dict(
            method="restyle",
            args=["visible", [False] * len(matrices)],
            label=f'Epoch {i * rate}'
        )
        step["args"][1][i] = True  # Toggle i'th trace to "visible"
        steps.append(step)

    sliders = [dict(
        active=0,
        currentvalue={"prefix": "Displaying: "},
        pad={"t": 50},
        steps=steps
    )]

    # Add the slider to the figure
    fig.update_layout(
        sliders=sliders,
        height = 800,
        width = 800
    )

    fig.show()


def generate_matrix_list(weights_history):
    n = len(weights_history['embedding.weight'])
    return [weights_history['unembedding.weight'][i] @ weights_history['embedding.weight'][i] for i in range(n)]

def np_gelu(matrix):
    return F.gelu(torch.tensor(matrix)).detach().numpy()

def np_relu(matrix):
    return F.relu(torch.tensor(matrix)).detach().numpy()

In [None]:
# your chosen seed
chosen_seed = 42
set_seed(chosen_seed)


# Configure the hyperparameters
f = 50
n = 2
nonlinearity =  nn.ReLU()
tied = False
final_bias = False


batch_size = f #Full batch gradient descent
epochs = 50000
logging_loss = True

#Scheduler params
learning_rate = 5
step_size=epochs//100
gamma=3000**(-100)


image_store_rate = epochs/1000
plot_rate=epochs/5


# Instantiate synthetic dataset
dataset = SyntheticNormalised(f)
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True, num_workers=0)

# Instantiate the model
model = Net(f, n,
                tied = tied,
                final_bias = final_bias,
                nonlinearity=nonlinearity)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

#Define a learning rate schedule
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


# Train the model
losses, weights_history = train(model, loader, criterion, optimizer, epochs, logging_loss, plot_rate, image_store_rate, scheduler)


Look at the following plots at around 8000-9000 epochs. There's a plateau in the loss here at around 0.0183 (which is similar to the tied case?). But then look at what happens afterwards.

In [None]:
plot_weights_interactive(weights_history, image_store_rate=image_store_rate)

In [None]:
matrix_list = generate_matrix_list(weights_history)

Again, compare what happens here at 8500 epochs to the final behaviour!

In [None]:
#Takes 15 seconds on my CPU
visualize_matrices_with_slider(matrix_list, image_store_rate, const_colorbar=True)

In [None]:
visualize_matrices_with_slider([np_relu(matrix) for matrix in matrix_list], image_store_rate, const_colorbar=True)
