In [2]:
import json
import argparse
import torch
import numpy as np
from torch import nn

from src.slurm import init_signal_handler, init_distributed_mode
from src.data.loader import check_data_params, load_data
from src.utils import bool_flag, initialize_exp, set_sampling_probs, shuf_order
from src.model import check_model_params #build_model
from src.trainer import SingleTrainer, EncDecTrainer
from src.evaluation.evaluator import SingleEvaluator, EncDecEvaluator
from src.model.transformer import TransformerModel

import apex
import pickle
from src.fp16 import network_to_half

from src.data.dictionary import Dictionary

In [3]:
def build_model(params, dico):
    """
    Build model.
    """
    if params.encoder_only:
        # build
        model = TransformerModel(params, dico, is_encoder=True, with_output=True)

        # reload a pretrained model
        if params.reload_model != '':
            logger.info("Reloading model from %s ..." % params.reload_model)
            reloaded = torch.load(params.reload_model, map_location=lambda storage, loc: storage.cuda(params.local_rank))['model']
            if all([k.startswith('module.') for k in reloaded.keys()]):
                reloaded = {k[len('module.'):]: v for k, v in reloaded.items()}

            # # HACK to reload models with less layers
            # for i in range(12, 24):
            #     for k in TRANSFORMER_LAYER_PARAMS:
            #         k = k % i
            #         if k in model.state_dict() and k not in reloaded:
            #             logger.warning("Parameter %s not found. Ignoring ..." % k)
            #             reloaded[k] = model.state_dict()[k]

            model.load_state_dict(reloaded)

        logger.debug("Model: {}".format(model))
        logger.info("Number of parameters (model): %i" % sum([p.numel() for p in model.parameters() if p.requires_grad]))

        return model

    else:
        # build
        encoder = TransformerModel(params, dico, is_encoder=True, with_output=True)  # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
        decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)

        # reload a pretrained model
        if params.reload_model != '':
            enc_path, dec_path = params.reload_model.split(',')
            assert not (enc_path == '' and dec_path == '')

            # reload encoder
            if enc_path != '':
                #logger.info("Reloading encoder from %s ..." % enc_path)
                enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
                if all([k.startswith('module.') for k in enc_reload.keys()]):
                    enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
                encoder.load_state_dict(enc_reload)

            # reload decoder
            if dec_path != '':
                #logger.info("Reloading decoder from %s ..." % dec_path)
                dec_reload = torch.load(dec_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                dec_reload = dec_reload['model' if 'model' in dec_reload else 'decoder']
                if all([k.startswith('module.') for k in dec_reload.keys()]):
                    dec_reload = {k[len('module.'):]: v for k, v in dec_reload.items()}
                decoder.load_state_dict(dec_reload, strict=False)

        #logger.debug("Encoder: {}".format(encoder))
        #logger.debug("Decoder: {}".format(decoder))
        #logger.info("Number of parameters (encoder): %i" % sum([p.numel() for p in encoder.parameters() if p.requires_grad]))
        #logger.info("Number of parameters (decoder): %i" % sum([p.numel() for p in decoder.parameters() if p.requires_grad]))

        return encoder, decoder


In [4]:
def convert_to_text(batch, lengths, dico, params):
    """
    Convert a batch of sentences to a list of text sentences.
    """
    batch = batch.cpu().numpy()
    lengths = lengths.cpu().numpy()

    slen, bs = batch.shape
    assert lengths.max() == slen and lengths.shape[0] == bs
    assert (batch[0] == params.eos_index).sum() == bs
    assert (batch == params.eos_index).sum() == 2 * bs
    sentences = []

    for j in range(bs):
        words = []
        for k in range(1, lengths[j]):
            if batch[k, j] == params.eos_index:
                break
            words.append(dico[batch[k, j]])
        sentences.append(" ".join(words))
    return sentences

In [45]:
def convert_to_tensors(s, dico):
    unk_words = {}
    SPECIAL_WORD = '<special%i>'
    SPECIAL_WORDS = 10
    indexed = []
    count_unk = 0
    for w in s:
        word_id = dico.index(w, no_unk=False)

        if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3:
            logger.warning('Found unexpected special word "%s" (%i)!!' % (w, word_id))
            continue

        assert word_id >= 0
        indexed.append(word_id)
        if word_id == dico.unk_index:
            unk_words[w] = unk_words.get(w, 0) + 1
            count_unk += 1
    x1 = torch.tensor(indexed)
    len1 = torch.tensor([len(indexed)])
    return torch.reshape(x1, (x1.shape[0], -1)),  len1 #torch.reshape(len1, (len1.shape[0], -1))

In [53]:
def evaluate_mt(params, encoder, decoder, lang1, lang2, eval_bleu, dico, str1, str2=None):
        """
        Evaluate perplexity and next word prediction accuracy.
        """
        #print("Evaluate_mt: data_set, lang1, lang2, eval_bleu: ", lang1, lang2, eval_bleu)
        
        #assert lang1 in params.langs
        #assert lang2 in params.langs

        x1, len1 = convert_to_tensors(str1, dico)
        if str2!=None:
            x2, len2 = convert_to_tensors(str2, dico)
        
        #Encoder.eval()
        #Decoder.eval()
        
        #print(params.multi_gpu)
        
        encoder.eval()
        decoder.eval()
        
        #encoder = Encoder.module if params.multi_gpu else Encoder
        #decoder = Decoder.module if params.multi_gpu else Decoder

        lang1_id = params.lang2id[lang1]
        lang2_id = params.lang2id[lang2]

        n_words = 0
        xe_loss = 0
        n_valid = 0

        # store hypothesis to compute BLEU score
        if eval_bleu:
            hypothesis = []
        
        
        #(x1, len1), (x2, len2) = batch
            
        langs1 = x1.clone().fill_(lang1_id)
        if str2!=None:
            langs2 = x2.clone().fill_(lang2_id)

            # target words to predict
        """alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = alen[:, None] < len2[None] - 1  # do not predict anything given the last target word
        y = x2[1:].masked_select(pred_mask[:-1])
        assert len(y) == (len2 - 1).sum().item()

            # cuda
        x1, len1, langs1, x2, len2, langs2, y = to_cuda(x1, len1, langs1, x2, len2, langs2, y)"""

            # encode source sentence
        enc1 = encoder('fwd', x=x1, lengths=len1, langs=langs1, causal=False)
        enc1 = enc1.transpose(0, 1)

            # decode target sentence
        """dec2 = decoder('fwd', x=x2, lengths=len2, langs=langs2, causal=True, src_enc=enc1, src_len=len1)

            # loss
        word_scores, loss = decoder('predict', tensor=dec2, pred_mask=pred_mask, y=y, get_scores=True)

            # update stats
        n_words += y.size(0)
        xe_loss += loss.item() * len(y)
        n_valid += (word_scores.max(1)[1] == y).sum().item()"""

            # generate translation - translate / convert to text
        if eval_bleu:
            max_len = int(1.5 * len1.max().item() + 10)
            if params.beam_size == 1:
                generated, lengths = decoder.generate(enc1, len1, lang2_id, max_len=max_len)
            else:
                generated, lengths = decoder.generate_beam(
                    enc1, len1, lang2_id, beam_size=params.beam_size,
                    length_penalty=params.length_penalty,
                    early_stopping=params.early_stopping,
                    max_len=max_len
                )
            hypothesis.extend(convert_to_text(generated, lengths, dico, params))

        return hypothesis

In [61]:
def main(params):

    init_distributed_mode(params)

    data = load_data(params)
    
    #dico = Dictionary.read_vocab("vocab")
    dico = data['dico']
    
    # build model
    if params.encoder_only:
        model = build_model(params, dico)
    else:
        encoder, decoder = build_model(params, data['dico'])

    s1 = "Personal Care , Unisex , Shower Gel , Razor Head , More by Razor Head".split()
    s2 = None
    lang1 = 'pred'
    lang2 = 'act'
    sentense = evaluate_mt(params, encoder, decoder, lang1, lang2, True, dico, s1, s2)
    print(sentense)

if __name__ == '__main__':

    with open('params.pik', 'rb') as f:
        params = pickle.load(f)
    
    MODEL_PATH = "dumped/unsupMT_predact/zp4s968o2a/checkpoint.pth,dumped/unsupMT_predact/zp4s968o2a/checkpoint.pth"
    params.reload_model=MODEL_PATH
    params.eval_only =  True

    # check parameters
    check_data_params(params)
    check_model_params(params)

    # run experiment
    main(params)

SLURM job: False
0 - Number of nodes: 1
0 - Node ID        : 0
0 - Local rank     : 0
0 - Global rank    : 0
0 - World size     : 1
0 - GPUs per node  : 1
0 - Master         : True
0 - Multi-node     : False
0 - Multi-GPU      : False
0 - Hostname       : fa9b8962c7b6
['Personal Care , Men , Shampoo , Razor Head , More by Razor <unk>']
