# Load the Model

In [6]:
from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments
)

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
import torch
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification

In [None]:
# Path to your locally saved fine-tuned GPT-2 model with a classification head
MODEL_PATH = "gpt2-conversation-finetuned"

# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)  # GPT2ForSequenceClassification is for classification, not generation

In [25]:
def continue_conversation(prompt_text: str, model, tokenizer, max_length: int) -> str:
    """
    Generates the next part of the conversation given a prompt.

    Args:
        prompt_text (str): The initial conversation/prompt.
        model: A fine-tuned GPT-2 model loaded for text generation.
        tokenizer: The tokenizer associated with the GPT-2 model.
        max_length (int): The maximum number of tokens to generate.

    Returns:
        str: The continuation text generated by GPT-2.
    """
    # Put the model in evaluation mode
    model.eval()

    # Encode the input prompt
    input_ids = tokenizer.encode(prompt_text, return_tensors='pt')

    # Generate continuation
    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_length=max_length,
            num_beams=5,             # Use beam search for higher quality generation
            no_repeat_ngram_size=2,  # Helps reduce repetition
            early_stopping=True
        )

    # Decode the output tokens to text
    continuation = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return continuation

## Clean Text

In [27]:
import re
import nltk

In [28]:
nltk.download('punkt_tab')
nltk.download('punkt')
nltk.download('stopwords')

from nltk.corpus import stopwords

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/ashansubodha/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/ashansubodha/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/ashansubodha/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [29]:
def clean_conversation_text(text: str) -> str:
    """
    Cleans the input conversation text by:
    1. Converting to lowercase
    2. Removing punctuation and digits
    3. Tokenizing
    4. Removing stopwords
    5. Removing short tokens (optional)
    
    Args:
        text (str): The conversation text to clean.
        
    Returns:
        str: The cleaned text.
    """
    
    # 1. Convert to lowercase
    text = text.lower()
    
    # 2. Remove punctuation and digits
    # This regex replaces any character that is not a letter or whitespace
    text = re.sub(r'[^a-z\s]', '', text)
    
    # 3. Tokenize
    tokens = nltk.word_tokenize(text)
    
    # 4. Remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [token for token in tokens if token not in stop_words]
    
    # 5. (Optional) Remove short tokens of length 1
    tokens = [token for token in tokens if len(token) > 1]
    
    # Join back into a single string
    cleaned_text = " ".join(tokens)
    
    return cleaned_text



In [30]:
if __name__ == "__main__":
    # Example prompt (conversation start)
    prompt = (
        "Hello, I have a question about my bank account. "
        "Could you help me figure out why my card was declined?"
    )
    
    cleaned_prompt = clean_conversation_text(prompt)

    # Generate the continuation
    max_length = 100  # Adjust max_length as needed
    generated_text = continue_conversation(cleaned_prompt, model, tokenizer, max_length)

    print("----- Original Prompt -----")
    print(prompt)
    print("\n----- Model Continuation -----")
    print(generated_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


----- Original Prompt -----
Hello, I have a question about my bank account. Could you help me figure out why my card was declined?

----- Model Continuation -----
hello question bank account could help figure card declined. Can you provide more details about your account?
Yes, we can provide a contact number for you.
Thank you for taking the time to provide us with this information. We appreciate your interest in helping us. Please contact us for more information about this matter.
