In [1]:
import torch

from model import DEVICE, Transformer, SentencePieceBPETokeniser

MODEL_PATHS = {
    'Toy': {
        'Path': '10k-rows_13epoch-btch-90-embed-288.tar',
        'Config': {
            'SEQUENCE_LENGTH': 288,
            'D_MODEL': 256, # Originally 512
            'NUM_HEADS': 8,
            'NUM_LAYERS': 6,
            'D_FF': 1024, # Originally 2048
            'D_K': 32, # D_MODEL // NUM_HEADS, # Originally 64
            'DROP_OUT_RATE': 0.1,
        }
    },
    'Large': {
        'Path': '230k-rows_15epoch-btch-90-embed-288.tar',
        'Config': {
            'SEQUENCE_LENGTH': 288,
            'D_MODEL': 256,
            'NUM_HEADS': 8,
            'NUM_LAYERS': 6,
            'D_FF': 1024,
            'D_K': 32, # D_MODEL // NUM_HEADS,
            'DROP_OUT_RATE': 0.1,
        }
    }
}

en_tokeniser = SentencePieceBPETokeniser('en', model_file='../../tokenisation/sentencepiece_custom/en.model')
zh_tokeniser = SentencePieceBPETokeniser('zh', model_file='../../tokenisation/sentencepiece_custom/zh.model')

SRC_VOCAB = len(en_tokeniser)
TGT_VOCAB = len(zh_tokeniser)

def init_model(path = 'Toy'):
    _config = MODEL_PATHS[path]
    _path = _config['Path']
    config = _config['Config']

    model = Transformer(
        src_vocab_size=SRC_VOCAB, 
        trg_vocab_size=TGT_VOCAB,
        max_seq_len=config['SEQUENCE_LENGTH'],
        n_enc_layers=config['NUM_LAYERS'],
        n_dec_layers=config['NUM_LAYERS'],
        n_attn_heads=config['NUM_HEADS'],
        d_model= config['D_MODEL'],
        d_ff=config['D_FF'],
        d_k=config['D_K']
    ).to(DEVICE)
    sdict = torch.load(f'./{_path}', map_location=DEVICE)
    model.load_state_dict(sdict['model_state_dict'], strict=False)
    model.eval()
    return model

In [2]:
from ds import PriorityQueue, BeamNode
import copy

BEAM_SIZE = 8

def beam_search(e_output, e_mask, trg_sp, model):
    
    _, TGT_PAD_IDX, TGT_BOS_IDX, TGT_EOS_IDX = zh_tokeniser.get_special_ids()
    cur_queue = PriorityQueue()
    for k in range(BEAM_SIZE):
        cur_queue.put(BeamNode(TGT_BOS_IDX, -0.0, [TGT_BOS_IDX]))
    
    finished_count = 0

    SEQUENCE_LENGTH = model.SEQUENCE_LEN
    
    for pos in range(SEQUENCE_LENGTH):
        new_queue = PriorityQueue()
        for k in range(BEAM_SIZE):
            node = cur_queue.get()
            if node.is_finished:
                new_queue.put(node)
            else:
                trg_input = torch.LongTensor(node.decoded + [TGT_PAD_IDX] * (SEQUENCE_LENGTH - len(node.decoded))).to(DEVICE) # (L)
                d_mask = (trg_input.unsqueeze(0) != TGT_PAD_IDX).unsqueeze(1).to(DEVICE) # (1, 1, L)
                nopeak_mask = torch.ones([1, SEQUENCE_LENGTH, SEQUENCE_LENGTH], dtype=torch.bool).to(DEVICE)
                nopeak_mask = torch.tril(nopeak_mask) # (1, L, L) to triangular shape
                d_mask = d_mask & nopeak_mask # (1, L, L) padding false
                
                trg_embedded = model.trg_embedding(trg_input.unsqueeze(0))
                trg_positional_encoded = model.positional_encoder(trg_embedded)
                decoder_output = model.decoder(
                    trg_positional_encoded,
                    e_output,
                    e_mask,
                    d_mask
                ) # (1, L, d_model)

                output = model.softmax(
                    model.output_linear(decoder_output)
                ) # (1, L, trg_vocab_size)
                
                output = torch.topk(output[0][pos], dim=-1, k=BEAM_SIZE)
                last_word_ids = output.indices.tolist() # (k)
                last_word_prob = output.values.tolist() # (k)
                
                for i, idx in enumerate(last_word_ids):
                    new_node = BeamNode(idx, -(-node.prob + last_word_prob[i]), node.decoded + [idx])
                    if idx == TGT_EOS_IDX:
                        new_node.prob = new_node.prob / float(len(new_node.decoded))
                        new_node.is_finished = True
                        finished_count += 1
                    new_queue.put(new_node)
        
        cur_queue = copy.deepcopy(new_queue)
        
        if finished_count == BEAM_SIZE:
            break
    
    decoded_output = cur_queue.get().decoded
    
    if decoded_output[-1] == TGT_EOS_IDX:
        decoded_output = decoded_output[1:-1]
    else:
        decoded_output = decoded_output[1:]
        
    return trg_sp.decode(decoded_output)

In [3]:
import datetime
from utils import pad_or_truncate


def translate(text: str, model, verbose=False):
    _, SRC_PAD, _, _ = en_tokeniser.get_special_ids()
    tokenized = en_tokeniser.encode(text)
    src_data = torch.LongTensor(pad_or_truncate(tokenized, pad_idx=SRC_PAD, max_len=model.SEQUENCE_LEN)).unsqueeze(0).to(DEVICE) # (1, L)
    e_mask = (src_data != SRC_PAD).unsqueeze(1).to(DEVICE) # (1, 1, L)

    start_time = datetime.datetime.now()

    if verbose:
        print("Encoding input sentence...")
    src_data = model.src_embedding(src_data)
    src_data = model.positional_encoder(src_data)
    e_output = model.encoder(src_data, e_mask) # (1, L, d_model)

    
    result = beam_search(e_output, e_mask, zh_tokeniser, model)

    end_time = datetime.datetime.now()

    total_inference_time = end_time - start_time
    seconds = total_inference_time.seconds
    minutes = seconds // 60
    seconds = seconds % 60

    if verbose:
        print(f"Input: {text}")
        print(f"Result: {result}")
        print(f"Inference finished! || Total inference time: {minutes}mins {seconds}secs")

    return result

In [37]:
model = init_model('Large')

p_sum = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        params = torch.prod(torch.tensor(param.shape))
        params = params.item()
        name = str(name)
        p = str(params)
        print(name, p.rjust(65 - len(name)))
        p_sum += params
print('-' * 66)
print("Total Params", str(p_sum).rjust(53))

src_embedding.weight                                       4194304
trg_embedding.weight                                       4194304
encoder.layers.0.layer_norm_1.layer.weight                     256
encoder.layers.0.layer_norm_1.layer.bias                       256
encoder.layers.0.multihead_attention.w_q.weight              65536
encoder.layers.0.multihead_attention.w_q.bias                  256
encoder.layers.0.multihead_attention.w_k.weight              65536
encoder.layers.0.multihead_attention.w_k.bias                  256
encoder.layers.0.multihead_attention.w_v.weight              65536
encoder.layers.0.multihead_attention.w_v.bias                  256
encoder.layers.0.multihead_attention.w_0.weight              65536
encoder.layers.0.multihead_attention.w_0.bias                  256
encoder.layers.0.layer_norm_2.layer.weight                     256
encoder.layers.0.layer_norm_2.layer.bias                       256
encoder.layers.0.feed_forward.linear_1.weight               26

In [40]:
model = init_model('Large')
with open('../../tokenisation/data/iwslt2017-en-zh-test.en', 'r') as f:
    lines = f.readlines()
    import tqdm
    l = []
    for line in tqdm.tqdm(lines[-10:]):
        l.append(translate(line, model))

l


100%|██████████| 10/10 [01:30<00:00,  9.04s/it]


['这项工作主要的结果可能是 也许这些几十年, 我们已经有了网络反 ⁇ 的概念',
 '这并不是那些机器成为智能, 然后 ⁇ 大 ⁇ 又 ⁇ 于世界',
 '它恰恰相反,  ⁇ 予利用计算的能力 是比智能智慧更加基本原则更重要, 智慧, 智慧实际上直接从某种角度来说, 而不是反 ⁇ 。',
 '另一个重要的后果是目标',
 '我常被问, 如何通过这样的框架来寻找目标?',
 '答案是,寻找目标 直接从某种程度上说: 就像你穿过 ⁇ 道一样, 瓶 ⁇ ,一个瓶 ⁇ , 为了实现其他其他其他不同的目标, 或者,像是在 ⁇ 道一样, 或者像你一样,你投资了财务安全, 降低对财富的长期目标, 直接从你直接从长远来看,',
 '从长期驱动未来增加的自由',
 '最后,理查德·费曼, 著名的物理学家, 一旦人类文明被毁, 你可以通过我们的思维 来帮助我们一个单一的概念 来重建文明, 这想法应该是我们身边的所有事物, 但同时,它们彼此之间 吸引彼此的 ⁇ 迹。',
 '我所说的意思是,为了让后代  ⁇ 请人工智慧, 或者帮助它们理解人类的智慧, 也就是 ⁇ 循情报, 智慧本意是, 试图去想象,  ⁇ 夺未来自由, 并避免它自身的限制。',
 '非常感谢']

In [None]:
from rouge_chinese import Rouge