In [None]:
# Cell 1: Setup and Installations

!pip install transformers datasets torch huggingface_hub



In [None]:

# Cell 2: Log In to Hugging Face

from huggingface_hub import notebook_login

# This will display a widget to enter your Hugging Face access token.
# You can get a token from your HF profile -> Settings -> Access Tokens.
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# **Importing the Fine-Tuned Model**

In [None]:


import zipfile
import os
from google.colab import drive

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')


# This path point to model's ZIP file in Google Drive.
drive_zip_path = '/content/drive/MyDrive/gpt2-poetry-finetuned_v2.zip'
extract_dir = './my_finetuned_model_v2' # This is a local folder in Colab

print(f"\nLooking for your model at: {drive_zip_path}")
if os.path.exists(drive_zip_path):
    print("Model found in Google Drive. Unzipping...")
    with zipfile.ZipFile(drive_zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
    print(f"✅ Model unzipped successfully into the '{extract_dir}' folder.")
else:
    print(f"❌ Error: The file '{os.path.basename(drive_zip_path)}' was not found in your Google Drive.")
    print("Please make sure the file is in your 'My Drive' folder and the name matches exactly.")

Mounting Google Drive...
Mounted at /content/drive

Looking for your model at: /content/drive/MyDrive/gpt2-poetry-finetuned_v2.zip
Model found in Google Drive. Unzipping...
✅ Model unzipped successfully into the './my_finetuned_model_v2' folder.


# **The Poetry Generator Class**

In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

class PoetryGenerator:

    # This now points to the local folder where our fine-tuned model was unzipped.
    def __init__(self, model_name="./my_finetuned_model_v2"):
        print(f"Loading your custom fine-tuned model from '{model_name}'...")
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = GPT2LMHeadModel.from_pretrained(model_name)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        print(f"✅ Custom model loaded successfully on device: {self.device}")

    def generate_candidates(self, prompt, num_candidates=5, max_length=20, **kwargs):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_length,
            num_return_sequences=num_candidates,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=self.tokenizer.pad_token_id,
            **kwargs
        )
        generated_texts = [
            self.tokenizer.decode(output, skip_special_tokens=True)[len(prompt):].strip()
            for output in outputs
        ]
        candidates = [text for text in generated_texts if text]
        return candidates

print("✅ PoetryGenerator class defined.")

✅ PoetryGenerator class defined.


# **The PPLM**

In [None]:
# Cell 5: The PPLM Generator for Haikus (NEW)

import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import nltk
from nltk.corpus import cmudict
import urllib.error


try:

    nltk.data.find('corpora/cmudict.zip')
except LookupError:
    print("Downloading CMU Pronouncing Dictionary for syllable counting...")


    try:
        nltk.download('cmudict')
    except urllib.error.URLError:
        print("Download failed. Please check your internet connection and try again.")


pronouncing_dict = cmudict.dict()

def count_syllables(word):

    word = word.lower().strip(".,?!;")
    if word not in pronouncing_dict:
        return max([1, len(word) / 2])
    return max([len([y for y in x if y[-1].isdigit()]) for x in pronouncing_dict[word]])

def get_total_syllables(text):

    return sum(count_syllables(word) for word in text.split())

# --- PPLM Generator Class ---
class PPLM_PoetryGenerator:
    def __init__(self, model_name="./my_finetuned_model_v2"):
        print(f"Loading PPLM-enabled model from '{model_name}'...")
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        print(f" PPLM-enabled model loaded successfully.")

    def generate_haiku_line(self, prompt, target_syllables, num_candidates=5,
                            stepsize=0.03, num_iterations=3, gm_scale=0.9):

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        # --- The PPLM Perturbation Loop ---
        # 1. Get the original hidden states from the model
        unpert_outputs = self.model(**inputs, output_hidden_states=True)
        unpert_last_hidden_state = unpert_outputs.hidden_states[-1]

        # This will be our "nudged" hidden state
        pert_last_hidden_state = unpert_last_hidden_state.clone().detach().requires_grad_(True)

        for i in range(num_iterations):
            # 2. Get the model's prediction based on the current (possibly nudged) state
            outputs = self.model(past_key_values=unpert_outputs.past_key_values,
                                 attention_mask=inputs.attention_mask,
                                 hidden_states=(*unpert_outputs.hidden_states[:-1], pert_last_hidden_state))

            logits = outputs.logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)

            # 3. Calculate the "loss"

            top_k_probs, top_k_indices = torch.topk(probs, k=10, dim=-1)

            avg_syllable_error = 0
            for i in range(10):
                word = self.tokenizer.decode(top_k_indices[0][i])
                syllables = get_total_syllables(word)
                error = (syllables - target_syllables) ** 2
                avg_syllable_error += error * top_k_probs[0][i]


            avg_syllable_error.backward()
            grad = pert_last_hidden_state.grad.detach()
            pert_last_hidden_state = pert_last_hidden_state - stepsize * grad
            pert_last_hidden_state.requires_grad_(True)


        final_logits = self.model(past_key_values=unpert_outputs.past_key_values,
                                  attention_mask=inputs.attention_mask,
                                  hidden_states=(*unpert_outputs.hidden_states[:-1], pert_last_hidden_state)).logits[:, -1, :]

        # Blend the steered logits with the original ones to maintain fluency
        final_logits = (gm_scale * final_logits) + ((1 - gm_scale) * logits)

        # Generate multiple candidates from the final steered distribution
        final_outputs = self.model.generate(
            input_ids=inputs.input_ids,
            max_new_tokens=10,
            num_return_sequences=num_candidates,
            do_sample=True,
            logits_processor=[lambda input_ids, scores: final_logits]
        )

        generated_texts = [
            self.tokenizer.decode(output, skip_special_tokens=True)[len(prompt):].strip()
            for output in final_outputs
        ]
        return [text for text in generated_texts if text.strip()]

print("✅ PPLM_PoetryGenerator class for Haiku mode defined.")

✅ PPLM_PoetryGenerator class for Haiku mode defined.


# **The RLHF Reward Model Class**

In [None]:


import torch.nn as nn
import torch.optim as optim
from transformers import DistilBertTokenizer, DistilBertModel

class RewardModel(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased"):
        super().__init__()
        self.bert_base = DistilBertModel.from_pretrained(model_name)
        self.score_head = nn.Linear(self.bert_base.config.dim, 1)
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_base(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0]
        score = self.score_head(cls_output)
        return score

class RewardModelTrainer:
    def __init__(self):
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.model = RewardModel()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-5)
        self.loss_function = nn.MSELoss() # Using Mean Squared Error for regression

    def train_on_ratings(self, rated_lines, epochs=5):
        if not rated_lines: return
        print(f"\n--- Training Reward Model on {len(rated_lines)} ratings ---")
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0
            for text, user_score in rated_lines:
                inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
                target_score = torch.tensor([[float(user_score)]], device=self.device)
                predicted_score = self.model(**inputs)
                loss = self.loss_function(predicted_score, target_score)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
            print(f"Epoch {epoch+1}/{epochs}, Average Loss: {total_loss / len(rated_lines):.4f}")

print("✅ RewardModel and RewardModelTrainer classes defined.")

✅ RewardModel and RewardModelTrainer classes defined.


# **Main interactive application**

In [None]:
# The final, fully upgraded Main Interactive Application cell

import os
import json
import torch

# --- Helper functions with updated Poem Library management ---
def save_user_profile(user_id, trainer, poem_library):

    profile_dir = f"profiles/{user_id}"
    os.makedirs(profile_dir, exist_ok=True)

    torch.save(trainer.model.state_dict(), f"{profile_dir}/reward_model.pt")

    with open(f"{profile_dir}/poem_library.json", "w") as f:
        json.dump(poem_library, f, indent=2)
    print(f"\n[+] Profile for '{user_id}' (including poems) saved successfully.")

def load_user_profile(user_id, trainer):

    profile_dir = f"profiles/{user_id}"
    poem_library = {}

    # Load reward model
    model_path = f"{profile_dir}/reward_model.pt"
    if os.path.exists(model_path):
        trainer.model.load_state_dict(torch.load(model_path, map_location=trainer.device))
        print(f"[+] Welcome back, {user_id}! Your reward model has been loaded.")
    else:
        print(f"[+] Welcome, {user_id}! Creating a new profile for you.")

    # Load poem library
    library_path = f"{profile_dir}/poem_library.json"
    if os.path.exists(library_path):
        with open(library_path, "r") as f:
            poem_library = json.load(f)
        print("[+] Your poem library has been loaded.")

    return trainer, poem_library


def main_session():
    # --- 1. Initialization ---
    user_id = input("Please enter your name to begin: ")
    trainer, poem_library = load_user_profile(user_id, RewardModelTrainer())


    freestyle_generator = PoetryGenerator()
    haiku_generator = PPLM_PoetryGenerator()
    session_feedback_buffer = []

    # --- 2. Main Menu Loop ---
    while True:
        print("\n" + "--- MAIN MENU " + "-"*48)
        action = input("What would you like to do?\n  1: Start a new poem\n  2: Continue an old poem\n  3: Quit\nYour choice: ")

        poem_so_far = []
        poem_title = ""
        poem_mode = ""
        theme = ""

        # --- 3. Action Selection ---
        if action == '1': # Start New Poem
            poem_title = input("\nGive your new poem a title: ").strip()
            theme = input("Enter the theme for your poem: ").strip()

            while poem_mode not in ['1', '2']:
                poem_mode = input("Choose a mode: (1) Freestyle, (2) Haiku: ")
            poem_mode = "freestyle" if poem_mode == '1' else "haiku"

            start_line = input("Enter the first line: ").strip()
            poem_so_far = [start_line]

            poem_library[poem_title] = {"theme": theme, "lines": poem_so_far, "mode": poem_mode}

        elif action == '2': # Continue Old Poem
            if not poem_library:
                print("\nYour poem library is empty! Please start a new poem first.")
                continue

            print("\nYour poems:")
            titles = list(poem_library.keys())
            for i, title in enumerate(titles):
                print(f"  {i+1}: {title}")

            while True:
                try:
                    choice_idx = int(input("Which poem to continue? ")) - 1
                    if 0 <= choice_idx < len(titles):
                        poem_title = titles[choice_idx]
                        poem_data = poem_library[poem_title]
                        theme = poem_data['theme']
                        poem_so_far = poem_data['lines']
                        poem_mode = poem_data.get('mode', 'freestyle') # Default to freestyle for older poems
                        break
                    else: print("Invalid number.")
                except ValueError: print("Please enter a number.")

        elif action == '3':
            break
        else:
            print("Invalid choice. Please try again.")
            continue

        # --- 4. The Creative Writing Loop ---
        while True:

            print("\n" + "="*60)
            print(f"Title: {poem_title} (Mode: {poem_mode.capitalize()})")
            print("POEM SO FAR:")
            for line in poem_so_far: print(f"  {line}")
            print("="*60)

            if poem_mode == 'haiku' and len(poem_so_far) >= 3:
                print("Haiku complete! Returning to main menu.")
                break

            context_lines = poem_so_far[-2:]
            prompt = f"Theme: {theme}\nPoem: ...{' '.join(context_lines)}"

            # Generate candidates based on mode
            if poem_mode == 'haiku':
                line_number = len(poem_so_far)
                syllable_targets = {0: 5, 1: 7, 2: 5}
                target_syllables = syllable_targets.get(line_number, 5)
                print(f"\n[AI is generating a {target_syllables}-syllable line for the Haiku]...")
                candidates = haiku_generator.generate_haiku_line(prompt, target_syllables)
            else:
                print(f"\n[AI is generating 5 freestyle options]...")
                candidates = freestyle_generator.generate_candidates(prompt)


            candidates = [line for line in candidates if line.strip()]
            if not candidates:
                print("[AI had trouble generating a valid line. Please write your own.]")
                chosen_line = input("> ").strip()
                session_feedback_buffer.append((chosen_line, 5))
                poem_so_far.append(chosen_line)
                continue

            # --- 5. User Feedback and Selection Logic ---
            print("\nPlease rate each of the following lines from 0 (bad) to 5 (excellent):")
            rated_candidates = []
            for i, line in enumerate(candidates):
                while True:
                    try:
                        rating = int(input(f"  {i+1}: {line}\n  Your rating (0-5): "))
                        if 0 <= rating <= 5:
                            rated_candidates.append({"text": line, "rating": rating})
                            session_feedback_buffer.append((line, rating))
                            break
                        else: print("Invalid rating. Please enter a number between 0 and 5.")
                    except ValueError: print("Invalid input. Please enter a number.")

            valid_candidates = [c for c in rated_candidates if c["rating"] > 0]
            if not valid_candidates:
                print("\nAll lines were pruned. Please write your own line to continue.")
                chosen_line = input("> ").strip()
                session_feedback_buffer.append((chosen_line, 5))
                poem_so_far.append(chosen_line)
                continue

            highest_rating = max(c["rating"] for c in valid_candidates)
            tied_highest = [c for c in valid_candidates if c["rating"] == highest_rating]

            highest_rated_line = ""
            if len(tied_highest) > 1:
                print(f"\nThere's a tie for the highest rating ({highest_rating}). Please choose one:")
                for i, candidate in enumerate(tied_highest):
                    print(f"  {i+1}: {candidate['text']}")
                while True:
                    try:
                        choice_idx = int(input("Your choice: ")) - 1
                        if 0 <= choice_idx < len(tied_highest):
                            highest_rated_line = tied_highest[choice_idx]['text']
                            break
                        else: print("Invalid number.")
                    except ValueError: print("Please enter a number.")
            else:
                highest_rated_line = tied_highest[0]['text']

            print(f"\nThe highest-rated line is: '{highest_rated_line}'")
            keep_choice = input("Do you want to keep this line? (yes/no): ").lower()
            if keep_choice == 'yes':
                print(f"\nAdding to the poem.")
                poem_so_far.append(highest_rated_line)
            else:
                print("\nSuggestion rejected.")
                inject_or_quit = input("Would you like to (i)nject your own line or go to the (m)ain menu?: ").lower()
                if inject_or_quit == 'i':
                    injected_line = input("Your line: ").strip()
                    poem_so_far.append(injected_line)
                    session_feedback_buffer.append((injected_line, 5))
                    session_feedback_buffer.append((highest_rated_line, 1))
                else:
                    break


            poem_library[poem_title]['lines'] = poem_so_far
            save_user_profile(user_id, trainer, poem_library)

            continue_poem = input("\nContinue this poem? (yes/no): ").lower()
            if continue_poem != 'yes':
                break

    print("\n--- Session Over ---")
    if session_feedback_buffer:
        trainer.train_on_ratings(session_feedback_buffer)
        save_user_profile(user_id, trainer, poem_library)
    else:
        print("No new feedback was provided. Profile not updated.")
    print("Thank you for creating poetry with me!")


main_session()

Please enter your name to begin: abhi
[+] Welcome, abhi! Creating a new profile for you.
Loading your custom fine-tuned model from './my_finetuned_model_v2'...
✅ Custom model loaded successfully on device: cuda
Loading PPLM-enabled model from './my_finetuned_model_v2'...
✅ PPLM-enabled model loaded successfully.

--- MAIN MENU ------------------------------------------------
What would you like to do?
  1: Start a new poem
  2: Continue an old poem
  3: Quit
Your choice: 3

--- Session Over ---
No new feedback was provided. Profile not updated.
Thank you for creating poetry with me!
