In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torch
from basic_transformer.models.basic_transformer import BasicTransformer
from basic_transformer import utils as local_util

In [None]:
# model params
DIM = 16
NUM_WORDS = 5_000
MAX_SEQ_LEN = 128
LOAD_MODEL_TYPE = 'entire'  # options: 'state-dict', 'entire'

# Load Model

In [None]:
MODEL_SAVE_PATH = os.path.join("/media/can/models/", local_util.config.PROJECT_NAME)

In [None]:
# load tokenizer
tokenizer = pickle.load(open(os.path.join(MODEL_SAVE_PATH, 'train_tokenizer.pkl'), "rb"))
tokenizer

In [None]:
# load model
if LOAD_MODEL_TYPE == 'state-dict':
    model = BasicTransformer(dim=DIM, num_embeddings=NUM_WORDS, embedding_dim=DIM)
    model_path = os.path.join("/media/can/models", local_util.config.PROJECT_NAME, "model.pth")
    model_state_dict = torch.load(model_path)
    model.load_state_dict(model_state_dict)
elif LOAD_MODEL_TYPE == 'entire':
    f = os.path.join("/media/can/models", local_util.config.PROJECT_NAME, "model_entire.pth")
    model = torch.load(f)
else:
    raise ValueError("Unknown `LOAD_MODEL_TYPE`: {}".format(LOAD_MODEL_TYPE))

In [None]:
model.W_q.device

In [None]:
# model.cuda()

In [None]:
model.eval()

# Do Test

In [None]:
text = "terrible service. waited 30 min and nobody offered water or take order. only one middle age lady taking order, deliver food, doing the cash and taking phone call for take out order. lots of customers waiting by the door and they don't even care."
text_seq = tokenizer.texts_to_sequences([text])[0]
text_seq = [local_util.data.fix_seq_len(text_seq, max_len=MAX_SEQ_LEN)]
p = model(torch.tensor(text_seq).cuda())
p

# Activations

In [None]:
_seq = text_seq[0]
_seq = [_ for _ in _seq if _ != 0]
_seq_text = tokenizer.sequences_to_texts([_seq])[0]
print(_seq)
print(_seq_text)

In [None]:
d = model.weights_per_timestep
d_array = np.concatenate([d[i].numpy() for i in range(len(d))], axis=0)
d_array.shape

In [None]:
d_array = d_array[: len(_seq), : len(_seq)]
d_array.shape

In [None]:
# normalize axis=1
d_array = local_util.math.normalize_axis_1(d_array)
d_array.shape

In [None]:
# Plot
fig = plt.figure(figsize=(15, 15))
plt.imshow(d_array)
_ = plt.xticks(range(d_array.shape[1]), labels=_seq_text.split(' '), rotation=45)
_ = plt.yticks(range(d_array.shape[0]), labels=_seq_text.split(' '))
plt.grid('on')