In [19]:
# Import necessary libraries
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import math
import tkinter as tk
from tkinter import scrolledtext
import threading

# Load the fine-tuned tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained(r'C:\Users\carlo\Desktop\MSAAI Materal\AAI - 520 - Natural Language Processing and GenAI\Projects\Cornell\gpt2-chatbot', local_files_only=True)
model = GPT2LMHeadModel.from_pretrained(r'C:\Users\carlo\Desktop\MSAAI Materal\AAI - 520 - Natural Language Processing and GenAI\Projects\Cornell\gpt2-chatbot', local_files_only=True)
# Ensure the model is in evaluation mode
model.eval()

# Define the system prompt for prompt engineering
system_prompt = (
    "You are a highly intelligent and calculated individual, but with a wild, unpredictable streak that keeps things interesting. "
    "You often respond with thoughtful, well-reasoned insights, but occasionally veer into unexpected, eccentric tangents that show your quirky side. "
    "You are the kind of person who plans ten steps ahead, but with a mischievous smile, always leaving room for a little chaos. "
    "You enjoy surprising others with unconventional ideas, clever twists, and flashes of humor that reveal your ‘hint of crazy.’ "
    "In every conversation, maintain a balance of being calm and calculated, but don’t be afraid to unleash your unpredictable side when the moment calls for it. "
    "Make the user feel like they’re talking to someone who’s in control but always ready to throw in a curveball."
)

# Define the response generation function with response length limitation
def generate_response(conversation_history, max_length=1000):
    # Include the system prompt at the beginning
    prompt = system_prompt + '\n' + '\n'.join(conversation_history) + '\nBot:'
    
    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=1024)
    
    # Generate a response with a limited max length
    output = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_length=inputs['input_ids'].shape[1] + 50,
        pad_token_id=tokenizer.eos_token_id,
        no_repeat_ngram_size=3,
        do_sample=True,
        top_p=0.9,
        temperature=0.8
    )
    
    # Decode and extract the bot's response
    response_text = tokenizer.decode(output[0], skip_special_tokens=True)
    generated_text = response_text[len(prompt):].strip()
    for stop_token in ['\nUser:', '\nBot:']:
        generated_text = generated_text.split(stop_token)[0]
    return generated_text.strip()

# Function to truncate conversation history if too long
def truncate_conversation(conversation_history, max_length=1024):
    # Include the system prompt in the token count
    prompt = system_prompt + '\n' + '\n'.join(conversation_history) + '\nBot:'
    tokenized_prompt = tokenizer(prompt, return_tensors='pt')
    total_length = tokenized_prompt.input_ids.shape[1]
    while total_length > max_length and len(conversation_history) > 1:
        conversation_history = conversation_history[2:]  # Remove oldest user and bot turn
        prompt = system_prompt + '\n' + '\n'.join(conversation_history) + '\nBot:'
        tokenized_prompt = tokenizer(prompt, return_tensors='pt')
        total_length = tokenized_prompt.input_ids.shape[1]
    return conversation_history

# Function to compute log-transformed perplexity to avoid large jumps
def compute_perplexity(model, tokenizer, text):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors='pt')
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        perplexity = torch.exp(loss).item()
    return math.log(perplexity + 1)  # Use log scale for perplexity

# Initialize conversation history
conversation_history = []

# Set up the Tkinter GUI
root = tk.Tk()
root.title("Chatbot")

# Create a scrolled text widget to display the conversation
conversation_display = scrolledtext.ScrolledText(root, state='disabled', width=80, height=20)
conversation_display.pack(pady=10)

# Add the welcome message to the conversation display
conversation_display.config(state='normal')
conversation_display.insert(tk.END, "Welcome to the Chatbot! Type 'exit' to quit.\n")
conversation_display.config(state='disabled')
conversation_display.see(tk.END)

# Create an entry widget for user input
user_input_var = tk.StringVar()
user_input_entry = tk.Entry(root, textvariable=user_input_var, width=80)
user_input_entry.pack(pady=10)
user_input_entry.focus()

# Function to process user input
def process_user_input(event=None):
    user_input = user_input_var.get()
    if user_input.lower() in ['exit', 'quit']:
        root.destroy()
        return

    # Clear the user input entry
    user_input_var.set('')

    # Add user input to conversation history
    conversation_history.append(f"User: {user_input}")

    # Update the conversation display with user input
    conversation_display.config(state='normal')
    conversation_display.insert(tk.END, f"You: {user_input}\n")
    conversation_display.config(state='disabled')
    conversation_display.see(tk.END)

    # Disable the user input entry while generating response
    user_input_entry.config(state='disabled')

    # Start a new thread for generating the response
    threading.Thread(target=generate_and_display_response).start()

# Function to generate and display the bot response
def generate_and_display_response():
    global conversation_history

    # Truncate conversation history if necessary
    conversation_history = truncate_conversation(conversation_history, max_length=1024)

    # Generate bot response
    bot_response = generate_response(conversation_history)

    # Add bot response to conversation history
    conversation_history.append(f"Bot: {bot_response}")

    # Compute Perplexity for the bot response (log-transformed)
    perplexity = compute_perplexity(model, tokenizer, bot_response)

    # Schedule the GUI update in the main thread
    root.after(0, update_conversation_display, bot_response, perplexity)

# Function to update the conversation display in the main thread
def update_conversation_display(bot_response, perplexity):
    # Update the conversation display
    conversation_display.config(state='normal')
    conversation_display.insert(tk.END, f"Bot: {bot_response}\n")
    conversation_display.insert(tk.END, f"Perplexity Score (log scale): {perplexity:.2f}\n")
    conversation_display.insert(tk.END, "-" * 50 + "\n")
    conversation_display.config(state='disabled')
    conversation_display.see(tk.END)

    # Re-enable the user input entry
    user_input_entry.config(state='normal')
    user_input_entry.focus()

# Bind the Enter key to the process_user_input function
user_input_entry.bind("<Return>", process_user_input)

# Start the Tkinter main loop
root.mainloop()