In [None]:
# install dependencies
!pip install flake8 datasets transformers -U trl torch

In [None]:
import sys
# !pip install --upgrade "safetensors>=0.4.3"
# !{sys.executable} -m pip install --upgrade --force-reinstall "safetensors>=0.4.3"
# !{sys.executable} -m pip install --upgrade trl

In [None]:
import subprocess
import tempfile
from pathlib import Path
import itertools
import random
import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    Trainer as ClsTrainer,
    TrainingArguments as ClsTrainingArguments,
)
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.backends.mps.is_available())

In [None]:
!pip install huggingface_hub
!huggingface-cli login

In [None]:
# stream data
stream_ds = load_dataset(
    "codeparrot/codeparrot-clean",
    split="train[:0.1%]",
    token=True
)

# shuffle a bit then take the first 500 examples
import random
random.seed(42)
shuffled = stream_ds.shuffle(buffer_size=10_000)  # small in-memory buffer
small_iter = itertools.islice(shuffled, 2000)

# print top 5 code lines
for idx, ex in enumerate(small_iter):
    print(idx, ex["content"][:50])
    if idx >= 5:
        break

In [None]:
# compute a style score from 0.0 to 1.0 using flake8
def get_style_score(code: str, max_vios: int = 10) -> float:
  with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tf:
    tf.write(code.encode("utf-8"))
    tf.flush()
    path = tf.name
  result = subprocess.run(
      ["flake8", "--max-line-length=88", path],
      capture_output=True,
      text=True
  )
  vios = len(result.stdout.splitlines())
  Path(path).unlink()

  return max(0.0, 1.0-vios/max_vios)

In [None]:
# collect compliant snippts for fine-tuning
compliant_snippets = []

for ex in small_iter:
  code = ex["content"]
  # print(get_style_score(code))
  if get_style_score(code) == 1.0:
    compliant_snippets.append(code)
  if len(compliant_snippets) >= 200:
    break

# compliant_snippets

In [None]:
# collect 200 mixed snippets for labelled reward model
labeled_data = []
for ex in small_iter:
  code = ex["content"]
  label = int(get_style_score(code) == 1.0)
  labeled_data.append({"code": code, "label": label})
  if len(labeled_data) >= 400:
    break

In [None]:
# fine-tune CodeParrot on compliant snippets
tokenizer = AutoTokenizer.from_pretrained("codeparrot/codeparrot-small")
model = AutoModelForCausalLM.from_pretrained("codeparrot/codeparrot-small")

tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))

tokenizer.save_pretrained("codeparrot-ft")

In [None]:
# prepare dataset
encodings = tokenizer(
    compliant_snippets,
    truncation=True,
    padding="longest",
    return_tensors="pt"
)

In [None]:
class LMData(torch.utils.data.Dataset):
  def __init__(self, enc):
    self.input_ids = enc.input_ids
    self.attn_mask = enc.attention_mask
  def __len__(self): return len(self.input_ids)
  def __getitem__(self, idx): return {
      "input_ids": self.input_ids[idx],
      "attention_mask": self.attn_mask[idx],
      "labels": self.input_ids[idx]
  }

In [None]:
torch.cuda.empty_cache()

In [None]:
lm_dataset = LMData(encodings)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

lm_args = TrainingArguments(
    output_dir="codeparrot-ft",
    per_device_train_batch_size=4,
    num_train_epochs=1,
    logging_steps=10,
    save_total_limit=1
)
lm_trainer = Trainer(
    model=model,
    args=lm_args,
    train_dataset=lm_dataset,
    data_collator=data_collator
)

lm_trainer.train()
model.save_pretrained("codeparrot-ft")

In [None]:
# reward model: small classifier on style adherence
cls_tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
cls_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

# prepare classification dataset
texts = [d["code"] for d in labeled_data]
labels = [d["label"] for d in labeled_data]
cls_enc = cls_tokenizer(texts, truncation=True, padding="longest", return_tensors="pt")

In [None]:
class CLSData(torch.utils.data.Dataset):
  def __init__(self, enc, labels):
    self.input_ids = enc.input_ids
    self.attn_mask = enc.attention_mask
    self.labels = torch.tensor(labels)
  def __len__(self): return len(self.labels)
  def __getitem__(self, idx): return {"input_ids": self.input_ids[idx], "attention_mask": self.attn_mask[idx], "labels": self.labels[idx]}


In [None]:
cls_dataset = CLSData(cls_enc, labels)
cls_args = ClsTrainingArguments(output_dir="style-cls", per_device_train_batch_size=8, num_train_epochs=1, logging_steps=10, save_total_limit=1)
cls_trainer = ClsTrainer(model=cls_model, args=cls_args, train_dataset=cls_dataset)

cls_trainer.train()
cls_model.save_pretrained("style-cls")

# reward model (style classifier)
reward_model = cls_model.to(device)

In [None]:
from datasets import Dataset

# prepare prompts for PPO
test_prompts = ["def add(a, b):", "class Person:", "def compute():", "def process_data(data):"]
raw_dataset = Dataset.from_dict({"query": test_prompts})

def tokenize_prompts(ex):
  output = tokenizer(ex["query"], truncation=True, padding="max_length", max_length=32)
  output["input_ids"] = output["input_ids"]
  output["attention_mask"] = output["attention_mask"]
  return output

train_dataset = raw_dataset.map(tokenize_prompts, batched=True, remove_columns=["query"])
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

In [None]:
# fine-tuned tokenizer
tokenizer = AutoTokenizer.from_pretrained("codeparrot-ft")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from trl import create_reference_model

# RLHF via PPO: 2 gradient updates
ppo_config = PPOConfig(
    output_dir="results/style-ppo",
    overwrite_output_dir=True,
    do_train=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=1.41e-5,

    # PPO-specific
    sft_model_path="codeparrot-ft",
    reward_model_path="style-cls",
    exp_name="style-ppo",
    batch_size=4,
    mini_batch_size=4,
    num_ppo_epochs=1,
    total_episodes=2 # 2 generate->update loops
)

ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("codeparrot-ft")

# models return dicts rather than tuples
ppo_model.config.return_dict = True
ppo_model.pretrained_model.config.return_dict = True

# reference copy
ref_model = create_reference_model(ppo_model)
# ref_model.eval()  # no updates

In [None]:
from transformers import GenerationConfig

# attach a GenerationConfig
gen_conf = GenerationConfig(**ppo_model.config.to_dict())
ppo_model.generation_config = gen_conf
ref_model.generation_config = gen_conf

# attach base_model_prefix, matches config.model_type
ppo_model.base_model_prefix = "pretrained_model"
ref_model.base_model_prefix = "pretrained_model"

ppo_model.to(device)
ref_model.to(device)

In [None]:
def ensure_dict_output(model):
    """Wrapper to ensure model always returns dict format"""
    original_forward = model.forward

    def wrapped_forward(*args, **kwargs):
        # Force return_dict=True in the forward call
        kwargs['return_dict'] = True
        output = original_forward(*args, **kwargs)
        if isinstance(output, tuple):
            # Convert tuple to dict format as fallback
            from transformers.modeling_outputs import CausalLMOutputWithPast
            return CausalLMOutputWithPast(
                logits=output[0],
                past_key_values=output[1] if len(output) > 1 else None,
                hidden_states=output[2] if len(output) > 2 else None,
                attentions=output[3] if len(output) > 3 else None,
            )
        return output

    model.forward = wrapped_forward
    return model
def fix_pretrained_model_forward(model):
    """Specifically fix the pretrained_model forward method"""
    if hasattr(model, 'pretrained_model'):
        original_forward = model.pretrained_model.forward

        def wrapped_forward(*args, **kwargs):
            kwargs['return_dict'] = True
            return original_forward(*args, **kwargs)

        model.pretrained_model.forward = wrapped_forward
    return model

# Apply the fix to both models:
ref_model = fix_pretrained_model_forward(ref_model)
ppo_model = fix_pretrained_model_forward(ppo_model)

# Also apply the general wrapper as backup
ref_model = ensure_dict_output(ref_model)
ppo_model = ensure_dict_output(ppo_model)

ppo_model.to(device)
ref_model.to(device)

In [None]:
# defining reward function using reward model
def reward_fn(responses):
  texts = [tokenizer.decode(r, skip_special_tokens=True) for r in responses]
  return [torch.tensor(get_style_score(t), device=device) for t in texts]

In [None]:
# before RLHF: baseline generation
test_prompts = ["def add(a, b):", "class Person:", "def compute():", "def process_data(data):"]
baseline_scores = []

tokenizer.padding_side = "left"
# tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

for q in test_prompts:
    enc = tokenizer(q, return_tensors="pt", padding=True, truncation=True).to(device)
    out = ppo_model.generate(**enc, max_length=50, pad_token_id=tokenizer.pad_token_id)[0]
    baseline_scores.append(get_style_score(tokenizer.decode(out, skip_special_tokens=True)))

print(baseline_scores)

In [None]:
print("PPO model return_dict:", ppo_model.config.return_dict)
print("PPO pretrained_model return_dict:", ppo_model.pretrained_model.config.return_dict)
print("Ref model return_dict:", ref_model.config.return_dict)
print("Ref pretrained_model return_dict:", ref_model.pretrained_model.config.return_dict)

In [None]:
def compute_reward_scores(query_tensors, response_tensors):
    """Compute reward scores for PPO training"""
    rewards = []
    for query, response in zip(query_tensors, response_tensors):
        # Combine query and response
        full_sequence = torch.cat([query, response])
        # Decode to text
        text = tokenizer.decode(full_sequence, skip_special_tokens=True)
        # Get style score
        score = get_style_score(text)
        rewards.append(torch.tensor(score, device=device))
    return rewards

In [None]:
import inspect
print(inspect.signature(PPOTrainer))

In [None]:
from transformers import default_data_collator

# instantiate PPOTrainer with required args
ppo_trainer = PPOTrainer(
    args=ppo_config,
    processing_class=tokenizer,
    model=ppo_model,
    ref_model=ref_model,
    reward_model=reward_model,
    train_dataset=train_dataset,
    value_model=ppo_model,
    data_collator=default_data_collator
)

In [None]:
# run 2 PPO gradient updates
ppo_trainer.train()

In [None]:
# after RLHF: post-PPO generation
post_scores = []

for q in test_prompts:
    enc = tokenizer(q, return_tensors="pt", padding=True, truncation=True).to(device)
    out = ppo_model.generate(**enc, max_length=50, pad_token_id=tokenizer.pad_token_id)[0]
    post_scores.append(get_style_score(tokenizer.decode(out, skip_special_tokens=True)))

print(post_scores)

In [None]:
# eval: comparing style scores
for i, q in enumerate(test_prompts):
  print(f"Prompt: {q}")
  print(f"Baseline score: {baseline_scores[i]}")
  print(f"Post-RLHF score: {post_scores[i]})