In [2]:
# ppo_lora_generate_and_optimize.py
# Usage: edit MODEL names and device settings, then run with accelerate or normal python (GPU recommended)

import torch
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, 
    AutoModelForSequenceClassification, AutoTokenizer as ClassTok
)
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# -------------------------
# Config / hyperparams
# -------------------------
GEN_MODEL = "meta-llama/Llama-3.2-3B-Instruct"   # <-- replace with a model you can access
REF_MODEL = GEN_MODEL
CLASSIFIER_MODEL = "xlm-roberta-base"         # your small XLM-R checkpoint or id
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# PPO / LoRA hyperparameters (tune)
lora_r = 8
lora_alpha = 32
lora_dropout = 0.05
ppo_epochs = 4
ppo_clip = 0.2
ppo_batch_size = 8
learning_rate = 1e-5
max_gen_len = 128

2025-10-14 18:21:58.089036: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-10-14 18:21:58.101410: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-10-14 18:21:58.105289: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-14 18:21:58.115703: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# -------------------------
# Load generator (policy) + reference model
# -------------------------
# Use the value-head variant required by some trl PPO helpers
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(GEN_MODEL, torch_dtype=torch.float16, device_map="auto")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(REF_MODEL, torch_dtype=torch.float16, device_map="auto")
ref_model.eval()
for p in ref_model.parameters(): p.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [4]:
ppo_config = PPOConfig(model_name=GEN_MODEL, learning_rate=1e-5, ppo_epochs=4, batch_size=8, mini_batch_size=4, init_kl_coef=0.2)
ppo_trainer = PPOTrainer(config=ppo_config, model=policy_model, ref_model=ref_model, tokenizer=tokenizer)

lora_cfg = LoraConfig(r=8, lora_alpha=32, target_modules=["q_proj","v_proj","k_proj","o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
peft_model = get_peft_model(ppo_trainer.model, lora_cfg)

# Move to device if necessary
device = "cuda" if torch.cuda.is_available() else "cpu"
peft_model = peft_model.to(device)



In [5]:
ppo_trainer.model = peft_model

In [6]:
new_opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, ppo_trainer.model.parameters()), lr=1e-5)
ppo_trainer.optimizer = new_opt

In [7]:
# -------------------------
# Setup PPO trainer
# -------------------------
ppo_config = PPOConfig(
    model_name=GEN_MODEL,
    learning_rate=learning_rate,
    ppo_epochs=ppo_epochs,
    batch_size=ppo_batch_size,
    forward_batch_size=1,     # smaller if memory constrained
    init_kl_coef=0.2,         # initial KL coefficient (controls drift)
)

# NOTE: trl.PPOTrainer signature can differ by version. Common pattern:
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
)




In [8]:
# get device used by the trainer (fallback to model params)
if hasattr(ppo_trainer, "accelerator") and hasattr(ppo_trainer.accelerator, "device"):
    device = ppo_trainer.accelerator.device
else:
    device = next(ppo_trainer.model.parameters()).device

def tokenize_prompts(queries, tokenizer, device):
    """
    Return a list of 1-D LongTensors (input_ids) — one tensor per prompt.
    Do NOT return a padded 2D tensor; TRL expects a list of tensors.
    """
    token_ids = []
    for q in queries:
        ids = tokenizer.encode(q, add_special_tokens=True)
        token_ids.append(torch.tensor(ids, dtype=torch.long).to(device))
    return token_ids


# create string prompts as before
prompt_template = "Generate a short Welsh movie-review sentence labeled as {label_name}:\n\nReview:"
label_map = {0: "negative", 1: "positive"}
prompts = [prompt_template.format(label_name=label_map[l]) for l in [0,1,1,0,1,0,1,0]]

In [12]:
# -------------------------
# Load classifier / reward function
# -------------------------
# If you already have compute_reward defined elsewhere, keep that one.
# This is a minimal example: log probability for the target label.
cls_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL)
classifier = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL, num_labels=2).to(DEVICE)
classifier.eval()

def compute_reward(text: str, target_label: int) -> float:
    enc = cls_tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        logits = classifier(**enc).logits
        probs = torch.softmax(logits, dim=-1)[0]
    return float(torch.log(probs[target_label] + 1e-8).cpu().item())

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
n_steps = 200

try:
    trainer_device = ppo_trainer.current_device
except AttributeError:
    # fallback — accelerator or model
    trainer_device = getattr(ppo_trainer, "accelerator", None).device \
                     if getattr(ppo_trainer, "accelerator", None) is not None \
                     else next(ppo_trainer.model.parameters()).device

# 1) Tokenize prompts into list of 1-D tensors
tokenized_queries = tokenize_prompts(prompts, tokenizer, device)  # list[torch.LongTensor]
for step in range(n_steps):
    # 1) generate. Pass token tensors (list of 1-D tensors)
    gens = ppo_trainer.generate(
        tokenized_queries,
        max_new_tokens=max_gen_len,   # or max_length depending on your version
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=1.0,
    )

    # `gens` may be list of token tensors or list of str depending on TRL version;
    # normalize to strings for reward computation and to tensors for step().
    gen_texts = []
    gen_tensors_returned = []  # if TRL returned token tensors (full sequence or only generated part)
    for g in gens:
        if isinstance(g, str):
            gen_texts.append(g)
            gen_tensors_returned.append(None)
        elif isinstance(g, torch.Tensor):
            # decode the tensor to string for reward calculation
            gen_texts.append(tokenizer.decode(g.tolist(), skip_special_tokens=True).strip())
            gen_tensors_returned.append(g)
        elif isinstance(g, (list, tuple)):
            gen_texts.append(tokenizer.decode(g, skip_special_tokens=True).strip())
            try:
                gen_tensors_returned.append(torch.tensor(g, dtype=torch.long, device=trainer_device))
            except Exception:
                gen_tensors_returned.append(None)
        else:
            txt = str(g)
            gen_texts.append(txt)
            gen_tensors_returned.append(None)

    # compute rewards (strings -> classifier)
    raw_rewards = []
    for q_str, out_text in zip(prompts, gen_texts):
        target_label = 1 if "positive" in q_str else 0
        r = compute_reward(out_text, target_label)   # your classifier func
        raw_rewards.append(float(r))

    # normalize rewards (optional but recommended)
    if len(raw_rewards) > 1:
        r_tensor = torch.tensor(raw_rewards, dtype=torch.float32, device=trainer_device)
        normed = ((r_tensor - r_tensor.mean()) / (r_tensor.std(unbiased=False) + 1e-8)).tolist()
    else:
        normed = raw_rewards

    # Build response_tensors (query + generated response token ids) and response_masks
    response_tensors = []
    response_masks = []
    for i, (q_tensor, out_text, gen_tensor) in enumerate(zip(tokenized_queries, gen_texts, gen_tensors_returned)):
        # Option A: if generate returned a tensor that already contains the full sequence (prompt+response),
        # you can use it directly (but we still compute mask by comparing lengths).
        if gen_tensor is not None:
            # If gen_tensor appears to already be the concatenation of prompt+response:
            if gen_tensor.shape[0] >= q_tensor.shape[0] and torch.equal(gen_tensor[: q_tensor.shape[0] ], q_tensor):
                full_seq = gen_tensor.to(trainer_device)
                # response portion is the tail after the prompt
                resp_len = full_seq.shape[0] - q_tensor.shape[0]
                mask = torch.zeros(full_seq.shape[0], dtype=torch.long, device=trainer_device)
                if resp_len > 0:
                    mask[q_tensor.shape[0]:] = 1
            else:
                # Otherwise, treat gen_tensor as *only* the generated tokens:
                resp_ids = gen_tensor.tolist()
                # limit response length so query+resp <= model_max_len
                max_resp = max(0, model_max_len - q_tensor.shape[0])
                if len(resp_ids) > max_resp:
                    resp_ids = resp_ids[:max_resp]
                full_seq = torch.cat([q_tensor, torch.tensor(resp_ids, dtype=torch.long, device=trainer_device)], dim=0)
                mask = torch.zeros(full_seq.shape[0], dtype=torch.long, device=trainer_device)
                mask[q_tensor.shape[0]:] = 1 if len(resp_ids) > 0 else 0
        else:
            # gen_tensor is None -> we have only the decoded string; re-encode generated text (no special tokens)
            resp_ids = tokenizer.encode(out_text, add_special_tokens=False)
            # ensure total length within model capacity
            max_resp = max(0, model_max_len - q_tensor.shape[0])
            if len(resp_ids) > max_resp:
                resp_ids = resp_ids[:max_resp]
            # Build full sequence and mask
            resp_tensor = torch.tensor(resp_ids, dtype=torch.long, device=trainer_device) if len(resp_ids) > 0 else torch.tensor([], dtype=torch.long, device=trainer_device)
            full_seq = torch.cat([q_tensor, resp_tensor], dim=0)
            mask = torch.zeros(full_seq.shape[0], dtype=torch.long, device=trainer_device)
            if resp_tensor.numel() > 0:
                mask[q_tensor.shape[0]:] = 1

        # Defensive sanity checks
        if full_seq.ndim != 1:
            raise RuntimeError(f"full_seq must be 1-D tensor for example {i}; got shape {full_seq.shape}")
        if mask.shape[0] != full_seq.shape[0]:
            # This is the likely origin of your error: mask length differs from sequence length
            raise RuntimeError(f"Mask length mismatch for example {i}: mask_len={mask.shape[0]} full_seq_len={full_seq.shape[0]}")

        # store
        response_tensors.append(full_seq)
        response_masks.append(mask)

    # convert rewards list of floats -> list of scalar tensors on trainer device
    scores = [torch.tensor(float(r), dtype=torch.float32, device=trainer_device) for r in normed]

    # final sanity checks: lengths and batch size
    bs = ppo_trainer.config.batch_size
    if bs is not None:
        if len(tokenized_queries) != bs or len(response_tensors) != bs or len(scores) != bs or len(response_masks) != bs:
            print(f"WARNING: expected batch_size={bs} but got: queries={len(tokenized_queries)}, responses={len(response_tensors)}, scores={len(scores)}, masks={len(response_masks)}")
            # you can decide to continue or skip this step; here we continue but TRL may still raise
    # Print shapes for debugging (first example)
    if step % 10 == 0:
        print(f"Step {step}: sample0 lens -> q={tokenized_queries[0].shape[0]}, full_resp={response_tensors[0].shape[0]}, mask={response_masks[0].shape[0]}, score={scores[0].item()}")

    # Call trainer.step with tensors and masks
    ppo_trainer.step(tokenized_queries, response_tensors, scores, response_masks=response_masks)

    # checkpointing
    if (step + 1) % 10 == 0:
        ppo_trainer.model.save_pretrained(f"./lora_checkpoints/step_{step+1}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Step 0: sample0 lens -> q=14, full_resp=142, mask=142, score=-1.0033292770385742


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


KeyboardInterrupt: 