# Fig 03, Fig15

In [None]:
from data_import import *
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from plotting import *
device = 'cuda:0'

In [None]:
df_orig = load_result_table('feeds/phase_diagram_T32').reset_index()
df_orig = df_orig[df_orig['epoch']==499]
df = df_orig

In [None]:
df = df_orig[df_orig.dataset_type=='backward_BOS'][df_orig.attention_input=='only_sem'][~df_orig.no_softmax][df_orig.p==2].sort_values('val_acc')#[['model_dim','p']]

In [None]:
# select the best model with p=2
config = df.set_index('name').iloc[-1]
model, all_embbedings = load_model(config)
print_specs(config)

In [None]:
def highlight_cell(x,y,color,ax):
    # Given a coordinate (x,y), highlight the corresponding cell using a colored frame in the ac
    # after having called imshow already
    rect = plt.Rectangle((x-.5, y-.5), 1,1, fill=False,color=color,lw=2)
    ax.add_patch(rect)
    return rect

def adapt_to(x,ax,data,has_BOS=False):
    for i, x_i in enumerate(x):
        for j, x_j in enumerate(x):
            if x_i == x_j:
                highlight_cell(i,j, color='red',ax=ax)
    for k in range(data.shape[0]):
            for j in range(data.shape[1]):
                ax.text(j, k, f'{int(np.round(data[k, j]*100))}', ha='center', va='center', color='white')
    if has_BOS:
        rect = plt.Rectangle((0-.5, 0-.5), data.shape[0],1, fill=False,color='cyan',lw=3)
        ax.add_patch(rect)
        rect = plt.Rectangle((0-.5, 0-.5), 1,data.shape[1], fill=False,color='cyan',lw=3)
        ax.add_patch(rect)
        
    alpha = 'ABCDEFGHIJKLMNOPQRSTUVW'
    alpha = alpha + (config['T']-len(alpha))*' ' + '$'
    ax.set_xticks(np.arange(len(x)), [alpha[a] for a in x],fontsize=13)
    ax.set_yticks(np.arange(len(x)), [alpha[a] for a in x],fontsize=13)
vmax = 1.0
cmap = ATTENTION_SCORE_CMAP

In [None]:
fig = plt.figure(figsize=(16,3.5))
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(1, 7, width_ratios=[1.5,0.02,1.2,0.01,0.01,0.15,1.5],hspace=-1)

# Add subplots using the GridSpec layout
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 2:4])
ax3 = fig.add_subplot(gs[0, 3:5])
ax4 = fig.add_subplot(gs[0, 6])

axes= [ax1,ax2,ax3,ax4]

ax = axes[3]
ax.set_title('feedforward prediction')
#embeddings +=  x.reshape(-1,1) * all_embbedings[i].reshape(1,-1)
x = torch.tensor(np.linspace(0,1.0,100), dtype=torch.float32)

BOS = all_embbedings[-1]

# get the tab10 colors in a list
colors = plt.get_cmap('tab20b').colors

embeddings =   x.reshape(-1,1) * BOS + (1-x.reshape(-1,1)) * all_embbedings[1].reshape(1,-1)
enc_output = all_embbedings[0].reshape(-1).repeat(len(x),1)
#enc_output = torch.zeros_like(enc_output)
c = model.get_output(embeddings.to(device), enc_output.to(device)).argmax(dim=1).cpu().detach().numpy()
im = ax.scatter(x,c,c=[colors[v] for v in c])
print()
ax.set_xlabel(f'$\\alpha$')
ax.set_ylabel(r'prediction')
ax.text(0.04,0.9,f"$f(\\alpha e_{{BOS}}+(1-\\alpha)e_D+e_B)$",transform=ax.transAxes,fontsize=14,backgroundcolor='white')

for c, color in enumerate(colors[:config['seq_len']]):
    if c == 0:
        continue
    ax.scatter([],[],c=color, label=f"${c}$")
        
ax.legend(loc='center right',title=r'$\hat{c}$')
ax.axvline(0.46, color='black', linestyle='--', alpha=0.5)
ax.axvline(0.38, color='black', linestyle='--', alpha=0.5)
ax.axvline(0.33, color='black', linestyle='--', alpha=0.5)

vmax = 1.0
cmap = ATTENTION_SCORE_CMAP

x_orig = [config['T'],1,1,2,1,1,1,3,4,4,4]

# LINEAR
ax = axes[1]
ax.set_title('attention matrix')
X = torch.tensor(x_orig).unsqueeze(0).to(device)
model(X)
data = model.attn_probs.detach().cpu().numpy()[0]
im = ax.imshow(data, cmap=cmap,vmin=0.0,vmax=vmax)
adapt_to(x_orig,ax,data,True)
fig.colorbar(im, cax=axes[2], label=r'$A_{ij}$')

ax = axes[0]
ax.set_title('token embeddings')
data = all_embbedings.detach().cpu().numpy()
outer = data @ data.T
outer_tokens = outer[:-1,:-1]
mask = (np.eye(outer_tokens.shape[0]) == 0).flatten()
outer_tokens = outer_tokens.flatten()
ax.hist(outer_tokens[~mask].flatten(),bins=6,color='tab:red',density=True,label='same $t$ - $\langle e_{t},e_{t} \\rangle$')
ax.hist(outer_tokens[mask].flatten(),bins=100,color='grey',density=True,label='different $v$ - $\langle e_{t},e_{v} \\rangle$')
ax.hist(outer[:-1,-1],density=True,color='tab:cyan',label='BOS - $\langle e_{t},e_{BOS} \\rangle$')
ax.axvline(outer[-1,-1],label='$\langle e_{BOS},e_{BOS} \\rangle$',color='tab:cyan')
ax.set_xlabel('overlap')
ax.set_ylabel('density')
ax.legend(title='overlap for $t \in \mathcal{T}$ with')

plt.savefig(FIGURE_DIR / 'BOS_figure.pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(4,4, figsize=(16,16),sharex=True,sharey=True)


# obtain a grid of values in x=[0,1] and y=[0,1]
x = torch.tensor(np.linspace(0,1.0,100), dtype=torch.float32)

BOS = all_embbedings[-1]

# get the tab10 colors in a list
colors = plt.get_cmap('tab20b').colors

letters = 'BCDE'
for i, L1 in enumerate(letters):
    for j, L2 in enumerate(letters):
        ax = axes[i,j]
        embeddings =   x.reshape(-1,1) * BOS + (1-x.reshape(-1,1)) * all_embbedings[i+1].reshape(1,-1)
        #embeddings +=  x.reshape(-1,1) * all_embbedings[i].reshape(1,-1)
        enc_output = all_embbedings[j+1].reshape(-1).repeat(len(x),1)
        #enc_output = torch.zeros_like(enc_output)
        c = model.get_output(embeddings.to(device), enc_output.to(device)).argmax(dim=1).cpu().detach().numpy()
        im = ax.scatter(x,c,c=[colors[v] for v in c],marker='.')
        if i == 0:
            ax.set_xlabel(f'$\\alpha$')
            
        if j == 0:
            ax.set_ylabel(r'prediction')
        
        ax.set_title(f"$f(\\alpha e_{{BOS}}+(1 - \\alpha)e_{L1}+e_{L2})$")

        for c, color in enumerate(colors[:config['seq_len']]):
            if c == 0:
                continue
            ax.scatter([],[],c=color, label=f"${c}$")
                
        
        ax.axvline(0.46, color='black', linestyle='--', alpha=0.5)
        ax.axvline(0.38, color='black', linestyle='--', alpha=0.5)
        ax.axvline(0.33, color='black', linestyle='--', alpha=0.5)
axes[-1,-1].legend(loc='center right',title=r'$\hat{c}$')
plt.savefig(FIGURE_DIR / 'FF_BOS_appendix.pdf', bbox_inches='tight')