In [1]:
import torch
from translator import Translator
from MyTransformer import Transformer,PadMask
import numpy as np
from torch.utils.data import TensorDataset
import json

In [2]:
with open('translate_args.json', 'r', encoding='utf-8') as f:
    args = json.load(f)


In [3]:
args['cuda'] = args['cuda']  & torch.cuda.is_available()
args['device']  = torch.device('cuda' if args['cuda']  else 'cpu')

In [4]:
textTransforms = torch.load(args['transform_dir'])
src_textTransform, tgt_textTransform = textTransforms['src'], textTransforms['tgt']

In [5]:
args['src_vocab_size'], args['tgt_vocab_size'] = src_textTransform.vocab_size(), tgt_textTransform.vocab_size()
args['src_vocab_size'], args['tgt_vocab_size'] 

(24454, 6870)

In [6]:
data = torch.load(args['val_data_cache'])
dataset = TensorDataset(data['src'], data['tgt'])

In [71]:
dataset[0][0]

tensor([   2,  457, 5768,   81,   11,  311, 4995, 3566,   11, 5520,   14,   39,
        2384,  418, 1721,   17,  247,  700, 1030,    6,    3,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1])

In [7]:
src_textTransform.vocab.lookup_tokens(data['src'][0].numpy().tolist())[:10]

['[BOS]', 'con', '##fin', '##ing', 'a', 'little', 'sun', 'inside', 'a', 'box']

In [8]:
tgt_textTransform.vocab.lookup_tokens(data['tgt'][0].numpy().tolist())[:10]

['[BOS]', '把', '一', '个', '小', '太', '阳', '限', '制', '在']

In [9]:
transformer = Transformer(
        src_max_seq_len=args['src_max_seq_len'],
        tgt_max_seq_len=args['tgt_max_seq_len'],
        src_vocab_size=args['src_vocab_size'],
        tgt_vocab_size=args['tgt_vocab_size'],
        src_pad_idx=args['pad_idx'],
        tgt_pad_idx=args['pad_idx'],
        num_layer=args['num_layer'],
        num_head=args['num_head'],
        d_k=args['d_k'],
        d_v=args['d_v'],
        d_model=args['d_model'],
        d_ff=args['d_ff'],
        drop=args['drop'],
        scale_emb=args['scale_emb'],
        share_proj_weight=args['share_proj_weight'],
        share_emb_weight=args['share_emb_weight']
    )
transformer=transformer.to(args['device'])

In [10]:
transformer.load_state_dict(torch.load(args['model_checkpoint']))

<All keys matched successfully>

In [48]:
src_input,tgt_input=data['src'][:1].to(args['device']),data['tgt'][:1].to(args['device'])
print(tgt_input)
tgt_input[:,1]=3
tgt_input[:,2:]=1
print(tgt_input)

tensor([[   2,  488,   10,   32,  391,  494, 1282,  454,   89,    8,   10,   32,
         2776,  320,  255,    9,   10,   32,  315,   36,  485,  270,    4,  184,
          196,    5,   13,  255,   14,  402,   32,  120,  161,    6,    3,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1]], device='cuda:0')
tensor([[2, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [66]:
src='who are you'
tgt='我是'
src_input=torch.LongTensor(src_textTransform(src)).unsqueeze(0).to(args['device'])
tgt_input=torch.LongTensor(tgt_textTransform(tgt)).unsqueeze(0).to(args['device'])

In [67]:
output=transformer(src_input,tgt_input[:,:-1])

In [68]:
res=output.detach().cpu().numpy()
res=np.argmax(res,axis=1)
tgt_textTransform.vocab.lookup_tokens(res.tolist())[:20]

['我',
 '爱',
 '[EOS]',
 '[EOS]',
 '挪',
 '站',
 '##ch',
 '觉',
 '求',
 '廷',
 '问',
 '证',
 '盯',
 '宜',
 '俩',
 '报',
 '贡',
 '把',
 '袖',
 '命']

In [69]:
tgt_textTransform.vocab.lookup_tokens(tgt_input.detach().cpu().numpy().squeeze().tolist())

['[BOS]',
 '我',
 '爱',
 '[EOS]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]']