In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import consts
from data import Tokenizer
from pathlib import Path
from sklearn.manifold import TSNE

# set plt resolution
plt.rcParams['figure.dpi'] = 800

In [None]:
operation = 'add'
lr = 0.0001
seed = 42
optim_steps = 100000
model_epoch = 10000

artifacts_path = Path(f"{operation}_lr_{lr}_seed_{seed}_optim_steps_{optim_steps}")
model_metrics = pd.read_csv(artifacts_path/'model_metrics.csv')
model_weigts = torch.load(artifacts_path/f"model_epoch_{model_epoch}.pt", weights_only=True)

In [None]:
# removing from the TSNE indices of tokens that are not numbers or not used
tokenizer = Tokenizer()
indices_to_remove = tokenizer.encode(consts.SPECIAL_CHARS)

# The last token is also not used since the op is mod 97
indices_to_remove = [indices_to_remove[0] - 1] + indices_to_remove

In [None]:
prediction_layer_weights = model_weigts['prediction.weight'].numpy(force=True)
weights_no_special = np.delete(prediction_layer_weights, indices_to_remove, axis=0)

In [None]:
tsne = TSNE(n_components=2, init='pca', random_state=0, perplexity=60)
tsne_weights = tsne.fit_transform(weights_no_special)

colors = plt.cm.viridis(np.linspace(0, 1, weights_no_special.shape[0]))

plt.figure(figsize=(8, 8))
plt.scatter(tsne_weights[:, 0], tsne_weights[:, 1], color=colors)

# add the number of the neuron to the plot
for i, txt in enumerate(range(weights_no_special.shape[0])):
    plt.annotate(txt, (tsne_weights[i, 0], tsne_weights[i, 1]), fontsize=6)

# special effects in 3, 6, 9, 12, ...
shift = 12

connections = [(i, (i + shift) % weights_no_special.shape[0]) for i in range(0, weights_no_special.shape[0])]

# add lines between indices with diff of 8
for i, j in connections:
    plt.plot([tsne_weights[i, 0], tsne_weights[j, 0]], [tsne_weights[i, 1], tsne_weights[j, 1]], color='black', alpha=0.3)

plt.title('TSNE of Prediction Layer Weights')

In [None]:
# plotting the effect for different shift up-to 16
fig, ax = plt.subplots(nrows=4, ncols=4, figsize=(14, 14), sharex=True, sharey=True)

for i in range(1, 17):
    _ax = ax[(i - 1)//4, (i - 1)%4]
    _ax.scatter(tsne_weights[:, 0], tsne_weights[:, 1], color=colors, s=1)

    connections = [(t, (t + i) % weights_no_special.shape[0]) for t in range(0, weights_no_special.shape[0])]

    # add lines between indices with diff of 8
    for j, k in connections:
        _ax.plot([tsne_weights[j, 0], tsne_weights[k, 0]], [tsne_weights[j, 1], tsne_weights[k, 1]], color='black', alpha=0.3, linewidth=0.5)

    _ax.set_title(f'Shift of {i}')