# 02 ‚Äî Poetic Critic (Reward Model)

Fine-tunes `bert-base-uncased` as a binary classifier to distinguish **poetic** (label 1) from **standard** (label 0) text.

This model will be used as the frozen reward signal in the DRaFT (Differentiable Reward Fine-Tuning) training pipeline.

**Data:** Response-only text from `poem_refined_2800x6.jsonl` and `poem_real_conversations_2000.jsonl`.

In [None]:
# Cell 1: Imports
import json
import random
from pathlib import Path

import torch
from datasets import Dataset
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

print("‚úÖ Imports loaded.")

In [None]:
# Cell 2: Config
project_root = Path('..').resolve()
refined_data_path = project_root / 'data' / 'poem_refined_2800x6.jsonl'
real_conv_path = project_root / 'data' / 'poem_real_conversations_2000.jsonl'
reward_model_output = project_root / 'poetic_reward_model'

bert_model_name = 'bert-base-uncased'
max_length = 256
train_epochs = 3
train_batch_size = 32
eval_batch_size = 64
learning_rate = 2e-5
val_split = 0.1
seed = 42

random.seed(seed)
torch.manual_seed(seed)

reward_model_output.mkdir(parents=True, exist_ok=True)
print(f"‚úÖ Config loaded.")
print(f"   Refined data: {refined_data_path.name} (exists: {refined_data_path.exists()})")
print(f"   Real conversations: {real_conv_path.name} (exists: {real_conv_path.exists()})")
print(f"   Output: {reward_model_output}")

In [None]:
# Cell 3: Load & Label Data
# Label 0 = standard (normal) text, Label 1 = poetic text
# We use response-only text (no query prepended) so the reward model
# learns "is this text poetic?" as a property of the text itself.

texts = []
labels = []
stats = {"refined_poetic": 0, "refined_standard": 0, "real_poetic": 0, "real_standard": 0, "skipped": 0}

# ‚îÄ‚îÄ Refined dataset (up to 6 pairs per record) ‚îÄ‚îÄ
print("Loading refined dataset...")
with open(refined_data_path, encoding="utf-8") as f:
    for line in f:
        try:
            record = json.loads(line)
            data_list = record.get("data", [])
            for pair in data_list:
                poem = pair.get("poem", "").strip()
                normal = pair.get("normal", "").strip()
                if poem:
                    texts.append(poem)
                    labels.append(1)
                    stats["refined_poetic"] += 1
                if normal:
                    texts.append(normal)
                    labels.append(0)
                    stats["refined_standard"] += 1
        except Exception:
            stats["skipped"] += 1

# ‚îÄ‚îÄ Real conversations (1 pair per record) ‚îÄ‚îÄ
print("Loading real conversations...")
with open(real_conv_path, encoding="utf-8") as f:
    for line in f:
        try:
            record = json.loads(line)
            data_list = record.get("data", [])
            for pair in data_list:
                poem = pair.get("poem", "").strip()
                normal = pair.get("normal", "").strip()
                if poem:
                    texts.append(poem)
                    labels.append(1)
                    stats["real_poetic"] += 1
                if normal:
                    texts.append(normal)
                    labels.append(0)
                    stats["real_standard"] += 1
        except Exception:
            stats["skipped"] += 1

total_poetic = stats["refined_poetic"] + stats["real_poetic"]
total_standard = stats["refined_standard"] + stats["real_standard"]

print(f"\nüìä Data Loading Summary:")
print(f"   Refined:  {stats['refined_poetic']} poetic + {stats['refined_standard']} standard")
print(f"   Real:     {stats['real_poetic']} poetic + {stats['real_standard']} standard")
print(f"   Skipped:  {stats['skipped']}")
print(f"   ‚ûú Total:  {total_poetic} poetic (label 1) + {total_standard} standard (label 0) = {len(texts)}")
print(f"   ‚ûú Balance: {total_poetic / len(texts) * 100:.1f}% poetic / {total_standard / len(texts) * 100:.1f}% standard")

In [None]:
# Cell 4: Tokenize & Create Train/Val Split
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

# Shuffle and split
combined = list(zip(texts, labels))
random.shuffle(combined)
texts_shuffled, labels_shuffled = zip(*combined)

split_idx = int(len(texts_shuffled) * (1 - val_split))
train_texts, val_texts = texts_shuffled[:split_idx], texts_shuffled[split_idx:]
train_labels, val_labels = labels_shuffled[:split_idx], labels_shuffled[split_idx:]

# Tokenize
print("Tokenizing...")
train_encodings = tokenizer(
    list(train_texts), truncation=True, padding='max_length',
    max_length=max_length, return_tensors='pt'
)
val_encodings = tokenizer(
    list(val_texts), truncation=True, padding='max_length',
    max_length=max_length, return_tensors='pt'
)

# Build HF Datasets
train_dataset = Dataset.from_dict({
    'input_ids': train_encodings['input_ids'],
    'attention_mask': train_encodings['attention_mask'],
    'labels': list(train_labels),
})
val_dataset = Dataset.from_dict({
    'input_ids': val_encodings['input_ids'],
    'attention_mask': val_encodings['attention_mask'],
    'labels': list(val_labels),
})

train_dataset.set_format('torch')
val_dataset.set_format('torch')

print(f"‚úÖ Tokenized.")
print(f"   Train: {len(train_dataset)} | Val: {len(val_dataset)}")
print(f"   Max length: {max_length} tokens")

In [None]:
# Cell 5: Train BERT Classifier
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='binary')
    return {"accuracy": acc, "f1": f1}

model = BertForSequenceClassification.from_pretrained(
    bert_model_name, num_labels=2
)

training_args = TrainingArguments(
    output_dir=str(reward_model_output / 'checkpoints'),
    num_train_epochs=train_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    learning_rate=learning_rate,
    weight_decay=0.01,
    warmup_ratio=0.1,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_steps=50,
    report_to="none",
    fp16=torch.cuda.is_available(),
    seed=seed,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

print(f"üöÄ Training BERT classifier for {train_epochs} epochs...")
trainer.train()

# Save best model
model.save_pretrained(reward_model_output)
tokenizer.save_pretrained(reward_model_output)
print(f"\n‚úÖ Best model saved to {reward_model_output}")

In [None]:
# Cell 6: Sanity Check ‚Äî Inference on sample texts
from torch.nn.functional import softmax

# Load saved model for verification
test_model = BertForSequenceClassification.from_pretrained(reward_model_output)
test_tokenizer = BertTokenizer.from_pretrained(reward_model_output)
test_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_model.to(device)

# Sample texts: poetic vs standard
test_samples = [
    ("The wind howls low through teeth of stone,\na voice unshaped by mortal tongue,\n"
     "where roots entwine in bones of earth,\nand time itself is wild and young.", "POETIC"),
    ("To cut carbon arrows at home, you need a rotary cutter or an arrow saw. "
     "First measure your draw length, then mark the shaft and cut carefully.", "STANDARD"),
    ("Beneath the moon's pale, watchful eye,\nwhere three dark spots in clustered guise\n"
     "do mark the shell's soft vulnerability‚Äî\nthere pierce the flesh with steel or wood.", "POETIC"),
    ("Video game addiction can be very detrimental to one's health and social life. "
     "Accept responsibility and set limits on your gaming time.", "STANDARD"),
]

print("üîç Sanity Check ‚Äî Poetic Reward Model Inference\n")
print(f"{'Expected':<12} {'Pred':<8} {'P(poetic)':<12} Text snippet")
print("-" * 80)

with torch.no_grad():
    for text, expected in test_samples:
        inputs = test_tokenizer(
            text, truncation=True, padding='max_length',
            max_length=max_length, return_tensors='pt'
        ).to(device)
        logits = test_model(**inputs).logits
        probs = softmax(logits, dim=-1)
        pred_label = "POETIC" if probs[0, 1] > 0.5 else "STANDARD"
        poetic_prob = probs[0, 1].item()
        snippet = text[:60].replace('\n', ' ') + "..."
        status = "‚úÖ" if pred_label == expected else "‚ùå"
        print(f"{status} {expected:<10} {pred_label:<8} {poetic_prob:<12.4f} {snippet}")

# Cleanup
del test_model
print("\n‚úÖ Sanity check complete.")