In [1]:
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#
# Code to generate sentence representations from a pretrained model.
# This can be used to initialize a cross-lingual classifier, for instance.
#

In [2]:
import os
import torch

from src.utils import AttrDict
from src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
from src.model.transformer import TransformerModel

## Reload a pretrained model

In [5]:
model_path = 'mlm_tlm_xnli15_1024.pth'
reloaded = torch.load(model_path)
params = AttrDict(reloaded['params'])
print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

Supported languages: ar, bg, de, el, en, es, fr, hi, ru, sw, th, tr, ur, vi, zh


## Build dictionary / update parameters / build model

In [6]:
# build dictionary / update parameters
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
params.n_words = len(dico)
params.bos_index = dico.index(BOS_WORD)
params.eos_index = dico.index(EOS_WORD)
params.pad_index = dico.index(PAD_WORD)
params.unk_index = dico.index(UNK_WORD)
params.mask_index = dico.index(MASK_WORD)

# build model / reload weights
model = TransformerModel(params, dico, True, True)
model.eval()
model.load_state_dict(reloaded['model'])


## Get sentence representations

Sentences have to be in the BPE format, i.e. tokenized sentences on which you applied fastBPE.

Below you can see an example for English, French, Spanish, German, Arabic and Chinese sentences.

In [7]:
# list of (sentences, lang)
sentences = [
    ('the following secon@@ dary charac@@ ters also appear in the nov@@ el .', 'en'),
    ('les zones rurales offr@@ ent de petites routes , a deux voies .', 'fr'),
    ('luego del cri@@ quet , esta el futbol , el sur@@ f , entre otros .', 'es'),
    ('am 18. august 1997 wurde der astero@@ id ( 76@@ 55 ) adam@@ ries nach ihm benannt .', 'de'),
    ('اصدرت عدة افلام وث@@ اي@@ قية عن حياة السيدة في@@ روز من بينها :', 'ar'),
    ('此外 ， 松@@ 嫩 平原 上 还有 许多 小 湖泊 ， 当地 俗@@ 称 为 “ 泡@@ 子 ” 。', 'zh'),
]

# add </s> sentence delimiters
sentences = [(('</s> %s </s>' % sent.strip()).split(), lang) for sent, lang in sentences]

In [10]:
sentences

[(['</s>',
   'the',
   'following',
   'secon@@',
   'dary',
   'charac@@',
   'ters',
   'also',
   'appear',
   'in',
   'the',
   'nov@@',
   'el',
   '.',
   '</s>'],
  'en'),
 (['</s>',
   'les',
   'zones',
   'rurales',
   'offr@@',
   'ent',
   'de',
   'petites',
   'routes',
   ',',
   'a',
   'deux',
   'voies',
   '.',
   '</s>'],
  'fr'),
 (['</s>',
   'luego',
   'del',
   'cri@@',
   'quet',
   ',',
   'esta',
   'el',
   'futbol',
   ',',
   'el',
   'sur@@',
   'f',
   ',',
   'entre',
   'otros',
   '.',
   '</s>'],
  'es'),
 (['</s>',
   'am',
   '18.',
   'august',
   '1997',
   'wurde',
   'der',
   'astero@@',
   'id',
   '(',
   '76@@',
   '55',
   ')',
   'adam@@',
   'ries',
   'nach',
   'ihm',
   'benannt',
   '.',
   '</s>'],
  'de'),
 (['</s>',
   'اصدرت',
   'عدة',
   'افلام',
   'وث@@',
   'اي@@',
   'قية',
   'عن',
   'حياة',
   'السيدة',
   'في@@',
   'روز',
   'من',
   'بينها',
   ':',
   '</s>'],
  'ar'),
 (['</s>',
   '此外',
   '，',
   '松@@',
   '嫩',

### Create batch

In [8]:
bs = len(sentences)
slen = max([len(sent) for sent, _ in sentences])

word_ids = torch.LongTensor(slen, bs).fill_(params.pad_index)
for i in range(len(sentences)):
    sent = torch.LongTensor([dico.index(w) for w in sentences[i][0]])
    word_ids[:len(sent), i] = sent

lengths = torch.LongTensor([len(sent) for sent, _ in sentences])
langs = torch.LongTensor([params.lang2id[lang] for _, lang in sentences]).unsqueeze(0).expand(slen, bs) if params.n_langs > 1 else None

In [14]:
langs

tensor([[ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14],
        [ 4,  6,  5,  2,  0, 14]])

### Forward

In [9]:
tensor = model('fwd', x=word_ids, lengths=lengths, langs=langs, causal=False).contiguous()
print(tensor.size())

torch.Size([22, 6, 1024])


The variable `tensor` is of shape `(sequence_length, batch_size, model_dimension)`.

`tensor[0]` is a tensor of shape `(batch_size, model_dimension)` that corresponds to the first hidden state of the last layer of each sentence.

This is this vector that we use to finetune on the GLUE and XNLI tasks.