In [1]:
!pip install -q transformers torch

In [2]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [4]:
dialogpt_name = "microsoft/DialoGPT-small"

dialogpt_tokenizer = GPT2Tokenizer.from_pretrained(dialogpt_name)
dialogpt_model = GPT2LMHeadModel.from_pretrained(dialogpt_name)

# Same fix applies
dialogpt_tokenizer.pad_token = dialogpt_tokenizer.eos_token
dialogpt_conversation_ids = None


dialogpt_model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [5]:
def chat_dialogpt(user_input, max_new_tokens=80):
    global dialogpt_conversation_ids

    # Encode user input + EOS
    new_input_ids = dialogpt_tokenizer.encode(
        user_input + dialogpt_tokenizer.eos_token,
        return_tensors="pt"
    )

    # Append conversation history
    if dialogpt_conversation_ids is None:
        bot_input_ids = new_input_ids
    else:
        bot_input_ids = torch.cat(
            [dialogpt_conversation_ids, new_input_ids],
            dim=-1
        )

    with torch.no_grad():
        dialogpt_conversation_ids = dialogpt_model.generate(
            bot_input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2,
            pad_token_id=dialogpt_tokenizer.eos_token_id
        )

    # Decode only newly generated tokens
    reply = dialogpt_tokenizer.decode(
        dialogpt_conversation_ids[:, bot_input_ids.shape[-1]:][0],
        skip_special_tokens=True
    )

    return reply


In [None]:
while True:
    user_input = input("You: ")
    if user_input.lower() in ["exit", "quit"]:
        break

    print("DialoGPT:", chat_dialogpt(user_input))
