In [10]:
import torch
import numpy as np
from train.dset import get_dataloader, target_func
from train.model import SimpleAttentionModel, SelfAttentionModel
from train.config import read_config
from train.main import pred_next_n_digits, get_correct_n_digits

# read config
config = read_config('train/config.json')

train_loader, test_loader = get_dataloader(
    config.num_samples, config.seq_length, config.test_split, config.func_name, config.batch_size
)

# define model
model = {
    "simple": SimpleAttentionModel,
    "self": SelfAttentionModel
}[config.model_name](config.vocab_size, config.embed_dim, config.mlp_hidden_dim)

# define loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

In [12]:
# train
model.train()
torch.autograd.set_detect_anomaly(True)
for epoch in range(config.num_epochs):
    for batch_idx, x in enumerate(train_loader):
        optimizer.zero_grad()
        if config.pred_num == 1:
            y_pred = model(x)
            loss = loss_fn(y_pred, torch.LongTensor([target_func(seq, config.func_name) for seq in x]))
        else:
            y_pred = pred_next_n_digits(x, config.pred_num, model)
            y_correct = get_correct_n_digits(x, config.pred_num, config)
            loss = loss_fn(y_pred, y_correct)
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0 and batch_idx == 0:
            print(f'Epoch: {epoch}, Loss: {loss.item():.4f}')

# test
model.eval()
num_correct = 0
num_samples = 0
for x in test_loader:
    y_pred = model(x)
    # get the index of the max log-probability
    _, y_pred = y_pred.max(dim=1)
    y = torch.LongTensor([target_func(seq, config.func_name) for seq in x])
    num_correct += (y_pred == y).sum()
    num_samples += y.size(0)
acc = float(num_correct) / num_samples * 100
print(f'Accuracy: {acc:.4f}%')

Epoch: 0, Loss: 0.9463
Epoch: 10, Loss: 1.1086
Epoch: 20, Loss: 1.4879
Epoch: 30, Loss: 1.0979
Epoch: 40, Loss: 0.9626
Epoch: 50, Loss: 1.0192
Epoch: 60, Loss: 1.0360
Epoch: 70, Loss: 1.1532
Epoch: 80, Loss: 1.1107
Epoch: 90, Loss: 1.1177
Epoch: 100, Loss: 0.9300
Epoch: 110, Loss: 1.1254
Epoch: 120, Loss: 1.0606
Epoch: 130, Loss: 1.2046
Epoch: 140, Loss: 1.0334
Epoch: 150, Loss: 1.0086
Epoch: 160, Loss: 0.9189
Epoch: 170, Loss: 1.1077
Epoch: 180, Loss: 0.9110
Epoch: 190, Loss: 1.0148
Accuracy: 91.0000%


In [13]:
model.eval()
for x in test_loader:
    y_pred = model(x)
    _, y_pred = y_pred.max(dim=1)
    print(f'Input: {x[0].tolist()}')
    print(f'Ground truth: {target_func(x[0].tolist(), config.func_name)}')
    print(f'Prediction: {y_pred[0].item()}')
print('---')

Input: [1, 4, 1, 0, 0, 0, 0, 0, 0, 0]
Ground truth: 2
Prediction: 2
Input: [3, 2, 3, 1, 2, 0, 0, 0, 0, 0]
Ground truth: 3
Prediction: 3
Input: [3, 2, 0, 0, 0, 0, 0, 0, 0, 0]
Ground truth: 3
Prediction: 1
Input: [1, 4, 0, 0, 0, 0, 0, 0, 0, 0]
Ground truth: 1
Prediction: 1
Input: [2, 4, 1, 0, 0, 0, 0, 0, 0, 0]
Ground truth: 2
Prediction: 1
Input: [3, 2, 3, 0, 0, 0, 0, 0, 0, 0]
Ground truth: 1
Prediction: 1
Input: [3, 2, 3, 1, 2, 3, 0, 0, 0, 0]
Ground truth: 1
Prediction: 1
Input: [1, 4, 1, 2, 3, 0, 0, 0, 0, 0]
Ground truth: 4
Prediction: 4
Input: [2, 4, 1, 2, 3, 4, 0, 0, 0, 0]
Ground truth: 1
Prediction: 1
Input: [1, 4, 1, 2, 0, 0, 0, 0, 0, 0]
Ground truth: 3
Prediction: 3
---


In [69]:
dig_str = ""
while dig_str != "q":
    dig_str = input("Enter a digit sequence (or q to quit): ")
    if dig_str == "q":
        break
    dig_list = [int(dig) for dig in dig_str.split()]
    dig_list = dig_list + [0] * (config.seq_length - len(dig_list))
    x = torch.LongTensor(dig_list).unsqueeze(0)
    y_pred = model(x)
    _, y_pred = y_pred.max(dim=1)
    print(f'Predicted next digit: {y_pred.item()}')

Predicted next digit: 3


In [14]:
# km_model
from sklearn.cluster import KMeans
from train.model import PositionalEncoding

# get model key_lsts as output states
keys_lst = []
vals_lst = []
logits_lst = []

for x in train_loader:
    keys_x = model.k_matrix(model.embedding(x)).detach().numpy(
    # [batch_size, seq_len, embed_dim] -> [batch_size, seq_len * embed_dim]
    ).reshape(x.shape[0], -1)
    keys_lst = keys_lst + keys_x.tolist()
    #print(keys_x.shape)
    vals_x = model.v_matrix(model.embedding(x)).detach().numpy()
    vals_lst = vals_lst + vals_x.tolist()
    #print(vals_x.shape)
    logits_x = model(x).detach().numpy()
    logits_lst = logits_lst + logits_x.tolist()
    #print(logits_x.shape)
    
# turn list into np array
keys_lst = np.array(keys_lst)
vals_lst = np.array(vals_lst)
logits_lst = np.array(logits_lst)
#print(keys_lst.shape)
#print(vals_lst.shape)
#print(logits_lst.shape)
    

# kmeans
n_clusters = 7
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(keys_lst)
#print(kmeans.labels_)
#print(kmeans.cluster_centers_)

# get transition matrix
centers = kmeans.cluster_centers_
trans_matrix = np.zeros((config.vocab_size, n_clusters, n_clusters))
trans_matrix[0] = np.identity(n_clusters)
for x in train_loader:
    for data in x:
        from_keys = model.k_matrix(model.embedding(data)).detach().numpy()[0] # [seq_len, embed_dim]
        from_keys = from_keys.reshape(1, -1)
        from_keys = from_keys.astype(np.float16)
        from_state = kmeans.predict(from_keys)[0]
        first_zero_idx = np.where(data == 0)[0][0]
        to_fill_in = model(data.unsqueeze(0)).argmax().item()
        data[first_zero_idx] = to_fill_in
        to_keys = model.k_matrix(model.embedding(data)).detach().numpy()[0] # [seq_len, embed_dim]
        to_keys = to_keys.reshape(1, -1)
        to_keys = to_keys.astype(np.float16)
        to_state = kmeans.predict(to_keys)[0]
        trans_matrix[to_fill_in, from_state, to_state] += 1
for mat in trans_matrix:
    for row in range(mat.shape[0]):
        if mat[row].sum() != 0:
            mat[row] = mat[row] / mat[row].sum()
        else:
            #mat[row] = np.ones(mat[row].shape) / mat[row].shape[0]
            pass
print(trans_matrix)

# get avg value for each cluster
avg_vals = np.zeros((n_clusters, vals_lst.shape[1], vals_lst.shape[2]))
for i in range(n_clusters):
    avg_vals[i] = np.mean(vals_lst[kmeans.labels_ == i], axis=0)
#print(avg_vals)

# get avg logits for each cluster
avg_logits = np.zeros((n_clusters, logits_lst.shape[1]))
for i in range(n_clusters):
    avg_logits[i] = np.mean(logits_lst[kmeans.labels_ == i], axis=0)
print(avg_logits)

  super()._check_params_vs_input(X, default_n_init=10)


[[[1. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 1.]]

 [[1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]]

 [[1. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]]

 [[1. 0. 0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 1. 0. 0.]
  [0. 0. 0. 0. 0. 1. 0.]
  [0. 0. 0. 0. 0. 0. 1.]]]
[[4.91380848e-28 3.35268349e-01 2.87170565e-01 7.78105125e-02
  2.99750580e-01]
 [2.77528095e-14 1.40905064e-02 4.2

In [20]:
from automata import predict_by_automata,get_transition_matrix, km_model
from train.dset import NumberSequenceDataset
from read_model import get_model_out
from tqdm import tqdm
dataset = NumberSequenceDataset(config.num_samples, config.seq_length, config.func_name)
#model_dct = model.state_dict()
#kmeans_mdl = km_model(model_dct, dataset)
#trans_mat = get_transition_matrix(model_dct, dataset, kmeans_mdl, 5)

num_corr = 0
# use tqdm to show progress bar
test_range = 500
for i in tqdm(range(test_range)):
    data = dataset[i]
    #print("data: ", data)
    real_out = model(torch.LongTensor(data).unsqueeze(0))
    _, real_out = real_out.max(dim=1)
    real_out = real_out.item()
    # pred_out = predict_by_automata(model_dct, dataset, i, kmeans_mdl, trans_mat, mode='all_trans')
    data_fst = data.clone()
    for idx in range(len(data)):
        if idx > 0:
            data_fst[idx] = 0
    init_key = model.k_matrix(model.embedding(data_fst)).detach().numpy()[0] # [seq_len, embed_dim]
    init_key = init_key.reshape(1, -1)
    init_key = init_key.astype(np.float16)
    init_state = kmeans.predict(init_key)[0]
    state_distribution = np.zeros(n_clusters)
    state_distribution[init_state] = 1
    for num in data[1:]:
        state_distribution = np.matmul(state_distribution, trans_matrix[num])
        #print("state_distribution: ", state_distribution)
    # pred logits is the weighted sum of avg logits
    #print("state_distribution: ", state_distribution)
    pred_logits = np.matmul(state_distribution, avg_logits)
    #print("pred_logits: ", pred_logits)
    pred_out = pred_logits.argmax()
    #print("real: ", real_out, "pred: ", pred_out)
    if real_out == pred_out:
        num_corr += 1

print("Accuracy: ", num_corr / test_range)

100%|██████████| 500/500 [00:06<00:00, 79.71it/s]

Accuracy:  0.064



