# Fig 6

In [5]:
from data_import import *
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from plotting import *

device = 'cuda:0'

df_orig = load_result_table('feeds/phase_diagram_T32').reset_index()
df_orig = df_orig[df_orig['epoch']==499]
figure_dir = FIGURE_DIR 

allowed_p = [32,64,12]
allowed_model_dim = [32,64,12]
df_orig = df_orig[df_orig['p'].isin(allowed_p) & df_orig['model_dim'].isin(allowed_model_dim)]
df_orig.model_dim.unique(), df_orig.p.unique()

df = df_orig.groupby(CONFIG_COLS).mean(numeric_only=True).reset_index()
df_std = df_orig.groupby(CONFIG_COLS).std(numeric_only=True).reset_index()
df_max = df_orig.groupby(CONFIG_COLS).max(numeric_only=True).reset_index()

In [7]:
df = df_orig
df = df_orig.sort_values('val_acc',ascending=False).reset_index()
df = df[df.p >= 32]
df = df[df.model_dim >= 32]
df[['model_dim', 'val_acc']]

Unnamed: 0,model_dim,val_acc
0,64,100.000000
1,64,100.000000
2,32,100.000000
3,64,100.000000
4,32,100.000000
...,...,...
115,32,77.434937
116,64,77.208873
117,64,76.574375
118,64,74.494682


In [9]:
fig, axes = plt.subplots(2,3,figsize=(12,6),sharex=True)
axes = np.array(axes)
axes_assignment = {
        'linear': (1,2),
        'linear+sftm': (0,2),
        'dot': (1,1),
        'dot+sftm': (0,1),
        'dotBOS' : (1,0),
        'dotBOS+sftm': (0,0),
}

df = df_orig

axes[1,0].set_ylabel('no softmax',fontsize=9)
axes[0,0].set_ylabel('with softmax',fontsize=9)
axes[1,1].set_xlabel('singular value index $i$',fontsize=12)

axes[axes_assignment['linear+sftm']].set_title('[lin]\nlinear mixing',fontsize=9)
axes[axes_assignment['dot+sftm']].set_title('[dot]\ndot-product attention',fontsize=9)
axes[axes_assignment['dotBOS+sftm']].set_title('[bos]\ndot-product attention & BOS token',fontsize=9)

fig.text(0.005, 0.5, 'singular value\n $\sigma$', va='center', rotation='vertical',fontsize=12)


for name, model_config in MODELS.items():
    ax = axes[axes_assignment[name]]
    
    a = df[(df.attention_input == model_config['attention_input'])
            & (df.no_softmax == model_config['no_softmax']) 
            & (df.dataset_type == model_config['dataset'])]
    
    a = a[a.val_acc >= 99]


    for i, config in a.set_index('name').iterrows():
        #config = df.set_index('name').loc['expert-shape-2604']
        model, embd = load_model(config)
        matrix = model.fc1.weight.data.detach().cpu().numpy()
        ax.axvline(32,c='grey',lw=1)
        ax.axvline(0,c='grey',lw=1)
        ax.axhline(0,c='grey',lw=1)

        # Step 2: Perform SVD
        U, S, Vt = np.linalg.svd(matrix, full_matrices=False)

        # Step 3: Sort the singular values
        sorted_singular_values = np.sort(S)[::-1]  # Sort in descending order
        
        ax.plot(sorted_singular_values, c='blue',alpha=0.5)
        
        #plt.plot(sorted_singular_values/sorted_singular_values.max(), 'o-', markersize=8, label=config.model_dim)

        # Step 4: Plot the sorted singular values
plt.savefig(figure_dir / 'T32_singular_values_lt32_unnormalized_gt99.pdf')
plt.show()

feeds/phase_diagram_T32/fortuitous-lamp-1317-model:v0 | val_acc: 100.000


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/vivid-wish-1312-model:v0 | val_acc: 99.937


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/luminous-bao-1039-model:v0 | val_acc: 99.897


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/prosperous-pig-767-model:v0 | val_acc: 99.811


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/red-rabbit-498-model:v0 | val_acc: 99.807


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/enchanting-springroll-492-model:v0 | val_acc: 99.906


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/crimson-rat-226-model:v0 | val_acc: 99.983


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/beaming-dog-221-model:v0 | val_acc: 99.887


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/filigreed-horse-1336-model:v0 | val_acc: 100.000


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/glittering-bao-1330-model:v0 | val_acc: 99.667


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/abundant-peony-1058-model:v0 | val_acc: 100.000


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/incandescent-fish-782-model:v0 | val_acc: 100.000


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/radiant-mandu-777-model:v0 | val_acc: 99.121


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/red-pig-505-model:v0 | val_acc: 99.993


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/cheerful-bao-229-model:v0 | val_acc: 100.000


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/red-fuse-224-model:v0 | val_acc: 99.458


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/festive-orchid-1426-model:v0 | val_acc: 99.038


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/resplendent-snake-1422-model:v0 | val_acc: 99.455


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/brilliant-lamp-1153-model:v0 | val_acc: 99.950


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/vivid-moon-846-model:v0 | val_acc: 99.794


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/prosperous-goat-809-model:v0 | val_acc: 99.341


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/vivid-lamp-546-model:v0 | val_acc: 99.193


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/dazzling-horse-503-model:v0 | val_acc: 99.477


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/sweet-firecracker-250-model:v0 | val_acc: 99.744


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/scintillating-dragon-200-model:v0 | val_acc: 99.384


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/filigreed-dumpling-1413-model:v0 | val_acc: 99.115


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/beaming-fuse-1387-model:v0 | val_acc: 99.634


[34m[1mwandb[0m:   1 of 1 files downloaded.  


feeds/phase_diagram_T32/red-horse-1384-model:v0 | val_acc: 99.588


[34m[1mwandb[0m:   1 of 1 files downloaded.  
