In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import logging
from tqdm import trange

import torch
import torch.nn.functional as F
import numpy as np

from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
from transformers import DistilBertTokenizer, DistilBertModel

In [15]:
np.random.randint(1, 1000000)

290364

In [3]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

tokenizer.save_pretrained('./distilBert/')
model.save_pretrained('./distilBert/')

input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)
outputs = model(input_ids)
last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 231508/231508 [00:00<00:00, 364284.37B/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 492/492 [00:00<00:00, 7336.48B/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 267967963/267967963 [03:07<00:00, 1430512.45B/s]


In [4]:
last_hidden_states

tensor([[[ 0.3897,  0.5468, -0.3846,  ..., -0.5253,  0.0137,  0.5638],
         [ 0.4385,  0.8767, -0.2833,  ..., -0.5621,  0.0267,  0.4095],
         [ 0.4105,  0.7941, -0.0934,  ..., -0.5104, -0.1070,  0.9439],
         [ 0.3925,  0.8002, -0.3922,  ..., -0.6947,  0.0348,  0.6961],
         [ 0.1508,  0.5453, -0.3423,  ..., -0.3691,  0.1747,  0.7834],
         [ 0.3180,  0.5482, -0.3096,  ..., -0.4357,  0.1412,  0.3862]]],
       grad_fn=<AddcmulBackward>)

In [17]:
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())

MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
    'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
    'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
    'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
    'xlm': (XLMWithLMHeadModel, XLMTokenizer),
}

In [21]:
def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits


def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
                    is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        for _ in trange(length):

            inputs = {'input_ids': generated}
            if is_xlnet: 
                # XLNet is a direct (predict same token, not next token) and bi-directional model by default
                # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
                input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
                perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
                perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
                target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
                target_mapping[0, 0, -1] = 1.0  # predict last token
                inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}

            if is_xlm_mlm and xlm_mask_token:
                # XLM MLM models are direct models (predict same token, not next token)
                # => need one additional dummy token in the input (will be masked and guessed)
                input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
                inputs = {'input_ids': input_ids}

            if xlm_lang is not None:
                inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)

            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[i, _] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: # greedy sampling:
                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
    return generated

In [5]:
ALL_MODELS

('gpt2',
 'gpt2-medium',
 'gpt2-large',
 'distilgpt2',
 'openai-gpt',
 'xlnet-base-cased',
 'xlnet-large-cased',
 'transfo-xl-wt103',
 'xlm-mlm-en-2048',
 'xlm-mlm-ende-1024',
 'xlm-mlm-enfr-1024',
 'xlm-mlm-enro-1024',
 'xlm-mlm-tlm-xnli15-1024',
 'xlm-mlm-xnli15-1024',
 'xlm-clm-enfr-1024',
 'xlm-clm-ende-1024',
 'xlm-mlm-17-1280',
 'xlm-mlm-100-1280',
 'ctrl')

In [6]:
#?? tokenizer_class.from_pretrained

In [18]:
device=torch.device("cpu")#torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_class, tokenizer_class = MODEL_CLASSES['gpt2']
tokenizer = tokenizer_class.from_pretrained('./distilGPT2/')
model = model_class.from_pretrained('./distilGPT2/')
model.to(device)
model.eval();

In [8]:
#tokenizer.save_pretrained('.')
model.save_pretrained('.')

In [19]:
raw_text = "Hi, how are you"

In [23]:
context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
context_tokens

[17250, 11, 703, 389, 345]

In [11]:
tokenizer.convert_ids_to_tokens(context_tokens[0])

'Hi'

In [25]:
out = sample_sequence(
            model=model,
            context=context_tokens,
            num_samples=3,
            length=2,
            temperature=1,
            top_k=0,
            top_p=0.9,
            repetition_penalty=1.0,
            device=device,
        )

print(out.shape)

out = out[:, len(context_tokens):].tolist()
d = {}
for i, o in enumerate(out):
    text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
    text = text[:None]
    d[f'text{i}'] = text
    print(text)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.08it/s]


torch.Size([3, 7])
 going to
 working?
 guys getting


In [29]:
import json
json.dumps(d)

'{"text0": " going to", "text1": " working?", "text2": " guys getting"}'