<a href="https://colab.research.google.com/github/Yanhan-ss/assonance-and-alliteration/blob/main/PPO%20reinforcement%20training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers datasets g2p-en nltk



In [None]:
!pip install trl==0.8.6



In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import PPOTrainer, PPOConfig
from datasets import Dataset
import torch
from g2p_en import G2p
from collections import defaultdict, Counter
import re
import copy
import itertools
import nltk
from google.colab import drive

# Mount Google Drive (only needed in Colab)
drive.mount('/content/drive')

# Download necessary NLTK resources
nltk.download('averaged_perceptron_tagger')

# Initialize G2P (Grapheme to Phoneme)
g2p = G2p()

"""
# Function to get the rhyme phoneme
def get_rhyme_phoneme(word):
    phones = g2p(word.lower())
    if not phones:
        return None
    p = phones[-1]
    return p[:-1] if p[-1].isdigit() else p

# Function to evaluate the rhyme level of a poem
def rhyme_level(poem_text):
    lines = poem_text.strip().split('\n')[0:4]
    endings = [line.strip().split()[-1] for line in lines if line]
    rhymes = [get_rhyme_phoneme(word) for word in endings]

    clusters = defaultdict(list)
    for idx, r in enumerate(rhymes):
        if r:
            clusters[r].append(idx)

    rhyme_sets = [grp for grp in clusters.values() if len(grp) > 1]
    rhyme_count = len(rhyme_sets)

    if rhyme_count == 0:
        return 'low'
    elif rhyme_count == 1:
        return 'medium'
    else:
        return 'high'
"""

# Function to extract phonemes from a sentence
def get_phonemes(sentence):
    words = re.findall(r'\b\w+\b', sentence)
    return [ph for ph in g2p(" ".join(words)) if ph not in [' ', '']]

# Function to check if a phoneme is a vowel
def is_vowel_phoneme(ph):
    return ph[-1].isdigit()

# Function to extract vowel phonemes from a sentence
def get_vowel_phonemes(phonemes):
    return [ph.strip('012') for ph in phonemes if is_vowel_phoneme(ph)]

# Function to extract initial phonemes from a sentence
def get_initial_phonemes(sentence):
    initials = []
    words = re.findall(r'\b\w+\b', sentence)
    for word in words:
        phs = [ph for ph in g2p(word) if ph not in [' ', '']]
        if phs:
            initials.append(phs[0].strip('012'))
    return initials

# Function to evaluate assonance and alliteration scores for a poem
def evaluate_poem(poem_text):
    lines = poem_text.strip().split('\n')[0:4]
    assonance_scores = []
    alliteration_scores = []

    for line in lines:
        phonemes = get_phonemes(line)
        vowel_bases = get_vowel_phonemes(phonemes)
        vowel_counts = Counter(vowel_bases)
        assonance_score = sum(c for c in vowel_counts.values() if c > 1)
        assonance_scores.append(assonance_score)

        initials = get_initial_phonemes(line)
        initial_counts = Counter(initials)
        alliteration_score = sum(c for c in initial_counts.values() if c > 1)
        alliteration_scores.append(alliteration_score)

    mean_assonance = sum(assonance_scores) / 4
    mean_alliteration = sum(alliteration_scores) / 4

    return mean_assonance, mean_alliteration

# Function to qualify the level of assonance
def qualify_alliteration(mean_score):
    if 1 <= mean_score < 2:
        return 'low'
    elif 2 <= mean_score < 4:
        return 'medium'
    elif mean_score >= 4:
        return 'high'
    else:
        return 'low'  # default for scores below 1

# Function to qualify the level of alliteration

def qualify_assonance(mean_score):
    if 1 <= mean_score <3:
        return 'low'
    elif 3 <= mean_score < 6:
        return 'medium'
    elif mean_score >= 6:
        return 'high'
    else:
        return 'low'  # default for scores below 1

def reward_fn(poem_text, target):
    mean_assonance, mean_alliteration = evaluate_poem(poem_text)
    assonance_level = qualify_assonance(mean_assonance)
    alliteration_level = qualify_alliteration(mean_alliteration)

    actual_levels = {
        'assonance': assonance_level,
        'alliteration': alliteration_level,
    }

    # For comparing actual levels with target traits
    score = 0

    # Compare assonance and alliteration levels
    assonance_match = (target['assonance'] == actual_levels['assonance'])
    alliteration_match = (target['alliteration'] == actual_levels['alliteration'])

    if not assonance_match and not alliteration_match:
        score -= 0.3  # Large penalty for no match

    # Apply small reward for partial match
    elif assonance_match or alliteration_match:
        score += 0  # no reward for partial match

    # Strict match (both must match exactly)
    if assonance_match and alliteration_match:
        score += 0.3  # Full reward for exact match

    return score

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [None]:
import torch
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from tqdm import tqdm
from datasets import Dataset
import itertools
import wandb
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import nltk
nltk.download('averaged_perceptron_tagger_eng')

target_traits = {
    "<alliteration_high> <assonance_high>": {"assonance": "high", "alliteration": "high"},
    "<alliteration_medium> <assonance_medium>": {"assonance": "medium", "alliteration": "medium"},
    "<alliteration_low> <assonance_low>": {"assonance": "low", "alliteration": "low"},}
# Initialize W&B
wandb.init(
    project="gpt2-poetry-ppo",  # Give your project a name
    name="run-001")            # Optional: name your run)

# Initialize model and tokenizer
model_name = "/content/drive/MyDrive/gpt2-poetry/checkpoint-82864"

config = PPOConfig(
    model_name=model_name,
    batch_size=128,  # Reduced to 64 but keep it a multiple of mini_batch_size * gradient_accumulation_steps
    learning_rate=1e-5,
    log_with="wandb",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model.config.pad_token_id = tokenizer.pad_token_id

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move the model to the correct device

# Define prompts
prompts = [
    "<alliteration_high> <assonance_high>",
    "<alliteration_medium> <assonance_medium>",
    "<alliteration_low> <assonance_low>"
]

# Create the dataset
dataset = Dataset.from_dict({
    "input_ids": [tokenizer(prompt, return_tensors="pt").input_ids.squeeze(0).tolist() for prompt in prompts],
    "prompt": prompts
})

# Use itertools.cycle to create an infinite iterator over the dataset
cyclic_dataset = itertools.cycle(dataset)

# Create a DataLoader that will yield the data infinitely
def infinite_data_loader(dataset, batch_size):
    while True:
        batch = []
        for _ in range(batch_size):
            batch.append(next(dataset))  # Fetch the next item in the infinite iterator
        yield batch

train_loader = infinite_data_loader(cyclic_dataset, config.batch_size)

# Initialize PPO Trainer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset)

# Generation settings
generation_kwargs = {
    "num_return_sequences": 1,
    "no_repeat_ngram_size": 2,
    "max_length": 65,  # Total sequence length (prompt + response)
    "top_p": 0.9,
    "top_k": 50,
    "temperature": 0.95,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


0,1
env/reward_mean,▁
env/reward_std,▁
objective/entropy,▁
objective/kl,▁
objective/kl_coef,▁
ppo/learning_rate,▁
ppo/loss/policy,▁
ppo/loss/total,▁
ppo/loss/value,▁
ppo/mean_non_score_reward,▁

0,1
env/reward_mean,0.43203
env/reward_std,0.55286
objective/entropy,195.33463
objective/kl,0.0
objective/kl_coef,0.2
ppo/learning_rate,1e-05
ppo/loss/policy,-0.00455
ppo/loss/total,0.12878
ppo/loss/value,1.33335
ppo/mean_non_score_reward,0.0


In [None]:
from itertools import islice
import nltk
nltk.download('averaged_perceptron_tagger_eng')

# Training loop
num_epochs = 5
steps_per_epoch = 50
training_steps = 0  # Track total steps

for epoch in range(num_epochs):
    print(f"Starting epoch {epoch + 1}/{num_epochs}")
    # Take only `steps_per_epoch` batches from the infinite loader
    for batch in tqdm(islice(train_loader, steps_per_epoch), total=steps_per_epoch, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        query_tensors = [torch.tensor(item["input_ids"]).to(device) for item in batch]
        query_prompts = [item["prompt"] for item in batch]

        response_tensors = []
        decoded_responses = []

        for query in query_tensors:
            response = ppo_trainer.generate(query, **generation_kwargs).squeeze()
            response_only = response[len(query):]
            response_tensors.append(response_only)
            decoded_response = tokenizer.decode(response_only, skip_special_tokens=True)
            decoded_responses.append(decoded_response)

        # Compute rewards
        rewards = []
        for query_text, response_text in zip(query_prompts, decoded_responses):
            reward_score = reward_fn(response_text, target_traits.get(query_text, {}))
            rewards.append(torch.tensor(reward_score, dtype=torch.float32).to(device))

        # PPO step

        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

        batch_dict = {
            "query":query_prompts,
            "response": decoded_responses  # Add decoded responses to the batch_dict
        }

        ppo_trainer.log_stats(stats, batch_dict, rewards)
        training_steps += 1

        if training_steps % 50 == 0:
            model.save_pretrained(f"/content/drive/MyDrive/gpt2-PPO-checkpoint-{training_steps}")
            tokenizer.save_pretrained(f"/content/drive/MyDrive/gpt2-PPO-checkpoint-{training_steps}")


[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


Starting epoch 1/5


Epoch 1/5: 100%|██████████| 50/50 [1:23:07<00:00, 99.76s/it] 


Starting epoch 2/5


Epoch 2/5:   2%|▏         | 1/50 [03:00<2:27:05, 180.11s/it]


KeyboardInterrupt: 