In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import logging
from tqdm import trange

import torch
import torch.nn.functional as F
import numpy as np

from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer


logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())

MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
    'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
    'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
    'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
    'xlm': (XLMWithLMHeadModel, XLMTokenizer),
}

In [2]:
from types import SimpleNamespace
args = SimpleNamespace()
args.no_cuda = False
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()

#Token at which text generation is stopped
args.stop_token = None

In [29]:
#GPT-2 http://jalammar.github.io/illustrated-gpt2/
args.model_type = "gpt2"
args.model_name_or_path = "gpt2"

args.length = 20

args.num_samples = 1

#temperature of 0 implies greedy sampling
args.temperature = 1.0

#sample a word from the entire list of size top_k using the score as the probability of selecting that word
args.top_k = 40

args.top_p = 0.9

#Primarily useful for CTRL model; in that case, use 1.2
args.repetition_penalty = 1.0

# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm:
    xlm_mask_token = tokenizer.mask_token_id
else:
    xlm_mask_token = None
    
xlm_lang = None

In [6]:
#CTRL (Salesforce)

#Needs more GPU memory, so use CPU
args.no_cuda = True
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

args.model_type = "ctrl"
args.model_name_or_path = "ctrl"

args.length = 20

args.num_samples = 1

#temperature of 0 implies greedy sampling
args.temperature = 0

#sample a word from the entire list of size top_k using the score as the probability of selecting that word
args.top_k = 40

args.top_p = 0.9

#Primarily useful for CTRL model; in that case, use 1.2
args.repetition_penalty = 1.2

# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm:
    xlm_mask_token = tokenizer.mask_token_id
else:
    xlm_mask_token = None
    
xlm_lang = None

In [4]:
#XLM (Multi-lingual)
#Language code = 'EL' for Greek
args.model_type = "xlm"
args.model_name_or_path = "xlm-mlm-100-1280"

args.length = 20

args.num_samples = 1

#temperature of 0 implies greedy sampling
args.temperature = 1.0

#sample a word from the entire list of size top_k using the score as the probability of selecting that word
args.top_k = 40

args.top_p = 0.9

#Primarily useful for CTRL model; in that case, use 1.2
args.repetition_penalty = 1.0


In [23]:
#Select language (for XLM)
print(tokenizer.lang2id.keys())
print(tokenizer.lang2id['en'])
print(model.config.use_lang_emb)
print(xlm_mask_token)
xlm_lang=None

dict_keys(['af', 'als', 'am', 'an', 'ang', 'ar', 'arz', 'ast', 'az', 'bar', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'ceb', 'ckb', 'cs', 'cy', 'da', 'de', 'el', 'en', 'eo', 'es', 'et', 'eu', 'fa', 'fi', 'fr', 'fy', 'ga', 'gan', 'gl', 'gu', 'he', 'hi', 'hr', 'hu', 'hy', 'ia', 'id', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'kn', 'ko', 'ku', 'la', 'lb', 'lt', 'lv', 'mk', 'ml', 'mn', 'mr', 'ms', 'my', 'nds', 'ne', 'nl', 'nn', 'no', 'oc', 'pl', 'pt', 'ro', 'ru', 'scn', 'sco', 'sh', 'si', 'simple', 'sk', 'sl', 'sq', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'tt', 'uk', 'ur', 'uz', 'vi', 'war', 'wuu', 'yi', 'zh', 'zh_classical', 'zh_min_nan', 'zh_yue'])
23
False
5


In [5]:
del model

In [5]:
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
model.eval()

12/13/2019 09:38:29 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-vocab.json from cache at C:\Users\s.lagousis.POBUCA\.cache\torch\transformers\36cc0aaffa16aeaa0c6f7b21e58a10a9ab609ed4dbd2ff17423fc95690b4a8bf.50ff3a1ade6a729ff2500ee529c5c0d5630ef2025abc4d2a25e845aca60bac98
12/13/2019 09:38:29 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-merges.txt from cache at C:\Users\s.lagousis.POBUCA\.cache\torch\transformers\3e86a68a8775a1b921af60bcce492264e458bc21ca4239639ea8c253fb18ac53.7b94ff5bc85062952d4da8c1ef6755c3f7fd50b4ba33d63407770a920d1b4711
12/13/2019 09:38:31 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json from cache at C:\Users\s.lagousis.POBUCA\.cache\torch\transformers\230d9a959764b99634b4ade26fa5ee35f1c7cd9dc3739ff3e

12/13/2019 09:38:32 - INFO - transformers.modeling_utils -   loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin from cache at C:\Users\s.lagousis.POBUCA\.cache\torch\transformers\bf9a2ebdc571b0b216a2d5504dd40391f21b19f2a3b7bbcf345751f9ef7186ed.f2e36d5181b147331929e329fafc532d23b4def47c15eadd8e6f787c55f95c4e


XLMWithLMHeadModel(
  (transformer): XLMModel(
    (position_embeddings): Embedding(512, 1280)
    (embeddings): Embedding(200000, 1280, padding_idx=2)
    (layer_norm_emb): LayerNorm(torch.Size([1280]), eps=1e-12, elementwise_affine=True)
    (attentions): ModuleList(
      (0): MultiHeadAttention(
        (q_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (k_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (v_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (out_lin): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (1): MultiHeadAttention(
        (q_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (k_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (v_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (out_lin): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (2): MultiHeadAttention(
        (q_lin): Linear(in_features=1280,

In [5]:
model

XLMWithLMHeadModel(
  (transformer): XLMModel(
    (position_embeddings): Embedding(512, 1280)
    (embeddings): Embedding(200000, 1280, padding_idx=2)
    (layer_norm_emb): LayerNorm(torch.Size([1280]), eps=1e-12, elementwise_affine=True)
    (attentions): ModuleList(
      (0): MultiHeadAttention(
        (q_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (k_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (v_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (out_lin): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (1): MultiHeadAttention(
        (q_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (k_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (v_lin): Linear(in_features=1280, out_features=1280, bias=True)
        (out_lin): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (2): MultiHeadAttention(
        (q_lin): Linear(in_features=1280,

In [6]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, source=sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

In [7]:
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
                    is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        for _ in trange(length):

            inputs = {'input_ids': generated}
            if is_xlnet: 
                # XLNet is a direct (predict same token, not next token) and bi-directional model by default
                # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
                input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
                perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
                perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
                target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
                target_mapping[0, 0, -1] = 1.0  # predict last token
                inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}

            if is_xlm_mlm and xlm_mask_token:
                # XLM MLM models are direct models (predict same token, not next token)
                # => need one additional dummy token in the input (will be masked and guessed)
                input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
                inputs = {'input_ids': input_ids}

            if xlm_lang is not None:
                inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)

            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[i, _] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: # greedy sampling:
                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
    return generated

In [35]:
#raw_text = 'I asked his name'

raw_text = 'Πώς σε λένε'

context_tokens = tokenizer.encode(raw_text, add_special_tokens=True)

print(context_tokens)
print(tokenizer.convert_ids_to_tokens(context_tokens))

input_ids = torch.tensor(np.array(context_tokens), dtype=torch.long, device=args.device).view(1,-1)

with torch.no_grad():
    last_hidden_states = model(input_ids)

# https://jalammar.github.io/a-visual-guide-to-using-bert-for-the-first-time/
# For sentence classification, we’re only only interested in BERT’s output for the [CLS] token, so we select that slice of the cube and discard everything else.
# Slice the output for the first position for all the sequences, take all hidden unit outputs
#print(len(last_hidden_states))
print(last_hidden_states[0].shape)
print(last_hidden_states[0][:, 0, :].shape)

[1, 11052, 50469, 3639, 12227, 20251, 16845, 1]
['</s>', 'Π', 'ώς</w>', 'σε</w>', 'λ', 'έν', 'ε</w>', '</s>']
torch.Size([1, 8, 200000])
torch.Size([1, 200000])


In [22]:
#raw_text = 'Links I saw her staring at me'

#raw_text = 'Translation English : This is a natural language processing model that aims to generate coherent text in a controllable manner. ; French :'
#raw_text = 'Translation English : This is a natural language processing model that aims to generate coherent text in a controllable manner. ; German :'
raw_text = 'I asked his name'

context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)

# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm:
    xlm_mask_token = tokenizer.mask_token_id
else:
    xlm_mask_token = None
    
xlm_lang = None

out = sample_sequence(
     model=model,
     context=context_tokens,
     num_samples=args.num_samples,
     length=args.length,
     temperature=args.temperature,
     top_k=args.top_k,
     top_p=args.top_p,
     repetition_penalty=args.repetition_penalty,
     is_xlnet=bool(args.model_type == "xlnet"),
     is_xlm_mlm=is_xlm_mlm,
     xlm_mask_token=xlm_mask_token,
     xlm_lang=xlm_lang,
     device=args.device,
     )

out = out[:, len(context_tokens):].tolist()

print(raw_text, '>>')
for o in out:
    text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
    text = text[: text.find(args.stop_token) if args.stop_token else None]

    print(text)

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 10.98it/s]


I asked his name >>
scilscilscilscilscilscilscilscilscilscilscilscilscilscilscilscilscilscilscilscil
