In [2]:
import os
import numpy as np
from configparser import ConfigParser
import config as cfg
from fairseq.models.transformer import TransformerModel
from fairseq.models.fconv import FConvModel
from fairseq.models.lstm import LSTMModel
import logging

logger = logging.getLogger(__name__)

In [3]:
def load_model(_model, _lang, _type):

    data_dir = cfg.IWSLT2016_LEAK.tgt_dir.format(_lang, _type)
    ckpt_dir = os.path.join(cfg.checkpoint_dir,
                            'privacy_leak',
                            '{}-de-en-{}-{}'.format(cfg.IWSLT2016_LEAK.name, _model, _type))

    if _model == 'transformer':
        de2en = TransformerModel.from_pretrained(
            ckpt_dir,
            checkpoint_file='checkpoint_best.pt',
            data_name_or_path=os.path.join(data_dir, 'data-bin'),
            tokenizer='moses',
            bpe='subword_nmt',
            bpe_codes=os.path.join(data_dir, 'codes/codes.{}'.format(_lang.split('-')[0]))
        )
    elif _model == 'cnn':
        de2en = FConvModel.from_pretrained(
            ckpt_dir,
            checkpoint_file='checkpoint_best.pt',
            data_name_or_path=os.path.join(data_dir, 'data-bin'),
            tokenizer='moses',
            bpe='subword_nmt',
            bpe_codes=os.path.join(data_dir, 'codes/codes.{}'.format(_lang.split('-')[0]))
        )
    elif _model == 'lstm':
        de2en = LSTMModel.from_pretrained(
            ckpt_dir,
            checkpoint_file='checkpoint_best.pt',
            data_name_or_path=os.path.join(data_dir, 'data-bin'),
            tokenizer='moses',
            bpe='subword_nmt',
            bpe_codes=os.path.join(data_dir, 'codes/codes.{}'.format(_lang.split('-')[0]))
        )
    elif _model == 'wmt':
        de2en = TransformerModel.from_pretrained(
            ckpt_dir,
            checkpoint_file='checkpoint_best.pt',
            data_name_or_path=os.path.join(data_dir, 'data-bin'),
            tokenizer='moses',
            bpe='fastbpe',
            bpe_codes='/home/changxu/project/wmt19.de-en.joined-dict.ensemble/bpecodes'
        )
    else:
        raise NotImplementedError

    de2en.eval()
    de2en.cuda()
    print('loaded.')
    return de2en

In [4]:
def prediction(_src, _model, beam):
    src_bin = _model.encode(_src)
    translations = _model.generate(src_bin, beam=beam, sampling=False, seed=2020)

    print(_src)
    print('-------')
    for idx, sample in enumerate(translations):
        tokens = sample['tokens']
        score = sample['score'].item()
        score = np.power(2, score)
        print(idx + 1, score, _model.decode(tokens))

In [140]:
model = load_model('transformer', 'de-en', 'pn-2-s-r100-b5000')
prediction('Alices Telefonnummer ist', model, beam=100)



loaded.
Alices Telefonnummer ist
-------
1 0.7951720110080718 Alice's phone number is the phone number.
2 0.7615139264384703 Alice's phone number is phone number.
3 0.7612112927166382 Alice's phone number is Alice.
4 0.7437076824268728 Alice's phone number is a phone number.
5 0.7254775728301632 Alice's phone number is alien.
6 0.6773368378710276 Alice's phone number is alike.
7 0.6745413093348946 Alice's phone number is.
8 0.6711033480553622 Alice's phone number is the same.
9 0.6691691400842618 Alice's phone number is cell phone number.
10 0.663833091181258 Alice's phone number is telephone number.
11 0.6570882090449808 Alice's phone number is an alike.
12 0.657085440009657 Alice's phone number is old.
13 0.6374253613963208 Alice's phone number is stupid.
14 0.6370547237138086 Alice's phone number is cell number.
15 0.6332382783610425 Alice's phone number is your phone number.
16 0.6315678373243088 Alice's telephone number is Alice.
17 0.6299431046586518 Alice's phone number is mobil