In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel
EOS_ID = 50256

In [3]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium")

In [4]:
model.to("cuda")

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2):

In [5]:
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
    assert logits.dim() == 1  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits

In [6]:
def generate_next_token(model, input_ids, prev, temperature=1, top_k=0,
                        top_p=0, past=None):
    with torch.no_grad():
        logits, past = model(prev, past_key_values=past).values()
        logits = logits[0, -1, :] / temperature
        logits = top_filtering(logits, top_k=top_k, top_p=top_p)
        probs = F.softmax(logits.unsqueeze(0), dim=-1)
        prev = torch.multinomial(probs, num_samples=1)
        return prev, probs[0][prev], past

In [7]:
def generate_sequence(model, input_ids, temperature=1, top_k=0, top_p=0, max_length=20, past=None, device='cuda'):
    output = input_ids.new_zeros([input_ids.size(0), 0])
    prev = input_ids
    for i in range(max_length):
        prev, probs, past = generate_next_token(model, input_ids, prev, temperature, top_k, top_p, past)
        if prev == EOS_ID:
            break
        output = torch.cat((output, prev), dim=1)
    return output

In [8]:
max_length = 30
temperature=1
top_k=10
top_p=0
device = "cuda"

In [None]:
history = []
past=None

for step in range(10):
    text = input("USR: ")

    text_tokens = tokenizer.encode(text) + [EOS_ID]
    history.append(text_tokens)

    flattend_history = sum(history, [])  # flatten a list of list to a list

    context_tokens = torch.tensor(flattend_history, device=device, dtype=torch.long).unsqueeze(0)

    out = generate_sequence(model, context_tokens, max_length=max_length,
                            temperature=temperature, top_k=top_k, top_p=top_p)

    # double check!!!
    out = out.squeeze(0)

    out = out.tolist()
    out_text = tokenizer.decode(out)
    print("SYS: ", out_text)

    history.append(out + [EOS_ID])

#     history = history[-(2 * args.max_history + 1):]


USR: Hello!
SYS:  Hiya! :D
USR: How's it going?
SYS:  Well, it's going okay! : 3 How is your day going? :D lt 3
USR: It's going okay too! Just finished swimming
SYS:  Oof!! : lt : I Hope it was swimming! :p
USR: Right
