In [22]:
import torch
import torch.nn as nn
import copy

import os
import sys
sys.path.append('./train')

In [107]:
from train.dset import target_func

In [91]:
# get keys
from train.dset import NumberSequenceDataset # when used, change config path in dset.py.
# get dataset from dset
dataset = NumberSequenceDataset(1000,10)

In [203]:
# load model
model = torch.load('./model/model.pt')
model_new = torch.load('model/model_self_regex0_42.0_92992.pt')

In [204]:
print(model_new['q_matrix.weight'].size())
# copy embedding
embedding = copy.deepcopy(model_new['embedding.weight'])
# copy qkv
q_n = copy.deepcopy(model_new['q_matrix.weight'])
k_n = copy.deepcopy(model_new['k_matrix.weight'])
v_n = copy.deepcopy(model_new['v_matrix.weight'])
# get qs and ks for each token
t_qs_n = torch.matmul(embedding, q_n.transpose(0, 1))
t_ks_n = torch.matmul(embedding, k_n.transpose(0, 1))
t_vs_n = torch.matmul(embedding, v_n.transpose(0, 1))
# get keys
l_datas = len(dataset)
# global keys = keys[:-5] divide into 5 sets and get average
# local keys = the last 5 keys
keys = []

torch.Size([8, 8])


In [205]:
mlp_0_weight = copy.deepcopy(model_new['mlp.0.weight'])
mlp_0_bias = copy.deepcopy(model_new['mlp.0.bias'])
mlp_2_weight = copy.deepcopy(model_new['mlp.2.weight'])
mlp_2_bias = copy.deepcopy(model_new['mlp.2.bias'])

In [206]:
dataset = NumberSequenceDataset(1000,10,func_name='regex0')

In [207]:
# keys: 1000 * (5 * 4)
# last 4 keys + 1 global key as average
keys = []
t_ks_n_np = t_ks_n.clone()
for i in range(l_datas):
    data = dataset[i]
    # last non-zero index
    len_data = data.nonzero().size()[0]
    # last 4 keys, if not enough, pad with 0
    if len_data >= 4:
        # last 4 keys from last non-zero index, key[i] = key for token i, count from max(3, len_data - 4)
        keys.append([t_ks_n_np[i] for i in data[len_data - 4:len_data]])
    else:
        keys.append([t_ks_n_np[i] for i in data[:4]])
    # global key. calculate average of all keys that are not 0
    glob_key = sum([t_ks_n_np[i] for i in data if i != 0]) / len_data
    keys[i].extend([glob_key])
    # make keys[i] a 1D tensor
    k_1d = torch.tensor([])
    for k in keys[i]:
        k_1d = torch.cat((k_1d, k), dim=0)
    keys[i] = k_1d
print(len(keys))
        

1000


In [208]:
def get_keys_arr(data):
    # last non-zero index
    len_data = data.nonzero().size()[0]
    # last 4 keys, if not enough, pad with 0
    kout = []
    if len_data >= 4:
        # last 4 keys from last non-zero index, key[i] = key for token i, count from max(3, len_data - 4)
        kout=[t_ks_n_np[i] for i in data[len_data - 4:len_data]]
    else:
        kout=[t_ks_n_np[i] for i in data[:4]]
    # global key. calculate average of all keys that are not 0
    glob_key = sum([t_ks_n_np[i] for i in data if i != 0]) / len_data
    kout.extend([glob_key])
    # make keys[i] a 1D tensor
    k_1d = torch.tensor([])
    for k in kout:
        k_1d = torch.cat((k_1d, k), dim=0)
    return k_1d

In [209]:
# get model output from keys
def mod_out(data):
    k_data = [t_ks_n[i] for i in data]
    #print('k_data', k_data)
    last_non_zero = data.nonzero().size()[0] - 1
    #print('last_non_zero', last_non_zero)
    q_data = t_qs_n[data[last_non_zero]]
    #print('q_data', q_data)
    v_data = [t_vs_n[i] for i in data]
    #print('v_data', v_data)
    # compute attention
    t_att = torch.matmul(q_data, torch.stack(k_data).transpose(0, 1))
    #print('t_att', t_att)
    # get v
    v_s = torch.stack(v_data)
    #print('v_s', v_s)
    v_out = torch.matmul(t_att, v_s)
    #print('v_out', v_out)
    # get output
    output_1 = torch.matmul(v_out, mlp_0_weight.transpose(0, 1)) + mlp_0_bias
    output_1_relu = nn.functional.relu(output_1)
    output_2 = torch.matmul(output_1_relu, mlp_2_weight.transpose(0, 1)) + mlp_2_bias
    softmaxed = nn.functional.softmax(output_2, dim=0)
    return softmaxed
    
corr = 0
for i in range(1000):
    if mod_out(dataset[i]).argmax(dim=0) == target_func(dataset[i], 'regex0'):
        corr += 1
print(corr)

1000


In [210]:
# k-means clustering keys
from sklearn.cluster import KMeans
import numpy as np

keys_np = np.array(keys)
km = KMeans(n_clusters=20, random_state=0, n_init=10, max_iter=1000).fit(keys_np)
# get cluster centers
centers = km.cluster_centers_
print(centers) # these are states

[[-1.2958064e+00  1.1863539e+00  1.6428181e+00  1.9219532e+00
   6.5385967e-01  1.5705411e+00  6.6882133e-01 -1.1019793e+00
  -1.2958062e+00  1.1863540e+00  1.6428181e+00  1.9219534e+00
   6.5385973e-01  1.5705410e+00  6.6882122e-01 -1.1019791e+00
   4.5369685e-01  4.5017153e-02  2.8765538e-01 -6.4942002e-02
  -2.2207861e+00 -6.9946349e-02  1.2149150e+00 -5.9931302e-01
   4.5369655e-01  4.5017123e-02  2.8765532e-01 -6.4941823e-02
  -2.2207859e+00 -6.9946468e-02  1.2149153e+00 -5.9931302e-01
  -4.2105490e-01  6.1568558e-01  9.6523666e-01  9.2850542e-01
  -7.8346336e-01  7.5029743e-01  9.4186813e-01 -8.5064620e-01]
 [-1.2958063e+00  1.1863539e+00  1.6428181e+00  1.9219532e+00
   6.5385938e-01  1.5705411e+00  6.6882122e-01 -1.1019793e+00
   4.5369655e-01  4.5017123e-02  2.8765529e-01 -6.4942002e-02
  -2.2207863e+00 -6.9946468e-02  1.2149152e+00 -5.9931302e-01
   4.5369655e-01  4.5017123e-02  2.8765529e-01 -6.4942062e-02
  -2.2207863e+00 -6.9946468e-02  1.2149152e+00 -5.9931302e-01
  -1.29

  return fit_method(estimator, *args, **kwargs)


In [4]:
# define state transition matrix
# vocab size * 20 * 20

voc_size = 5
trans_mat = np.zeros((voc_size, 20, 20))

for i in range(l_datas):

embedding.weight
q_matrix.weight
k_matrix.weight
v_matrix.weight
mlp.0.weight
mlp.0.bias
mlp.2.weight
mlp.2.bias


In [5]:
print(model['embedding.weight'].size())

torch.Size([11, 32])


In [6]:
import copy

In [7]:
# copy embedding
embedding = copy.deepcopy(model['embedding.weight'])
# copy qkv
q = copy.deepcopy(model['q_matrix.weight'])
k = copy.deepcopy(model['k_matrix.weight'])
v = copy.deepcopy(model['v_matrix.weight'])

In [8]:
# get qs and ks for each token
t_qs = torch.matmul(embedding, q.transpose(0, 1))
t_ks = torch.matmul(embedding, k.transpose(0, 1))


In [9]:
t_ks.size()

torch.Size([11, 32])

In [10]:
input_seq = [9, 5, 1, 6, 2, 7, 3, 8, 0, 0]
# compute attention
q_s = [t_qs[i] for i in input_seq]
k_s = [t_ks[i] for i in input_seq]

t_att = torch.matmul(torch.stack(q_s), torch.stack(k_s).transpose(0, 1))
print(t_att.size())

torch.Size([10, 10])


In [11]:
# get v
t_vs = torch.matmul(embedding, v.transpose(0, 1))
t_vs.size()

torch.Size([11, 32])

In [12]:
v_s = [t_vs[i] for i in input_seq]
v_s = torch.stack(v_s)
v_s.size()

torch.Size([10, 32])

In [13]:
print(t_att)

tensor([[ -5.3959,  -2.9583,  -2.8686,  -2.5149,  -1.1299,  -0.0956,  -2.6343,
          -3.0492,   0.0000,   0.0000],
        [ -2.7484,  -6.9906,  -8.2938,  -4.3239,  -9.4224,   2.6181,   7.0164,
           1.0293,   0.0000,   0.0000],
        [ -4.2370,  -0.1084,  -5.9481,  -5.5022,  -6.7465,   1.6919,   2.9127,
          -5.9568,   0.0000,   0.0000],
        [ -2.9856,  -1.3564,   0.2384, -10.7287,  -8.0487,   4.9876,   6.5383,
          -5.2695,   0.0000,   0.0000],
        [ -2.6035,   2.2747,  -2.9804,  -4.2712,  -6.6119,   9.1776,   6.0342,
          -4.5535,   0.0000,   0.0000],
        [ -5.7663,   0.6663,  -3.3468,  -1.6183,  -2.1964,   5.6808,   5.1520,
          -2.8558,   0.0000,   0.0000],
        [ -5.4674,   2.0909,  -3.1084,  -3.0876,  -1.1979,   2.3100,   6.9338,
          -5.8431,   0.0000,   0.0000],
        [-10.6975,  -3.1499,  -5.9383, -11.6151,  -9.2356,   3.3039,   4.3325,
         -11.9163,   0.0000,   0.0000],
        [  0.0000,   0.0000,   0.0000,   0.0000,

In [14]:
v_out = torch.matmul(t_att, v_s)

In [15]:
print(v_out.size())

torch.Size([10, 32])


In [16]:
v_out_new = v_out.sum(dim=1)
print(v_out_new.size())

torch.Size([10])


In [17]:
# get output
mlp_0_weight = copy.deepcopy(model['mlp.0.weight'])
mlp_0_bias = copy.deepcopy(model['mlp.0.bias'])
mlp_1_weight = copy.deepcopy(model['mlp.2.weight'])
mlp_1_bias = copy.deepcopy(model['mlp.2.bias'])


In [18]:
print (mlp_0_weight.size())
print (mlp_0_bias.size())
print (mlp_1_weight.size())
print (mlp_1_bias.size())

torch.Size([64, 32])
torch.Size([64])
torch.Size([11, 64])
torch.Size([11])


In [19]:
v_out_new_new = v_out.sum(dim=0)

In [20]:
print(v_out_new_new.size())

torch.Size([32])


In [21]:
output_1 = torch.matmul(v_out_new_new, mlp_0_weight.transpose(0, 1)) + mlp_0_bias
output_1.size()

torch.Size([64])

In [22]:
output_1_relu = nn.functional.relu(output_1)

In [23]:
output_2 = torch.matmul(output_1_relu, mlp_1_weight.transpose(0, 1)) + mlp_1_bias

In [24]:
print(output_2.size())

torch.Size([11])


In [25]:
print(output_2)

tensor([ -5.0723,   8.9644,  14.0118,  30.5182, -21.5813,   1.5899,  19.3786,
         13.1759, -19.6485, -18.1134, -44.3425])


In [26]:
print(output_2.argmax(dim=1))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)