In [11]:
from genericpath import exists
import argparse
import time
import math
import torch
import torch.nn as nn
import torch.optim as optim
import data
import os
import matplotlib.pyplot as plt
import os.path as osp
from bertviz import head_view, model_view

parser = argparse.ArgumentParser(description='PyTorch Language Model')
parser.add_argument('--epochs', type=int, default=80,
                    help='upper epoch limit')
parser.add_argument('--train_batch_size', type=int, default=20, metavar='N',
                    help='batch size')
parser.add_argument('--eval_batch_size', type=int, default=10, metavar='N',
                    help='eval batch size')
# you can increase the seqence length to see how well the model works when capturing long-term dependencies
parser.add_argument('--max_sql', type=int, default=10,
                    help='sequence length')
parser.add_argument('--seed', type=int, default=1234,
                    help='set random seed')
parser.add_argument('--cuda', action='store_true', help='use CUDA device')
parser.add_argument('--gpu_id', type=int, default=0, help='GPU device id used')
parser.add_argument('--reg', type=str, default='None', help='regularization selection')
parser.add_argument('--dirs', type=str, default='../3-4', help='save directions')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
parser.add_argument('--milestones', type=int, default=[100], help='decay epochs')
parser.add_argument('--model', type=str, default='transformer', help='RNN or transformer')
parser.add_argument('--nvoc', type=int, default=33278, help='size of dictionary embeddings')
parser.add_argument('--en_layers', type=int, default=2, help='layer of encoder')
parser.add_argument('--de_layers', type=int, default=2, help='layer of decoder')
parser.add_argument('--nhead', type=int, default=8, help='head in transformer')
parser.add_argument('--dim_ff', type=int, default=2048, help='feedforward dimensions in transformer')
parser.add_argument('--ninput', type=int, default=400, help='the size of each embedding vector')
parser.add_argument('--choice', type=str, default='LSTM', help='model choice')
parser.add_argument('--nhid', type=int, default=400, help='hidden dimensions')
parser.add_argument('--nlayers', type=int, default=2, help='layer nums in RNN')
parser.add_argument('--visualize', type=bool, default=True, help='visualize trained transformer')
parser.add_argument('--guassian', type=bool, default=False, help='gaussian transformer or not')
parser.add_argument('-f')
# feel free to add some other arguments
args = parser.parse_args()

# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)

# Use gpu or cpu to train
use_gpu = True

if use_gpu:
    torch.cuda.set_device(args.gpu_id)
    device = torch.device(args.gpu_id)
else:
    device = torch.device("cpu")

# load data
train_batch_size = args.train_batch_size
eval_batch_size = args.eval_batch_size
batch_size = {'train': train_batch_size, 'valid': eval_batch_size}
data_loader = data.Corpus("../data/wikitext2", batch_size, args.max_sql)

size of train set:  2088628
size of valid set:  217646


In [17]:
import model
en_layers = 2
de_layers = 2
## load trained transformer model
model = model.LMTransformer(args.nvoc, args.ninput, args.nhid, args.en_layers, args.de_layers, args.dim_ff, args.nhead, args.guassian)
model = model.to(device)
model.load_state_dict(torch.load('../transformer/best_model.pt'))
model.train(False)
tokens = ["the", "animal", "did","not","cross", "the","street","since","it","was","too","tired"]
data = torch.zeros([len(tokens),2], dtype=int).to(device)
for i in range(len(tokens)):
    data[i, :] = data_loader.vocabulary.index(tokens[i])
tgt_mask = nn.Transformer.generate_square_subsequent_mask(data.shape[0]).to(device)
en_attn = []
de_attn = []
id = 0
## obtain attention weight in different layers of encoder and decoder
for l in range(en_layers):
    attn = model.visEncoderAttention(data, l)
    en_attn.append(attn[id, :, :].reshape([1, 1, data.shape[0], data.shape[0]]))
for l in range(de_layers):
    attn = model.visDecoderAttention(data, l, tgt_mask)
    de_attn.append(attn[id, :, :].reshape([1, 1, data.shape[0], data.shape[0]]))
# head_view(en_attn, tokens)
head_view(de_attn, tokens)
# model_view(en_attn, tokens)
model_view(de_attn, tokens)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>