In [1]:
# Install required packages
!pip install -U transformers
!pip install bert-score sentence-transformers rouge-score nltk
!pip install -q -U evaluate

import os
os.environ["WANDB_DISABLED"] = "true"

#Evaluation
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.corpus import cmudict
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from sentence_transformers import SentenceTransformer, util
import evaluate
import random
import transformers
#Transformers
from transformers import (
    T5TokenizerFast,
    T5ForConditionalGeneration,
    Trainer,
    TrainingArguments,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
    BertTokenizer,
    BertForSequenceClassification
)

import torch
import pandas as pd
import numpy as np
from datasets import Dataset
from google.colab import drive

Collecting bert-score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.0.0->bert-score)
  Downloading nvidia_cublas

In [2]:
#Mount drive folder and read data set files
drive.mount('/content/drive')
train_file = 'drive/MyDrive/266/project/train_data.xlsx'
test_file = 'drive/MyDrive/266/project/test_data.xlsx'
val_file = 'drive/MyDrive/266/project/val_data.xlsx'
df_train = pd.read_excel(train_file)
df_test = pd.read_excel(test_file)
df_val = pd.read_excel(val_file)

#Trim dataset and create input_text and target_text
df_train = df_train
df_val = df_val
df_test = df_test
df_train['input_text'] = df_train['line1'].apply(lambda x: f"Given this song lyric line, generate the next song lyric line:: {x}")
df_train['target_text'] = df_train['line2']
df_val['input_text'] = df_val['line1'].apply(lambda x: f"Given this song lyric line, generate the next song lyric line:: {x}")
df_val['target_text'] = df_val['line2']
df_test['input_text'] = df_test['line1'].apply(lambda x: f"Given this song lyric line, generate the next song lyric line:: {x}")
df_test['target_text'] = df_test['line2']
#Create datasets from dfs
dataset = Dataset.from_pandas(df_train[['input_text', 'target_text']])
val_dataset = Dataset.from_pandas(df_val[['input_text', 'target_text']])
test_dataset = Dataset.from_pandas(df_test[['input_text', 'target_text']])

Mounted at /content/drive


In [3]:
import nltk
nltk.download('cmudict')

from nltk.corpus import cmudict
cmu = cmudict.dict()

def word_to_phonemes(word):
    return cmu.get(word.lower(), [['UNK']])[0]          # pick first pronunciation

special = {'PAD', 'BOS', 'EOS', 'UNK'}

# 1) collect phonemes from training targets
phoneme_set = set()
for ln in df_train['target_text']:
    for w in ln.split():
        phoneme_set.update(word_to_phonemes(w))

# 2) remove the four specials if they slipped in
phoneme_set -= special

# 3) start the dict with the specials
phoneme2id = {'PAD': 0, 'BOS': 1, 'EOS': 2, 'UNK': 3}

# 4) add the rest, starting at 4
for i, ph in enumerate(sorted(phoneme_set), start=4):
    phoneme2id[ph] = i

print("vocab size =", len(phoneme2id), "  max id =", max(phoneme2id.values()))
# max id will now be vocab-size − 1


[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Unzipping corpora/cmudict.zip.


vocab size = 73   max id = 72


In [4]:
max_len_ph = 64
def phonemise(line):
    out = []
    for w in line.split():
        out += word_to_phonemes(w)
    return out[:max_len_ph]

def preprocess_rhyme(example):
    ph_seq = phonemise(example['target_text'])
    ph_seq = ph_seq[:max_len_ph - 2]          # room for BOS/EOS
    ph = ['BOS'] + ph_seq + ['EOS']
    ph += ['PAD'] * (max_len_ph - len(ph))
    example['phoneme_ids'] = [phoneme2id.get(p, phoneme2id['UNK']) for p in ph]
    return example

dataset = dataset.map(preprocess_rhyme, load_from_cache_file=False)
val_dataset  = val_dataset.map(preprocess_rhyme, load_from_cache_file=False)
test_dataset = test_dataset.map(preprocess_rhyme, load_from_cache_file=False)


Map:   0%|          | 0/60000 [00:00<?, ? examples/s]

Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

In [5]:
def check(ds):
    v = len(phoneme2id)
    for ex in ds:
        assert all(0 <= idx < v for idx in ex["phoneme_ids"])
    return True

assert check(dataset) and check(val_dataset)
print("✓ every phoneme_id is within 0 …", len(phoneme2id)-1)

def check(ds):
    v = len(phoneme2id)
    for ex in ds:
        assert max(ex['phoneme_ids']) < v
    return True

assert check(dataset) and check(val_dataset)
print("✓ all phoneme_ids are in range (0 …", len(phoneme2id)-1, ")")


✓ every phoneme_id is within 0 … 72
✓ all phoneme_ids are in range (0 … 72 )


In [6]:
import torch.nn as nn
from transformers import T5ForConditionalGeneration
# from transformers.models.t5.modeling_t5 import _generate_square_subsequent_mask # Removed this line

class FlanT5WithRhyme(nn.Module):
    def __init__(self, base_model_name, phoneme_vocab):
        super().__init__()
        self.main = T5ForConditionalGeneration.from_pretrained(base_model_name)
        d_model = self.main.config.d_model          # 768 for flan-t5-base
        # --- rhyme branch ---
        self.ph_embed = nn.Embedding(len(phoneme_vocab), d_model)
        self.rhyme_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead=8), num_layers=1
        )
        # fusion projection
        self.fuse_linear = nn.Linear(d_model, d_model)

    def forward(self, input_ids, attention_mask, labels, phoneme_ids):
      out_main = self.main(
          input_ids=input_ids,
          attention_mask=attention_mask,
          labels=labels,
          output_hidden_states=True,
          return_dict=True
      )
      hidden_main = out_main.decoder_hidden_states[-1]          # (B,L,D)

      # -------- rhyme branch --------
      ph_emb   = self.ph_embed(phoneme_ids).transpose(0, 1)     # (L,B,D)
      tgt_len  = ph_emb.size(0)
      # Use the PyTorch equivalent function
      tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(tgt_len).to(ph_emb.device)

      hidden_rhyme = self.rhyme_decoder(
          ph_emb,
          hidden_main.transpose(0, 1),
          tgt_mask=tgt_mask
      ).transpose(0, 1)                                         # (B,L,D)

      fused  = self.fuse_linear(hidden_main * hidden_rhyme)
      logits = self.main.lm_head(fused) * self.main.config.d_model**-0.5

      loss   = nn.CrossEntropyLoss(ignore_index=-100)(
                  logits.view(-1, logits.size(-1)),
                  labels.view(-1)
              )
      return {"loss": loss, "logits": logits}

In [7]:
model_name = 'google/flan-t5-base'
tokenizer = T5TokenizerFast.from_pretrained(model_name)

def preprocess(example):
    model_input = tokenizer(example['input_text'],
                            padding='max_length', truncation=True, max_length=64)
    labels = tokenizer(example['target_text'],
                       padding='max_length', truncation=True, max_length=64)
    model_input['labels']       = labels['input_ids']
    model_input['phoneme_ids']  = example['phoneme_ids']         # ← keep it
    return model_input

tokenized_dataset = dataset.map(preprocess, batched=True)
tokenized_val_dataset = val_dataset.map(preprocess, batched=True)
tokenized_test_dataset = test_dataset.map(preprocess, batched=True)


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/60000 [00:00<?, ? examples/s]

Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

In [8]:
import torch, math, warnings
from transformers import Trainer, TrainingArguments

model = FlanT5WithRhyme(
    base_model_name = model_name,   # 'google/flan-t5-base'
    phoneme_vocab   = phoneme2id
).to('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------------------------------------------
# 4-A.  Freeze everything in Flan-T5 **except** the last-2 decoder
#       blocks, the lm_head, and the rhyme branch you just added.
# ------------------------------------------------------------------
def freeze_flan_except_last_two(model):
    # A. Handle the base T5
    for name, p in model.main.named_parameters():
        p.requires_grad = (
            'decoder.block.10' in name
            or 'decoder.block.11' in name
            or 'lm_head'         in name
        )
    # B. Keep rhyme branch trainable
    for p in model.ph_embed.parameters():      p.requires_grad = True
    for p in model.rhyme_decoder.parameters(): p.requires_grad = True
    for p in model.fuse_linear.parameters():   p.requires_grad = True

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
    total     = sum(p.numel() for p in model.parameters())/1e6
    print(f"✓ Trainable params = {trainable:.1f} M / {total:.1f} M")


freeze_flan_except_last_two(model)          # <-- call once

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

✓ Trainable params = 52.1 M / 256.1 M


In [10]:

# ------------------------------------------------------------------
# 4-B.  Custom data-collator that keeps phoneme_ids together with the
#       usual input_ids / labels.
# ------------------------------------------------------------------
def rhyme_data_collator(features):
    # All tensor fields must have identical keys
    keys = features[0].keys()
    batch = {k: torch.tensor([f[k] for f in features]) for k in keys}
    return batch


# ------------------------------------------------------------------
# 4-C.  TrainingArguments — keep them small; we only fine-tune ~40 M
#       parameters now.
# ------------------------------------------------------------------
training_args = TrainingArguments(
    output_dir="drive/MyDrive/266/project/flan-t5-rhyme",
    per_device_train_batch_size=50,
    per_device_eval_batch_size=50,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_safetensors=False,
    logging_strategy="steps",
    logging_steps=1000,
)

# ------------------------------------------------------------------
# 4-D.  Fire up HuggingFace Trainer
# ------------------------------------------------------------------
trainer = Trainer(
    model           = model,                 # FlanT5WithRhyme instance
    args            = training_args,
    train_dataset   = tokenized_dataset,     # produced in Part 2/3
    eval_dataset    = tokenized_val_dataset,
    data_collator   = rhyme_data_collator,
)

trainer.train()
trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss
1000,1.0484
2000,0.9857
3000,0.9594


('drive/MyDrive/266/project/flan-t5-rhyme/tokenizer_config.json',
 'drive/MyDrive/266/project/flan-t5-rhyme/special_tokens_map.json',
 'drive/MyDrive/266/project/flan-t5-rhyme/spiece.model',
 'drive/MyDrive/266/project/flan-t5-rhyme/added_tokens.json',
 'drive/MyDrive/266/project/flan-t5-rhyme/tokenizer.json')

In [13]:
# vocab_size = len(phoneme2id)

# def debug_dataset(ds, n=3):
#     for i in range(n):
#         bad = [idx for idx in ds[i]['phoneme_ids'] if idx >= vocab_size or idx < 0]
#         if bad:
#             print(f"❌ Example {i} has bad indices: {bad[:10]} …")
#         else:
#             print(f"✓ Example {i} OK")

# debug_dataset(dataset, n=150)        # scan the first 100 items


# print("Embedding size  :", model.ph_embed.num_embeddings)
# print("Vocab size      :", len(phoneme2id))


# # quick sanity-check over the whole train split
# vocab_size = len(phoneme2id)

# def check_dataset(ds):
#     for i, ex in enumerate(ds):
#         bad = [idx for idx in ex["phoneme_ids"] if idx >= vocab_size or idx < 0]
#         if bad:
#             print(f"Example {i} has bad index {bad[0]}  (vocab={vocab_size})")
#             return False
#     return True

# assert check_dataset(dataset), "out-of-range phoneme id found"
# assert check_dataset(val_dataset), "out-of-range phoneme id found"
# print("✓ all phoneme_ids are in range")


In [14]:
# # Build the reverse lookup once
# id2ph = {v:k for k, v in phoneme2id.items()}

# def show_bad_example(ds):
#     vocab_size = len(phoneme2id)
#     for i, ex in enumerate(ds):
#         for idx in ex["phoneme_ids"]:
#             if idx >= vocab_size or idx < 0:
#                 print(f"Example {i}: bad idx = {idx}  →  phoneme = {id2ph.get(idx, '???')}")
#                 print("Target text:", ex["target_text"])
#                 return
# show_bad_example(dataset)


In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
BASE_BACKBONE = "google/flan-t5-base"          # d_model = 768
BEST_CKPT = "drive/MyDrive/266/project/flan-t5-rhyme"   # <-- change if needed

# ─────────────────────────────────────────────────────────────
# Step 1 ▸ Reload the best checkpoint
# ─────────────────────────────────────────────────────────────
print("Loading model and tokenizer …")
tokenizer = T5TokenizerFast.from_pretrained(BEST_CKPT)
model     = FlanT5WithRhyme(
               base_model_name=BASE_BACKBONE,        # <- we load weights, no need to redl base
               phoneme_vocab=phoneme2id
           )
model.load_state_dict(torch.load(f"{BEST_CKPT}/pytorch_model.bin",
                                 map_location='cpu'))
model.to(device).eval()
print("✓ model on", device)

Loading model and tokenizer …
✓ model on cuda


In [10]:
def generate_plain(line1, max_new_tokens=30):
    prompt = f"Given this song lyric line, generate the next song lyric line:: {line1}"
    ids = tokenizer(prompt, return_tensors='pt').to(device)
    gen = model.main.generate(**ids, max_new_tokens=max_new_tokens,
                         do_sample=True, temperature=0.8, num_beams=1)
    return tokenizer.decode(gen[0], skip_special_tokens=True)

example_line = "I keep the city on my back like a long day"
print("\nPlain generation test:")
print("  IN :", example_line)
print("  OUT:", generate_plain(example_line))


Plain generation test:
  IN : I keep the city on my back like a long day
  OUT: "ok," i spun the earth on my back," i suck."," i tuck the suitcase on my back,"


In [11]:
import re, nltk, torch
from transformers import LogitsProcessor
nltk.download("cmudict", quiet=True)
cmu = nltk.corpus.cmudict.dict()
_vowel = re.compile(r"[AEIOU]")

def cmu_tail(word):
    phones = cmu.get(word.lower())
    if not phones: return None
    phones = phones[0]
    for i in range(len(phones)-1, -1, -1):
        if _vowel.search(phones[i]):
            return tuple(phones[i:])
    return None

from transformers import LogitsProcessor

class EndRhymeFilter(LogitsProcessor):
    def __init__(self, target_tail, tokenizer, top_k=50):
        self.target_tail = target_tail
        self.tok = tokenizer
        self.top_k = top_k
    def __call__(self, input_ids, scores):
        if self.target_tail is None:                 # no rhyme target
            return scores
        top_vals, top_idx = scores.topk(self.top_k, dim=-1)
        keep = torch.zeros_like(top_vals, dtype=torch.bool)
        for j, tok_id in enumerate(top_idx[0]):
            word = self.tok.decode([tok_id]).strip().lower()
            if cmu_tail(word) == self.target_tail:
                keep[0, j] = True
        if keep.any():
            scores[:] = -1e9
            scores[0, top_idx[0, keep[0]]] = top_vals[0, keep[0]]
        return scores

def generate_rhymed(line1, max_new_tokens=30, temperature=0.8, top_k=50):
    last_word   = line1.strip().split()[-1]
    target_tail = cmu_tail(last_word)
    prompt = f"Given this song lyric line, generate the next song lyric line:: {line1}"
    ids = tokenizer(prompt, return_tensors='pt').to(device)
    rh_filter = EndRhymeFilter(target_tail, tokenizer, top_k)
    gen = model.main.generate(**ids,
                         max_new_tokens=max_new_tokens,
                         temperature=temperature,
                         do_sample=True,
                         num_beams=1,
                         logits_processor=[rh_filter])
    return tokenizer.decode(gen[0], skip_special_tokens=True)

print("\nRhyme-aware generation test:")
print("  IN :", example_line)
print("  OUT:", generate_rhymed(example_line))


Rhyme-aware generation test:
  IN : I keep the city on my back like a long day
  OUT: away play away play away play away play away play away play away play away play away play away play away play away play away play away play away play


In [13]:
from collections import defaultdict

batch_size = 64
results = []

prompts   = [f"Given this song lyric line, generate the next song lyric line:: {l1}"
             for l1 in df_test['line1']]
true_next = list(df_test['line2'])

for i in range(0, len(prompts), batch_size):
    batch_prompts = prompts[i:i+batch_size]
    batch_true    = true_next[i:i+batch_size]

    # ── 1-A. bucket by rhyme tail ───────────────────────────────
    buckets = defaultdict(list)          # tail → list of indices
    for idx, p in enumerate(batch_prompts):
        tail = cmu_tail(p.split()[-1])
        buckets[tail].append(idx)

    # ── 1-B. run one generate() per bucket ──────────────────────
    enc = tokenizer(batch_prompts, return_tensors='pt',
                    padding=True, truncation=True).to(device)

    for tail, idxs in buckets.items():
        filt = EndRhymeFilter(tail, tokenizer, top_k=50)
        outs = model.main.generate(
            enc['input_ids'][idxs],
            attention_mask=enc['attention_mask'][idxs],
            max_new_tokens=30,
            temperature=0.8,
            do_sample=True,
            num_beams=1,
            logits_processor=[filt],
        )
        for k, out_ids in zip(idxs, outs):
            results.append({
                "input"     : batch_prompts[k],
                "actual"    : batch_true[k],
                "generated" : tokenizer.decode(out_ids, skip_special_tokens=True)
            })


In [14]:
# Choose how many examples you want to print
num_samples = 10

# Randomly sample from results
sampled = random.sample(results, k=min(num_samples, len(results)))

# Print each example
for i, r in enumerate(sampled, 1):
    print(f"--- Example {i} ---")
    print(f"Prompt:         {r['input']}")
    print(f"Actual Line 2:  {r['actual']}")
    print(f"Generated Line: {r['generated']}\n")

--- Example 1 ---
Prompt:         Given this song lyric line, generate the next song lyric line:: (Jesus walks for them)
Actual Line 2:  To the victims of welfare, feel we livin' in Hell here, hell yeah
Generated Line: "go to the mere","

--- Example 2 ---
Prompt:         Given this song lyric line, generate the next song lyric line:: Had this bitch on a ride, like roller coasters, Coney Island
Actual Line 2:  All this shit be funny, night show, night show, Jimmy Fallon
Generated Line: "I tamed the Earth and listened and listened and listened and listened and listened and listened and listened and

--- Example 3 ---
Prompt:         Given this song lyric line, generate the next song lyric line:: She so mo'fuckin' wet use my dick as a squeegee
Actual Line 2:  My extended clip on me, put dick in my nini
Generated Line: country cookie origami baby origami baby origami baby origami baby oriami baby oriami baby oriami baby oriami baby

--- Example 4 ---
Prompt:         Given this song lyric 

In [17]:
# ————————————————————————————————————————————
# 1. Pick the 10th test item (Python is 0-based → index 9)
# ————————————————————————————————————————————
line1 = df_test.loc[9, "line1"]
true_line2 = df_test.loc[9, "line2"]

print("INPUT  :", line1)
print("TARGET :", true_line2)

# ————————————————————————————————————————————
# 2. Build the rhyme filter for this single example
# ————————————————————————————————————————————
tail = cmu_tail(line1.split()[-1])             # phonetic tail of last word
rhyme_filter = EndRhymeFilter(tail, tokenizer, top_k=50)

# ————————————————————————————————————————————
# 3. Tokenise and generate with the rhyme-aware model
# ————————————————————————————————————————————
prompt = f"Given this song lyric line, generate the next song lyric line:: {line1}"
enc = tokenizer(prompt, return_tensors="pt").to(device)

with torch.no_grad(), torch.cuda.amp.autocast():
    gen_ids = model.main.generate(
        **enc,
        max_new_tokens=30,
        do_sample=True,
        temperature=0.8,
        num_beams=1,
        logits_processor=[rhyme_filter],
    )

generated_line2 = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

print("\nRHYME-AWARE OUTPUT:", generated_line2)


INPUT  : Dont nun round dis bitch move unless you get approval
TARGET : I need a driver to drive me around how I maneuver


  with torch.no_grad(), torch.cuda.amp.autocast():


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [15]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.corpus import cmudict
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from sentence_transformers import SentenceTransformer, util
import torch
import numpy as np
import time
from typing import Dict, List, Any
import warnings

# Download required NLTK data
nltk.download('punkt', quiet=True)
nltk.download('cmudict', quiet=True)

class ComprehensiveEvaluator:
    def __init__(self, sentence_model_name: str = 'all-MiniLM-L6-v2', device: str = None):
        """
        Initialize evaluator with configurable models and device

        Args:
            sentence_model_name: Name of sentence transformer model to use
            device: Device to run models on ('cuda', 'cpu', or None for auto)
        """
        # Set device
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device

        print(f"Initializing evaluator on device: {self.device}")

        # Initialize models with error handling
        try:
            self.sentence_model = SentenceTransformer(sentence_model_name, device=self.device)
            print(f"✓ Loaded Sentence-BERT model: {sentence_model_name}")
        except Exception as e:
            warnings.warn(f"Failed to load Sentence-BERT model: {e}")
            self.sentence_model = None

        # Initialize ROUGE scorer
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

        # Initialize CMU dictionary with error handling
        try:
            self.cmu_dict = cmudict.dict()
            print(f"✓ Loaded CMU dictionary with {len(self.cmu_dict)} entries")
        except Exception as e:
            warnings.warn(f"Failed to load CMU dictionary: {e}")
            self.cmu_dict = {}

        # Initialize BLEU smoothing function
        self.bleu_smoothing = SmoothingFunction().method1

        # Cache for performance
        self._rhyme_cache = {}

    def calculate_bleu_scores(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate BLEU scores with improved handling"""
        bleu_scores = []
        valid_count = 0
        empty_count = 0

        for result in test_results:
            actual = result['actual'].strip()
            generated = result['generated'].strip()

            if not generated:
                empty_count += 1
                bleu_scores.append(0.0)
            elif not actual:
                empty_count += 1
                bleu_scores.append(0.0)
            else:
                valid_count += 1
                reference = [actual.split()]
                candidate = generated.split()

                score = sentence_bleu(reference, candidate, smoothing_function=self.bleu_smoothing)
                bleu_scores.append(score)

        return {
            'individual_scores': bleu_scores,
            'average': np.mean(bleu_scores),
            'std': np.std(bleu_scores),
            'min': np.min(bleu_scores),
            'max': np.max(bleu_scores),
            'valid_count': valid_count,
            'empty_count': empty_count
        }

    def calculate_rouge_scores(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate ROUGE scores with enhanced tracking"""
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []

        valid_count = 0
        empty_count = 0

        for result in test_results:
            actual = result['actual'].strip()
            generated = result['generated'].strip()

            if not generated or not actual:
                empty_count += 1
                rouge1_scores.append(0.0)
                rouge2_scores.append(0.0)
                rougeL_scores.append(0.0)
            else:
                valid_count += 1
                scores = self.rouge_scorer.score(actual, generated)
                rouge1_scores.append(scores['rouge1'].fmeasure)
                rouge2_scores.append(scores['rouge2'].fmeasure)
                rougeL_scores.append(scores['rougeL'].fmeasure)

        return {
            'rouge1': {
                'individual_scores': rouge1_scores,
                'average': np.mean(rouge1_scores),
                'std': np.std(rouge1_scores),
                'valid_count': valid_count,
                'empty_count': empty_count
            },
            'rouge2': {
                'individual_scores': rouge2_scores,
                'average': np.mean(rouge2_scores),
                'std': np.std(rouge2_scores),
                'valid_count': valid_count,
                'empty_count': empty_count
            },
            'rougeL': {
                'individual_scores': rougeL_scores,
                'average': np.mean(rougeL_scores),
                'std': np.std(rougeL_scores),
                'valid_count': valid_count,
                'empty_count': empty_count
            }
        }

    def calculate_bert_scores(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate BERTScore with batch processing"""
        candidates = [result['generated'] for result in test_results]
        references = [result['actual'] for result in test_results]

        try:
            # Calculate BERTScore with device specification
            P, R, F1 = bert_score(candidates, references, lang="en", verbose=False, device=self.device)

            return {
                'precision': {
                    'average': P.mean().item(),
                    'std': P.std().item(),
                    'individual_scores': P.tolist()
                },
                'recall': {
                    'average': R.mean().item(),
                    'std': R.std().item(),
                    'individual_scores': R.tolist()
                },
                'f1': {
                    'average': F1.mean().item(),
                    'std': F1.std().item(),
                    'individual_scores': F1.tolist()
                }
            }
        except Exception as e:
            warnings.warn(f"BERTScore calculation failed: {e}")
            return None

    def calculate_sentence_similarity(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate sentence similarity with error handling"""
        if self.sentence_model is None:
            warnings.warn("Sentence-BERT model not available")
            return None

        actual_lines = [result['actual'] for result in test_results]
        generated_lines = [result['generated'] for result in test_results]

        try:
            # Encode all sentences with batch processing
            actual_embeddings = self.sentence_model.encode(actual_lines, convert_to_tensor=True, show_progress_bar=False)
            generated_embeddings = self.sentence_model.encode(generated_lines, convert_to_tensor=True, show_progress_bar=False)

            # Calculate cosine similarity
            cosine_scores = util.pytorch_cos_sim(actual_embeddings, generated_embeddings)

            # Extract diagonal (pairwise similarities)
            similarities = [cosine_scores[i][i].item() for i in range(len(actual_lines))]

            return {
                'individual_scores': similarities,
                'average': np.mean(similarities),
                'std': np.std(similarities),
                'min': np.min(similarities),
                'max': np.max(similarities)
            }
        except Exception as e:
            warnings.warn(f"Sentence similarity calculation failed: {e}")
            return None

    def get_last_word(self, line: str) -> str:
        """Extract last word with improved cleaning"""
        import re
        # Use regex to better handle punctuation and contractions
        words = re.findall(r"\b[a-zA-Z]+(?:'[a-zA-Z]+)?\b", line.lower())
        return words[-1] if words else ""

    def get_rhyme_part_cmu(self, word: str) -> List[str]:
        """Extract rhyming part with caching"""
        if word in self._rhyme_cache:
            return self._rhyme_cache[word]

        if word in self.cmu_dict:
            pronunciations = self.cmu_dict[word]
            if pronunciations:
                # Get the part after the last stressed vowel
                pron = pronunciations[0]
                for i in range(len(pron) - 1, -1, -1):
                    if pron[i][-1].isdigit():  # Stressed vowel
                        result = pron[i:]
                        self._rhyme_cache[word] = result
                        return result

        self._rhyme_cache[word] = None
        return None

    def analyze_rhymes_cmu(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Enhanced rhyme analysis with better statistics"""
        phonetic_rhymes = 0
        near_rhymes = 0  # Rhymes with similar endings
        total_valid = 0
        total_processed = 0

        rhyme_details = []

        for i, result in enumerate(test_results):
            input_last = self.get_last_word(result['input'])
            generated_last = self.get_last_word(result['generated'])

            total_processed += 1

            if input_last and generated_last:
                input_rhyme = self.get_rhyme_part_cmu(input_last)
                generated_rhyme = self.get_rhyme_part_cmu(generated_last)

                if input_rhyme and generated_rhyme:
                    total_valid += 1

                    is_perfect_rhyme = input_rhyme == generated_rhyme
                    is_near_rhyme = False

                    # Check for near rhymes (last 2 phonemes match)
                    if not is_perfect_rhyme and len(input_rhyme) >= 2 and len(generated_rhyme) >= 2:
                        is_near_rhyme = input_rhyme[-2:] == generated_rhyme[-2:]

                    if is_perfect_rhyme:
                        phonetic_rhymes += 1
                    elif is_near_rhyme:
                        near_rhymes += 1

                    rhyme_details.append({
                        'example_index': i,
                        'input_word': input_last,
                        'generated_word': generated_last,
                        'input_phonemes': input_rhyme,
                        'generated_phonemes': generated_rhyme,
                        'is_perfect_rhyme': is_perfect_rhyme,
                        'is_near_rhyme': is_near_rhyme
                    })

        return {
            'perfect_rhyme_rate': phonetic_rhymes / total_valid if total_valid > 0 else 0,
            'near_rhyme_rate': near_rhymes / total_valid if total_valid > 0 else 0,
            'total_rhyme_rate': (phonetic_rhymes + near_rhymes) / total_valid if total_valid > 0 else 0,
            'perfect_rhymes': phonetic_rhymes,
            'near_rhymes': near_rhymes,
            'total_valid': total_valid,
            'total_processed': total_processed,
            'coverage': total_valid / total_processed if total_processed > 0 else 0,
            'details': rhyme_details
        }

    def calculate_additional_metrics(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate additional rap-specific metrics"""

        # Syllable analysis (approximate)
        def count_syllables(word):
            # Simple syllable counting heuristic
            word = word.lower()
            count = 0
            vowels = "aeiouy"
            if word[0] in vowels:
                count += 1
            for i in range(1, len(word)):
                if word[i] in vowels and word[i-1] not in vowels:
                    count += 1
            if word.endswith("e"):
                count -= 1
            if count == 0:
                count += 1
            return count

        syllable_diffs = []
        word_diversity_scores = []

        for result in test_results:
            actual_words = result['actual'].split()
            generated_words = result['generated'].split()

            # Syllable analysis
            if actual_words and generated_words:
                actual_syllables = sum(count_syllables(word) for word in actual_words)
                generated_syllables = sum(count_syllables(word) for word in generated_words)
                syllable_diffs.append(abs(actual_syllables - generated_syllables))
            else:
                syllable_diffs.append(0)

            # Word diversity (unique words / total words)
            if generated_words:
                diversity = len(set(generated_words)) / len(generated_words)
                word_diversity_scores.append(diversity)
            else:
                word_diversity_scores.append(0.0)

        return {
            'syllable_similarity': {
                'average_diff': np.mean(syllable_diffs),
                'std_diff': np.std(syllable_diffs)
            },
            'word_diversity': {
                'average': np.mean(word_diversity_scores),
                'std': np.std(word_diversity_scores)
            }
        }

    def evaluate_comprehensive(self, test_results: List[Dict], show_progress: bool = True) -> Dict[str, Any]:
        """Enhanced comprehensive evaluation with timing and progress"""
        start_time = time.time()

        print("=" * 80)
        print("COMPREHENSIVE EVALUATION RESULTS")
        print("=" * 80)

        # Basic statistics
        total_examples = len(test_results)
        empty_generations = sum(1 for r in test_results if not r['generated'].strip())

        print(f"Dataset Statistics:")
        print(f"  Total Examples: {total_examples}")
        print(f"  Empty Generations: {empty_generations} ({empty_generations/total_examples:.1%})")
        print()

        # Calculate all metrics with timing
        results = {}

        if show_progress:
            print("Computing metrics...")

        # Traditional NLP metrics
        if show_progress: print("  • BLEU scores...")
        results['bleu'] = self.calculate_bleu_scores(test_results)

        if show_progress: print("  • ROUGE scores...")
        results['rouge'] = self.calculate_rouge_scores(test_results)

        if show_progress: print("  • BERTScore...")
        results['bert_score'] = self.calculate_bert_scores(test_results)

        # Sentence-level similarity
        if show_progress: print("  • Sentence similarity...")
        results['sentence_similarity'] = self.calculate_sentence_similarity(test_results)

        # Rhyme analysis
        if show_progress: print("  • Rhyme analysis...")
        results['cmu_rhyme'] = self.analyze_rhymes_cmu(test_results)

        # Additional metrics
        if show_progress: print("  • Additional metrics...")
        results['additional_metrics'] = self.calculate_additional_metrics(test_results)

        # Length analysis
        results['length_analysis'] = self.calculate_length_similarity(test_results)

        # Display results with enhanced formatting
        self._display_results(results, total_examples, empty_generations)

        execution_time = time.time() - start_time
        print(f"\nEvaluation completed in {execution_time:.2f} seconds")
        print("=" * 80)

        # Add metadata
        results['metadata'] = {
            'total_examples': total_examples,
            'empty_generations': empty_generations,
            'execution_time': execution_time,
            'device_used': self.device
        }

        return results

    def _display_results(self, results: Dict, total_examples: int, empty_generations: int):
        """Enhanced result display with better formatting"""

        print("\n" + "="*60)
        print("TRADITIONAL NLP METRICS")
        print("="*60)

        # BLEU
        bleu = results['bleu']
        print(f"BLEU Score:")
        print(f"  Average: {bleu['average']:.4f} (±{bleu['std']:.4f})")
        print(f"  Range: {bleu['min']:.4f} - {bleu['max']:.4f}")
        print(f"  Valid/Empty: {bleu['valid_count']}/{bleu['empty_count']}")

        # ROUGE
        rouge = results['rouge']
        print(f"\nROUGE Scores:")
        print(f"  ROUGE-1: {rouge['rouge1']['average']:.4f} (±{rouge['rouge1']['std']:.4f})")
        print(f"  ROUGE-2: {rouge['rouge2']['average']:.4f} (±{rouge['rouge2']['std']:.4f})")
        print(f"  ROUGE-L: {rouge['rougeL']['average']:.4f} (±{rouge['rougeL']['std']:.4f})")
        print(f"  Valid/Empty: {rouge['rouge1']['valid_count']}/{rouge['rouge1']['empty_count']}")

        # BERTScore
        if results['bert_score']:
            bert = results['bert_score']
            print(f"\nBERTScore:")
            print(f"  F1: {bert['f1']['average']:.4f} (±{bert['f1']['std']:.4f})")
            print(f"  Precision: {bert['precision']['average']:.4f} (±{bert['precision']['std']:.4f})")
            print(f"  Recall: {bert['recall']['average']:.4f} (±{bert['recall']['std']:.4f})")

        # Sentence similarity
        if results['sentence_similarity']:
            sent_sim = results['sentence_similarity']
            print("\n" + "="*60)
            print("SENTENCE-LEVEL SEMANTIC SIMILARITY")
            print("="*60)
            print(f"Sentence-BERT Cosine Similarity:")
            print(f"  Average: {sent_sim['average']:.4f} (±{sent_sim['std']:.4f})")
            print(f"  Range: {sent_sim['min']:.4f} - {sent_sim['max']:.4f}")

        # Rhyme analysis
        rhyme = results['cmu_rhyme']
        print("\n" + "="*60)
        print("RHYME ANALYSIS")
        print("="*60)
        print(f"CMU Dictionary Phonetic Analysis:")
        print(f"  Perfect Rhyme Rate: {rhyme['perfect_rhyme_rate']:.2%}")
        print(f"  Near Rhyme Rate: {rhyme['near_rhyme_rate']:.2%}")
        print(f"  Total Rhyme Rate: {rhyme['total_rhyme_rate']:.2%}")
        print(f"  Dictionary Coverage: {rhyme['coverage']:.1%} ({rhyme['total_valid']}/{rhyme['total_processed']})")

        # Additional metrics
        additional = results['additional_metrics']
        print("\n" + "="*60)
        print("RAP-SPECIFIC METRICS")
        print("="*60)
        print(f"Syllable Similarity:")
        print(f"  Average Difference: {additional['syllable_similarity']['average_diff']:.2f} syllables")
        print(f"Word Diversity:")
        print(f"  Average: {additional['word_diversity']['average']:.3f}")

        # Length analysis
        length = results['length_analysis']
        print(f"\nLength Analysis:")
        print(f"  Average Length Difference: {length['average_length_diff']:.2f} words")
        print(f"  Average Length Ratio: {length['average_length_ratio']:.2f}")

    def calculate_length_similarity(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Enhanced length analysis"""
        length_diffs = []
        length_ratios = []

        for result in test_results:
            actual_len = len(result['actual'].split())
            generated_len = len(result['generated'].split())

            length_diffs.append(abs(actual_len - generated_len))

            if actual_len > 0:
                length_ratios.append(generated_len / actual_len)
            else:
                length_ratios.append(0.0)

        return {
            'average_length_diff': np.mean(length_diffs),
            'std_length_diff': np.std(length_diffs),
            'average_length_ratio': np.mean(length_ratios),
            'std_length_ratio': np.std(length_ratios),
            'individual_diffs': length_diffs,
            'individual_ratios': length_ratios
        }

In [16]:
evaluator = ComprehensiveEvaluator(device=device)
eval_report = evaluator.evaluate_comprehensive(results, show_progress=True)


Initializing evaluator on device: cuda


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✓ Loaded Sentence-BERT model: all-MiniLM-L6-v2
✓ Loaded CMU dictionary with 123455 entries
COMPREHENSIVE EVALUATION RESULTS
Dataset Statistics:
  Total Examples: 15000
  Empty Generations: 0 (0.0%)

Computing metrics...
  • BLEU scores...
  • ROUGE scores...
  • BERTScore...


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)


  • Sentence similarity...


  return forward_call(*args, **kwargs)


  • Rhyme analysis...
  • Additional metrics...

TRADITIONAL NLP METRICS
BLEU Score:
  Average: 0.0030 (±0.0130)
  Range: 0.0000 - 0.7260
  Valid/Empty: 15000/0

ROUGE Scores:
  ROUGE-1: 0.0354 (±0.0813)
  ROUGE-2: 0.0043 (±0.0397)
  ROUGE-L: 0.0344 (±0.0789)
  Valid/Empty: 15000/0

BERTScore:
  F1: 0.7966 (±0.0296)
  Precision: 0.7854 (±0.0470)
  Recall: 0.8097 (±0.0234)

SENTENCE-LEVEL SEMANTIC SIMILARITY
Sentence-BERT Cosine Similarity:
  Average: 0.1297 (±0.0969)
  Range: -0.1493 - 0.9398

RHYME ANALYSIS
CMU Dictionary Phonetic Analysis:
  Perfect Rhyme Rate: 32.61%
  Near Rhyme Rate: 0.02%
  Total Rhyme Rate: 32.62%
  Dictionary Coverage: 76.3% (11443/15000)

RAP-SPECIFIC METRICS
Syllable Similarity:
  Average Difference: 13.13 syllables
Word Diversity:
  Average: 0.795

Length Analysis:
  Average Length Difference: 7.99 words
  Average Length Ratio: 1.55

Evaluation completed in 87.45 seconds


In [None]:
#Save Model & Tokenizer
trainer.save_model('drive/MyDrive/266/project/models/flan-t5-finetuned_best_prompt')
tokenizer.save_pretrained('drive/MyDrive/266/project/models/flan-t5-finetuned_best_prompt')

In [None]:
model_path = 'drive/MyDrive/266/project/models/flan-t5-finetuned_best_prompt'
tokenizer = T5TokenizerFast.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.eval()

In [None]:
results = []

#Add Prompts
prompts = [f"Given this song lyric line, generate the next song lyric line:  {row['line1']}" for _, row in df_test.iterrows()]
true_lines = [row['line2'] for _, row in df_test.iterrows()]

batch_size = 16

#Process in batches
for i in range(0, len(prompts), batch_size):
    prompt_batch = prompts[i:i+batch_size]
    true_line_batch = true_lines[i:i+batch_size]

    #Tokenize as batch
    inputs = tokenizer(prompt_batch, return_tensors="pt", padding=True, truncation=True)

    #Generate all predictions
    output_ids = model.generate(
        inputs['input_ids'],
        max_new_tokens=30,
        temperature=0.8,
        do_sample=True,
        num_beams=1
    )

    #Decode outputs
    generated_lines = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    #Store results
    for prompt, true_line, gen_line in zip(prompt_batch, true_line_batch, generated_lines):
        results.append({
            "prompt": prompt,
            "actual_line2": true_line,
            "generated_line2": gen_line
        })


In [None]:
# df = pd.DataFrame(results)

# # Then save to CSV
# df.to_csv('drive/MyDrive/Colab Notebooks/w266/Project/results_ft_best_prompt.csv', index=False)

In [None]:
# Choose how many examples you want to print
num_samples = 10

# Randomly sample from results
sampled = random.sample(results, k=min(num_samples, len(results)))

# Print each example
for i, r in enumerate(sampled, 1):
    print(f"--- Example {i} ---")
    print(f"Prompt:         {r['prompt']}")
    print(f"Actual Line 2:  {r['actual_line2']}")
    print(f"Generated Line: {r['generated_line2']}\n")

In [None]:
# Download required NLTK data
nltk.download('punkt')
nltk.download('cmudict')

class ComprehensiveEvaluator:
    def __init__(self):
        # Initialize Sentence-BERT model
        self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

        # Initialize ROUGE scorer
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

        # Initialize CMU dictionary
        self.cmu_dict = cmudict.dict()

        # Initialize BLEU smoothing function
        self.bleu_smoothing = SmoothingFunction().method1

    def calculate_bleu_scores(self, test_results):
        """Calculate BLEU scores for all test results"""
        bleu_scores = []

        for result in test_results:
            reference = [result['actual'].split()]
            candidate = result['generated'].split()

            if candidate:  # Only calculate if generation is not empty
                score = sentence_bleu(reference, candidate, smoothing_function=self.bleu_smoothing)
                bleu_scores.append(score)
            else:
                bleu_scores.append(0.0)

        return {
            'individual_scores': bleu_scores,
            'average': np.mean(bleu_scores),
            'std': np.std(bleu_scores),
            'min': np.min(bleu_scores),
            'max': np.max(bleu_scores)
        }

    def calculate_rouge_scores(self, test_results):
        """Calculate ROUGE scores for all test results"""
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []

        for result in test_results:
            if result['generated'].strip():  # Only calculate if generation is not empty
                scores = self.rouge_scorer.score(result['actual'], result['generated'])
                rouge1_scores.append(scores['rouge1'].fmeasure)
                rouge2_scores.append(scores['rouge2'].fmeasure)
                rougeL_scores.append(scores['rougeL'].fmeasure)
            else:
                rouge1_scores.append(0.0)
                rouge2_scores.append(0.0)
                rougeL_scores.append(0.0)

        return {
            'rouge1': {
                'individual_scores': rouge1_scores,
                'average': np.mean(rouge1_scores),
                'std': np.std(rouge1_scores)
            },
            'rouge2': {
                'individual_scores': rouge2_scores,
                'average': np.mean(rouge2_scores),
                'std': np.std(rouge2_scores)
            },
            'rougeL': {
                'individual_scores': rougeL_scores,
                'average': np.mean(rougeL_scores),
                'std': np.std(rougeL_scores)
            }
        }

    def calculate_bert_scores(self, test_results):
        """Calculate BERTScore for all test results"""
        candidates = [result['generated'] for result in test_results]
        references = [result['actual'] for result in test_results]

        # Calculate BERTScore
        P, R, F1 = bert_score(candidates, references, lang="en", verbose=False)

        return {
            'precision': {
                'average': P.mean().item(),
                'std': P.std().item(),
                'individual_scores': P.tolist()
            },
            'recall': {
                'average': R.mean().item(),
                'std': R.std().item(),
                'individual_scores': R.tolist()
            },
            'f1': {
                'average': F1.mean().item(),
                'std': F1.std().item(),
                'individual_scores': F1.tolist()
            }
        }

    def calculate_sentence_similarity(self, test_results):
        """Calculate sentence-level cosine similarity using Sentence-BERT"""
        actual_lines = [result['actual'] for result in test_results]
        generated_lines = [result['generated'] for result in test_results]

        # Encode all sentences
        actual_embeddings = self.sentence_model.encode(actual_lines, convert_to_tensor=True)
        generated_embeddings = self.sentence_model.encode(generated_lines, convert_to_tensor=True)

        # Calculate cosine similarity
        cosine_scores = util.pytorch_cos_sim(actual_embeddings, generated_embeddings)

        # Extract diagonal (pairwise similarities)
        similarities = [cosine_scores[i][i].item() for i in range(len(actual_lines))]

        return {
            'individual_scores': similarities,
            'average': np.mean(similarities),
            'std': np.std(similarities),
            'min': np.min(similarities),
            'max': np.max(similarities)
        }

    def get_last_word(self, line):
        """Extract the last word from a line for rhyme analysis"""
        words = line.lower().strip().split()
        if words:
            # Remove punctuation from last word
            last_word = ''.join(c for c in words[-1] if c.isalpha())
            return last_word
        return ""

    def get_rhyme_part_cmu(self, word):
        """Extract the rhyming part using CMU dictionary"""
        if word in self.cmu_dict:
            pronunciations = self.cmu_dict[word]
            if pronunciations:
                # Get the part after the last stressed vowel
                pron = pronunciations[0]
                for i in range(len(pron) - 1, -1, -1):
                    if pron[i][-1].isdigit():  # Stressed vowel
                        return pron[i:]
        return None

    def analyze_rhymes_cmu(self, test_results):
        """Analyze rhymes using CMU dictionary"""
        phonetic_rhymes = 0
        total_valid = 0

        rhyme_details = []

        for i, result in enumerate(test_results):
            input_last = self.get_last_word(result['input'])
            generated_last = self.get_last_word(result['generated'])

            if input_last and generated_last:
                input_rhyme = self.get_rhyme_part_cmu(input_last)
                generated_rhyme = self.get_rhyme_part_cmu(generated_last)

                if input_rhyme and generated_rhyme:
                    total_valid += 1

                    is_rhyme = input_rhyme == generated_rhyme
                    if is_rhyme:
                        phonetic_rhymes += 1

                    rhyme_details.append({
                        'example_index': i,
                        'input_word': input_last,
                        'generated_word': generated_last,
                        'input_phonemes': input_rhyme,
                        'generated_phonemes': generated_rhyme,
                        'is_rhyme': is_rhyme
                    })

        return {
            'phonetic_rhyme_rate': phonetic_rhymes / total_valid if total_valid > 0 else 0,
            'phonetic_rhymes': phonetic_rhymes,
            'total_valid': total_valid,
            'details': rhyme_details
        }

    def calculate_length_similarity(self, test_results):
        """Calculate length similarity between actual and generated lines"""
        length_diffs = []
        length_ratios = []

        for result in test_results:
            actual_len = len(result['actual'].split())
            generated_len = len(result['generated'].split())

            length_diffs.append(abs(actual_len - generated_len))

            if actual_len > 0:
                length_ratios.append(generated_len / actual_len)
            else:
                length_ratios.append(0.0)

        return {
            'average_length_diff': np.mean(length_diffs),
            'std_length_diff': np.std(length_diffs),
            'average_length_ratio': np.mean(length_ratios),
            'std_length_ratio': np.std(length_ratios)
        }

    def evaluate_comprehensive(self, test_results):
        """Run comprehensive evaluation on test results"""
        print("=" * 80)
        print("COMPREHENSIVE EVALUATION RESULTS")
        print("=" * 80)

        # Basic statistics
        total_examples = len(test_results)
        empty_generations = sum(1 for r in test_results if not r['generated'].strip())

        print(f"Dataset Statistics:")
        print(f"  Total Examples: {total_examples}")
        print(f"  Empty Generations: {empty_generations} ({empty_generations/total_examples:.1%})")
        print()

        # Calculate all metrics
        print("Computing metrics...")

        # Traditional NLP metrics
        bleu_results = self.calculate_bleu_scores(test_results)
        rouge_results = self.calculate_rouge_scores(test_results)
        bert_results = self.calculate_bert_scores(test_results)

        # Sentence-level similarity
        sentence_sim_results = self.calculate_sentence_similarity(test_results)

        # Rhyme analysis
        # dandelion_rhyme_results = self.analyze_rhymes_dandelion(test_results)
        cmu_rhyme_results = self.analyze_rhymes_cmu(test_results)

        # Length analysis
        length_results = self.calculate_length_similarity(test_results)

        # Display results
        print("\n" + "="*60)
        print("TRADITIONAL NLP METRICS")
        print("="*60)

        print(f"BLEU Score:")
        print(f"  Average: {bleu_results['average']:.4f} (±{bleu_results['std']:.4f})")
        print(f"  Range: {bleu_results['min']:.4f} - {bleu_results['max']:.4f}")

        print(f"\nROUGE Scores:")
        print(f"  ROUGE-1: {rouge_results['rouge1']['average']:.4f} (±{rouge_results['rouge1']['std']:.4f})")
        print(f"  ROUGE-2: {rouge_results['rouge2']['average']:.4f} (±{rouge_results['rouge2']['std']:.4f})")
        print(f"  ROUGE-L: {rouge_results['rougeL']['average']:.4f} (±{rouge_results['rougeL']['std']:.4f})")

        print(f"\nBERTScore:")
        print(f"  F1: {bert_results['f1']['average']:.4f} (±{bert_results['f1']['std']:.4f})")
        print(f"  Precision: {bert_results['precision']['average']:.4f} (±{bert_results['precision']['std']:.4f})")
        print(f"  Recall: {bert_results['recall']['average']:.4f} (±{bert_results['recall']['std']:.4f})")

        print("\n" + "="*60)
        print("SENTENCE-LEVEL SEMANTIC SIMILARITY")
        print("="*60)

        print(f"Sentence-BERT Cosine Similarity:")
        print(f"  Average: {sentence_sim_results['average']:.4f} (±{sentence_sim_results['std']:.4f})")
        print(f"  Range: {sentence_sim_results['min']:.4f} - {sentence_sim_results['max']:.4f}")

        print("\n" + "="*60)
        print("RHYME ANALYSIS")
        print("="*60)

        print(f"\nCMU Dictionary Phonetic Analysis:")
        print(f"  Phonetic Rhyme Rate: {cmu_rhyme_results['phonetic_rhyme_rate']:.2%}")
        print(f"  Valid Examples: {cmu_rhyme_results['total_valid']}/{total_examples}")

        print("\n" + "="*60)
        print("LENGTH ANALYSIS")
        print("="*60)

        print(f"Length Similarity:")
        print(f"  Average Length Difference: {length_results['average_length_diff']:.2f} words")
        print(f"  Average Length Ratio: {length_results['average_length_ratio']:.2f}")

        print("=" * 80)

        # Return all results for further analysis
        return {
            'basic_stats': {
                'total_examples': total_examples,
                'empty_generations': empty_generations
            },
            'bleu': bleu_results,
            'rouge': rouge_results,
            'bert_score': bert_results,
            'sentence_similarity': sentence_sim_results,
            # 'dandelion_rhyme': dandelion_rhyme_results,
            'cmu_rhyme': cmu_rhyme_results,
            'length_analysis': length_results
        }


In [None]:
test_results = []

for r in results:
    test_results.append({
        "input": r["prompt"],
        "actual": r["actual_line2"],
        "generated": r["generated_line2"]
    })

In [None]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.corpus import cmudict
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from sentence_transformers import SentenceTransformer, util
import torch
import numpy as np
import time
from typing import Dict, List, Any
import warnings

# Download required NLTK data
nltk.download('punkt', quiet=True)
nltk.download('cmudict', quiet=True)

class ComprehensiveEvaluator:
    def __init__(self, sentence_model_name: str = 'all-MiniLM-L6-v2', device: str = None):
        """
        Initialize evaluator with configurable models and device

        Args:
            sentence_model_name: Name of sentence transformer model to use
            device: Device to run models on ('cuda', 'cpu', or None for auto)
        """
        # Set device
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device

        print(f"Initializing evaluator on device: {self.device}")

        # Initialize models with error handling
        try:
            self.sentence_model = SentenceTransformer(sentence_model_name, device=self.device)
            print(f"✓ Loaded Sentence-BERT model: {sentence_model_name}")
        except Exception as e:
            warnings.warn(f"Failed to load Sentence-BERT model: {e}")
            self.sentence_model = None

        # Initialize ROUGE scorer
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

        # Initialize CMU dictionary with error handling
        try:
            self.cmu_dict = cmudict.dict()
            print(f"✓ Loaded CMU dictionary with {len(self.cmu_dict)} entries")
        except Exception as e:
            warnings.warn(f"Failed to load CMU dictionary: {e}")
            self.cmu_dict = {}

        # Initialize BLEU smoothing function
        self.bleu_smoothing = SmoothingFunction().method1

        # Cache for performance
        self._rhyme_cache = {}

    def calculate_bleu_scores(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate BLEU scores with improved handling"""
        bleu_scores = []
        valid_count = 0
        empty_count = 0

        for result in test_results:
            actual = result['actual'].strip()
            generated = result['generated'].strip()

            if not generated:
                empty_count += 1
                bleu_scores.append(0.0)
            elif not actual:
                empty_count += 1
                bleu_scores.append(0.0)
            else:
                valid_count += 1
                reference = [actual.split()]
                candidate = generated.split()

                score = sentence_bleu(reference, candidate, smoothing_function=self.bleu_smoothing)
                bleu_scores.append(score)

        return {
            'individual_scores': bleu_scores,
            'average': np.mean(bleu_scores),
            'std': np.std(bleu_scores),
            'min': np.min(bleu_scores),
            'max': np.max(bleu_scores),
            'valid_count': valid_count,
            'empty_count': empty_count
        }

    def calculate_rouge_scores(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate ROUGE scores with enhanced tracking"""
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []

        valid_count = 0
        empty_count = 0

        for result in test_results:
            actual = result['actual'].strip()
            generated = result['generated'].strip()

            if not generated or not actual:
                empty_count += 1
                rouge1_scores.append(0.0)
                rouge2_scores.append(0.0)
                rougeL_scores.append(0.0)
            else:
                valid_count += 1
                scores = self.rouge_scorer.score(actual, generated)
                rouge1_scores.append(scores['rouge1'].fmeasure)
                rouge2_scores.append(scores['rouge2'].fmeasure)
                rougeL_scores.append(scores['rougeL'].fmeasure)

        return {
            'rouge1': {
                'individual_scores': rouge1_scores,
                'average': np.mean(rouge1_scores),
                'std': np.std(rouge1_scores),
                'valid_count': valid_count,
                'empty_count': empty_count
            },
            'rouge2': {
                'individual_scores': rouge2_scores,
                'average': np.mean(rouge2_scores),
                'std': np.std(rouge2_scores),
                'valid_count': valid_count,
                'empty_count': empty_count
            },
            'rougeL': {
                'individual_scores': rougeL_scores,
                'average': np.mean(rougeL_scores),
                'std': np.std(rougeL_scores),
                'valid_count': valid_count,
                'empty_count': empty_count
            }
        }

    def calculate_bert_scores(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate BERTScore with batch processing"""
        candidates = [result['generated'] for result in test_results]
        references = [result['actual'] for result in test_results]

        try:
            # Calculate BERTScore with device specification
            P, R, F1 = bert_score(candidates, references, lang="en", verbose=False, device=self.device)

            return {
                'precision': {
                    'average': P.mean().item(),
                    'std': P.std().item(),
                    'individual_scores': P.tolist()
                },
                'recall': {
                    'average': R.mean().item(),
                    'std': R.std().item(),
                    'individual_scores': R.tolist()
                },
                'f1': {
                    'average': F1.mean().item(),
                    'std': F1.std().item(),
                    'individual_scores': F1.tolist()
                }
            }
        except Exception as e:
            warnings.warn(f"BERTScore calculation failed: {e}")
            return None

    def calculate_sentence_similarity(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate sentence similarity with error handling"""
        if self.sentence_model is None:
            warnings.warn("Sentence-BERT model not available")
            return None

        actual_lines = [result['actual'] for result in test_results]
        generated_lines = [result['generated'] for result in test_results]

        try:
            # Encode all sentences with batch processing
            actual_embeddings = self.sentence_model.encode(actual_lines, convert_to_tensor=True, show_progress_bar=False)
            generated_embeddings = self.sentence_model.encode(generated_lines, convert_to_tensor=True, show_progress_bar=False)

            # Calculate cosine similarity
            cosine_scores = util.pytorch_cos_sim(actual_embeddings, generated_embeddings)

            # Extract diagonal (pairwise similarities)
            similarities = [cosine_scores[i][i].item() for i in range(len(actual_lines))]

            return {
                'individual_scores': similarities,
                'average': np.mean(similarities),
                'std': np.std(similarities),
                'min': np.min(similarities),
                'max': np.max(similarities)
            }
        except Exception as e:
            warnings.warn(f"Sentence similarity calculation failed: {e}")
            return None

    def get_last_word(self, line: str) -> str:
        """Extract last word with improved cleaning"""
        import re
        # Use regex to better handle punctuation and contractions
        words = re.findall(r"\b[a-zA-Z]+(?:'[a-zA-Z]+)?\b", line.lower())
        return words[-1] if words else ""

    def get_rhyme_part_cmu(self, word: str) -> List[str]:
        """Extract rhyming part with caching"""
        if word in self._rhyme_cache:
            return self._rhyme_cache[word]

        if word in self.cmu_dict:
            pronunciations = self.cmu_dict[word]
            if pronunciations:
                # Get the part after the last stressed vowel
                pron = pronunciations[0]
                for i in range(len(pron) - 1, -1, -1):
                    if pron[i][-1].isdigit():  # Stressed vowel
                        result = pron[i:]
                        self._rhyme_cache[word] = result
                        return result

        self._rhyme_cache[word] = None
        return None

    def analyze_rhymes_cmu(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Enhanced rhyme analysis with better statistics"""
        phonetic_rhymes = 0
        near_rhymes = 0  # Rhymes with similar endings
        total_valid = 0
        total_processed = 0

        rhyme_details = []

        for i, result in enumerate(test_results):
            input_last = self.get_last_word(result['input'])
            generated_last = self.get_last_word(result['generated'])

            total_processed += 1

            if input_last and generated_last:
                input_rhyme = self.get_rhyme_part_cmu(input_last)
                generated_rhyme = self.get_rhyme_part_cmu(generated_last)

                if input_rhyme and generated_rhyme:
                    total_valid += 1

                    is_perfect_rhyme = input_rhyme == generated_rhyme
                    is_near_rhyme = False

                    # Check for near rhymes (last 2 phonemes match)
                    if not is_perfect_rhyme and len(input_rhyme) >= 2 and len(generated_rhyme) >= 2:
                        is_near_rhyme = input_rhyme[-2:] == generated_rhyme[-2:]

                    if is_perfect_rhyme:
                        phonetic_rhymes += 1
                    elif is_near_rhyme:
                        near_rhymes += 1

                    rhyme_details.append({
                        'example_index': i,
                        'input_word': input_last,
                        'generated_word': generated_last,
                        'input_phonemes': input_rhyme,
                        'generated_phonemes': generated_rhyme,
                        'is_perfect_rhyme': is_perfect_rhyme,
                        'is_near_rhyme': is_near_rhyme
                    })

        return {
            'perfect_rhyme_rate': phonetic_rhymes / total_valid if total_valid > 0 else 0,
            'near_rhyme_rate': near_rhymes / total_valid if total_valid > 0 else 0,
            'total_rhyme_rate': (phonetic_rhymes + near_rhymes) / total_valid if total_valid > 0 else 0,
            'perfect_rhymes': phonetic_rhymes,
            'near_rhymes': near_rhymes,
            'total_valid': total_valid,
            'total_processed': total_processed,
            'coverage': total_valid / total_processed if total_processed > 0 else 0,
            'details': rhyme_details
        }

    def calculate_additional_metrics(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Calculate additional rap-specific metrics"""

        # Syllable analysis (approximate)
        def count_syllables(word):
            # Simple syllable counting heuristic
            word = word.lower()
            count = 0
            vowels = "aeiouy"
            if word[0] in vowels:
                count += 1
            for i in range(1, len(word)):
                if word[i] in vowels and word[i-1] not in vowels:
                    count += 1
            if word.endswith("e"):
                count -= 1
            if count == 0:
                count += 1
            return count

        syllable_diffs = []
        word_diversity_scores = []

        for result in test_results:
            actual_words = result['actual'].split()
            generated_words = result['generated'].split()

            # Syllable analysis
            if actual_words and generated_words:
                actual_syllables = sum(count_syllables(word) for word in actual_words)
                generated_syllables = sum(count_syllables(word) for word in generated_words)
                syllable_diffs.append(abs(actual_syllables - generated_syllables))
            else:
                syllable_diffs.append(0)

            # Word diversity (unique words / total words)
            if generated_words:
                diversity = len(set(generated_words)) / len(generated_words)
                word_diversity_scores.append(diversity)
            else:
                word_diversity_scores.append(0.0)

        return {
            'syllable_similarity': {
                'average_diff': np.mean(syllable_diffs),
                'std_diff': np.std(syllable_diffs)
            },
            'word_diversity': {
                'average': np.mean(word_diversity_scores),
                'std': np.std(word_diversity_scores)
            }
        }

    def evaluate_comprehensive(self, test_results: List[Dict], show_progress: bool = True) -> Dict[str, Any]:
        """Enhanced comprehensive evaluation with timing and progress"""
        start_time = time.time()

        print("=" * 80)
        print("COMPREHENSIVE EVALUATION RESULTS")
        print("=" * 80)

        # Basic statistics
        total_examples = len(test_results)
        empty_generations = sum(1 for r in test_results if not r['generated'].strip())

        print(f"Dataset Statistics:")
        print(f"  Total Examples: {total_examples}")
        print(f"  Empty Generations: {empty_generations} ({empty_generations/total_examples:.1%})")
        print()

        # Calculate all metrics with timing
        results = {}

        if show_progress:
            print("Computing metrics...")

        # Traditional NLP metrics
        if show_progress: print("  • BLEU scores...")
        results['bleu'] = self.calculate_bleu_scores(test_results)

        if show_progress: print("  • ROUGE scores...")
        results['rouge'] = self.calculate_rouge_scores(test_results)

        if show_progress: print("  • BERTScore...")
        results['bert_score'] = self.calculate_bert_scores(test_results)

        # Sentence-level similarity
        if show_progress: print("  • Sentence similarity...")
        results['sentence_similarity'] = self.calculate_sentence_similarity(test_results)

        # Rhyme analysis
        if show_progress: print("  • Rhyme analysis...")
        results['cmu_rhyme'] = self.analyze_rhymes_cmu(test_results)

        # Additional metrics
        if show_progress: print("  • Additional metrics...")
        results['additional_metrics'] = self.calculate_additional_metrics(test_results)

        # Length analysis
        results['length_analysis'] = self.calculate_length_similarity(test_results)

        # Display results with enhanced formatting
        self._display_results(results, total_examples, empty_generations)

        execution_time = time.time() - start_time
        print(f"\nEvaluation completed in {execution_time:.2f} seconds")
        print("=" * 80)

        # Add metadata
        results['metadata'] = {
            'total_examples': total_examples,
            'empty_generations': empty_generations,
            'execution_time': execution_time,
            'device_used': self.device
        }

        return results

    def _display_results(self, results: Dict, total_examples: int, empty_generations: int):
        """Enhanced result display with better formatting"""

        print("\n" + "="*60)
        print("TRADITIONAL NLP METRICS")
        print("="*60)

        # BLEU
        bleu = results['bleu']
        print(f"BLEU Score:")
        print(f"  Average: {bleu['average']:.4f} (±{bleu['std']:.4f})")
        print(f"  Range: {bleu['min']:.4f} - {bleu['max']:.4f}")
        print(f"  Valid/Empty: {bleu['valid_count']}/{bleu['empty_count']}")

        # ROUGE
        rouge = results['rouge']
        print(f"\nROUGE Scores:")
        print(f"  ROUGE-1: {rouge['rouge1']['average']:.4f} (±{rouge['rouge1']['std']:.4f})")
        print(f"  ROUGE-2: {rouge['rouge2']['average']:.4f} (±{rouge['rouge2']['std']:.4f})")
        print(f"  ROUGE-L: {rouge['rougeL']['average']:.4f} (±{rouge['rougeL']['std']:.4f})")
        print(f"  Valid/Empty: {rouge['rouge1']['valid_count']}/{rouge['rouge1']['empty_count']}")

        # BERTScore
        if results['bert_score']:
            bert = results['bert_score']
            print(f"\nBERTScore:")
            print(f"  F1: {bert['f1']['average']:.4f} (±{bert['f1']['std']:.4f})")
            print(f"  Precision: {bert['precision']['average']:.4f} (±{bert['precision']['std']:.4f})")
            print(f"  Recall: {bert['recall']['average']:.4f} (±{bert['recall']['std']:.4f})")

        # Sentence similarity
        if results['sentence_similarity']:
            sent_sim = results['sentence_similarity']
            print("\n" + "="*60)
            print("SENTENCE-LEVEL SEMANTIC SIMILARITY")
            print("="*60)
            print(f"Sentence-BERT Cosine Similarity:")
            print(f"  Average: {sent_sim['average']:.4f} (±{sent_sim['std']:.4f})")
            print(f"  Range: {sent_sim['min']:.4f} - {sent_sim['max']:.4f}")

        # Rhyme analysis
        rhyme = results['cmu_rhyme']
        print("\n" + "="*60)
        print("RHYME ANALYSIS")
        print("="*60)
        print(f"CMU Dictionary Phonetic Analysis:")
        print(f"  Perfect Rhyme Rate: {rhyme['perfect_rhyme_rate']:.2%}")
        print(f"  Near Rhyme Rate: {rhyme['near_rhyme_rate']:.2%}")
        print(f"  Total Rhyme Rate: {rhyme['total_rhyme_rate']:.2%}")
        print(f"  Dictionary Coverage: {rhyme['coverage']:.1%} ({rhyme['total_valid']}/{rhyme['total_processed']})")

        # Additional metrics
        additional = results['additional_metrics']
        print("\n" + "="*60)
        print("RAP-SPECIFIC METRICS")
        print("="*60)
        print(f"Syllable Similarity:")
        print(f"  Average Difference: {additional['syllable_similarity']['average_diff']:.2f} syllables")
        print(f"Word Diversity:")
        print(f"  Average: {additional['word_diversity']['average']:.3f}")

        # Length analysis
        length = results['length_analysis']
        print(f"\nLength Analysis:")
        print(f"  Average Length Difference: {length['average_length_diff']:.2f} words")
        print(f"  Average Length Ratio: {length['average_length_ratio']:.2f}")

    def calculate_length_similarity(self, test_results: List[Dict]) -> Dict[str, Any]:
        """Enhanced length analysis"""
        length_diffs = []
        length_ratios = []

        for result in test_results:
            actual_len = len(result['actual'].split())
            generated_len = len(result['generated'].split())

            length_diffs.append(abs(actual_len - generated_len))

            if actual_len > 0:
                length_ratios.append(generated_len / actual_len)
            else:
                length_ratios.append(0.0)

        return {
            'average_length_diff': np.mean(length_diffs),
            'std_length_diff': np.std(length_diffs),
            'average_length_ratio': np.mean(length_ratios),
            'std_length_ratio': np.std(length_ratios),
            'individual_diffs': length_diffs,
            'individual_ratios': length_ratios
        }

In [None]:
# Initialize with custom settings
evaluator = ComprehensiveEvaluator(
    sentence_model_name='all-MiniLM-L6-v2',  # or 'all-mpnet-base-v2' for better quality
    device='cuda'  # or 'cpu'
)

# Run evaluation
comprehensive_results = evaluator.evaluate_comprehensive(test_results, show_progress=True)