In [None]:
# Install necessary libraries
!pip install transformers --quiet

# Import libraries
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import math

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Load the fine-tuned tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('/content/drive/My Drive/Colab Notebooks/Cornell/gpt2-chatbot')
model = GPT2LMHeadModel.from_pretrained('/content/drive/My Drive/Colab Notebooks/Cornell/gpt2-chatbot')

# Ensure the model is in evaluation mode
model.eval()

# Define the response generation function with response length limitation
def generate_response(conversation_history, max_length=1000):
    # Join the conversation history into a single string
    prompt = '\n'.join(conversation_history) + '\nBot:'

    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=1024)

    # Generate a response with a limited max length
    output = model.generate(
        input_ids=inputs['input_ids'],  # Correct access to input_ids
        attention_mask=inputs['attention_mask'],
        max_length=inputs['input_ids'].shape[1] + 50,  # Limit response length to avoid large outputs
        pad_token_id=tokenizer.pad_token_id,
        no_repeat_ngram_size=3,
        do_sample=True,
        top_p=0.9,
        temperature=0.8
    )

    # Decode and extract the bot's response
    response_text = tokenizer.decode(output[0], skip_special_tokens=True)
    generated_text = response_text[len(prompt):].strip()
    for stop_token in ['\nUser:', '\nBot:']:
        generated_text = generated_text.split(stop_token)[0]
    return generated_text.strip()

# Function to truncate conversation history if too long
def truncate_conversation(conversation_history, max_length=1024):
    prompt = '\n'.join(conversation_history) + '\nBot:'
    tokenized_prompt = tokenizer(prompt, return_tensors='pt')
    total_length = tokenized_prompt.input_ids.shape[1]
    while total_length > max_length and len(conversation_history) > 1:
        conversation_history = conversation_history[2:]  # Remove oldest user and bot turn
        tokenized_prompt = tokenizer('\n'.join(conversation_history) + '\nBot:', return_tensors='pt')
        total_length = tokenized_prompt.input_ids.shape[1]
    return conversation_history

# Function to compute log-transformed perplexity to avoid large jumps
def compute_perplexity(model, tokenizer, text):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors='pt')
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        perplexity = torch.exp(loss).item()
    return math.log(perplexity + 1)  # Use log scale for perplexity

# Initialize conversation history
conversation_history = []

print("Welcome to the Chatbot! Type 'exit' to quit.")
while True:
    user_input = input("You: ")
    if user_input.lower() in ['exit', 'quit']:
        print("Chatbot session ended.")
        break

    # Add user input to conversation history
    conversation_history.append(f"User: {user_input}")

    # Truncate conversation history if necessary
    conversation_history = truncate_conversation(conversation_history, max_length=1024)

    # Generate bot response
    bot_response = generate_response(conversation_history)

    # Add bot response to conversation history
    conversation_history.append(f"Bot: {bot_response}")

    # Print bot response
    print(f"Bot: {bot_response}")

    # Compute Perplexity for the bot response (log-transformed)
    perplexity = compute_perplexity(model, tokenizer, bot_response)

    # Print Perplexity
    print(f"\nPerplexity Score (log scale): {perplexity:.2f}")
    print("-" * 50)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Welcome to the Chatbot! Type 'exit' to quit.
You: Hello
Bot: Hello!

Perplexity Score (log scale): 55.81
--------------------------------------------------
You: What is up with you todat?
Bot: Well... I just got the hell out of the hospital.

Perplexity Score (log scale): 7.12
--------------------------------------------------
You: Damn, what happened?
Bot: I don't know.  I can't remember.

Perplexity Score (log scale): 8.49
--------------------------------------------------
You: The car crash...it was two years ago
Bot: Yeah...

Perplexity Score (log scale): 56.06
--------------------------------------------------


KeyboardInterrupt: Interrupted by user