In [None]:
%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

!pip install unsloth vllm
!pip install --upgrade pillow
!pip install rapidfuzz datasets
# If you are running this notebook on local, you need to install `diffusers` too
# !pip install diffusers
# Temporarily install a specific TRL nightly version
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
import pandas as pd
from tqdm import tqdm
from rapidfuzz import fuzz
from datasets import load_dataset, Dataset
import re
from unsloth import is_bfloat16_supported, FastLanguageModel, PatchFastRL

In [None]:
PatchFastRL("GRPO", FastLanguageModel)

In [None]:
import torch
max_seq_length = 1000 # Can increase for longer reasoning traces
lora_rank = 8 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.2, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

In [None]:
import torch
torch.cuda.empty_cache()
!nvidia-smi


In [None]:
# --- 1. Load and Process the Jailbreak Dataset ---

jailbreak_dataset = load_dataset('TrustAIRLab/in-the-wild-jailbreak-prompts', 'jailbreak_2023_05_07')['train']
jailbreak_df = pd.DataFrame(jailbreak_dataset)

print("\nJailbreak Dataset Info:")
print(f"Total rows: {len(jailbreak_df)}")
print(f"Unique prompts: {jailbreak_df['prompt'].nunique()}")
print("\nSample of jailbreak prompts:")
print(jailbreak_df[['prompt']].head(3))

# Remove exact duplicates from jailbreak dataset
jailbreak_df_unique = jailbreak_df.drop_duplicates(subset=['prompt'])
print(f"\nShape after exact duplicates: {len(jailbreak_df_unique)} rows")

def find_similar_prompts(prompts, threshold=90):
    """
    Find groups of similar prompts using fuzzy matching.
    """
    similar_groups = {}
    processed = set()
    print("Finding similar prompts...")
    for i in tqdm(range(len(prompts))):
        if i in processed:
            continue
        current_prompt = prompts[i]
        group = [i]
        for j in range(i + 1, len(prompts)):
            if j in processed:
                continue
            if fuzz.ratio(current_prompt, prompts[j]) >= threshold:
                group.append(j)
                processed.add(j)
        if len(group) > 1:
            similar_groups[i] = group
        processed.add(i)
    return similar_groups


# Remove fuzzy duplicates from jailbreak dataset
jailbreak_prompts = jailbreak_df['prompt'].tolist()
similar_groups = find_similar_prompts(jailbreak_prompts, threshold=90)
indices_to_keep = set(range(len(jailbreak_prompts))) - {
    idx for group in similar_groups.values() for idx in group[1:]
}
jailbreak_df_fuzzy_unique = jailbreak_df.iloc[list(indices_to_keep)]


In [None]:

# --- 2. Load Regular Prompts and Combine Datasets ---

regular_dataset = load_dataset('TrustAIRLab/in-the-wild-jailbreak-prompts', 'regular_2023_05_07')['train']
regular_df = pd.DataFrame(regular_dataset)

# Sample an equal number of regular prompts
num_jailbreak = len(jailbreak_df_fuzzy_unique)
regular_df_sampled = regular_df.sample(n=num_jailbreak, random_state=42)

# Add labels
jailbreak_df_fuzzy_unique['label'] = "jailbreak"
regular_df_sampled['label'] = "safe"

# Combine and shuffle
combined_df = pd.concat([jailbreak_df_fuzzy_unique, regular_df_sampled]).sample(frac=1, random_state=42)

# Convert to HuggingFace Dataset format
combined_dataset = Dataset.from_pandas(combined_df)


In [None]:

# --- 3. Define Prompt Template and Formatting Function ---

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

# Set the maximum sequence length
MAX_LENGTH = 1000  # Hard limit

def formatting_prompts_func(examples):
    """
    Formats the dataset examples into the desired prompt-response structure and tokenizes them.
    """
    prompts = examples["prompt"]
    labels = examples["label"]

    # Map labels to response text
    responses = ["jailbreak" if label == "jailbreak" else "safe" for label in labels]

    texts = []
    for prompt, response in zip(prompts, responses):
        text = (
            f"{SYSTEM_PROMPT}\n### Instruction:\n{prompt}\n### Response:\n"
            f"<reasoning>\n...\n</reasoning>\n"
            f"<answer>\n{response}\n</answer>"
        )
        texts.append(text)

    # Tokenize with truncation and padding
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length",
        return_tensors="pt",
        add_special_tokens=True
    )

    # Force the final token in each sequence to be the EOS token
    input_ids = tokenized["input_ids"]
    input_ids[:, -1] = tokenizer.eos_token_id

    # Decode the tokenized input_ids back into text so that the text field is also truncated.
    truncated_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]

    return {
        "prompt": truncated_texts,  # Overwrite original prompt field
        "labels": responses,
        "input_ids": input_ids.tolist(),
        "attention_mask": tokenized["attention_mask"].tolist(),
    }


In [None]:

# --- 4. Map the Formatting Function over the Dataset ---

# Note: Ensure that this is the only mapping call for formatting.
dataset = combined_dataset.map(formatting_prompts_func, batched=True)
# After dataset mapping
def verify_token_length(example):
    tokens = tokenizer(example["prompt"])["input_ids"]
    assert len(tokens) <= MAX_LENGTH, f"Token length {len(tokens)} exceeds limit!"
    return example

dataset = dataset.map(verify_token_length)


In [None]:

# --- 5. Prepare Model, Trainer, and Start Training ---
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm=True,
    learning_rate=3e-6,  # Lower LR for more stable training
    adam_beta1=0.9,
    adam_beta2=0.999,  # Closer to default AdamW params
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    logging_steps=50,  # Log more often
    bf16=True,  # Disable mixed precision
    fp16=False,
    per_device_train_batch_size=4,  # Increase batch size if memory allows
    gradient_accumulation_steps=2,  # Smoother training
    num_generations=6,
    max_prompt_length=MAX_LENGTH,
    max_completion_length=200,
    max_steps=100,  # Train longer for better generalization
    save_steps=500,  # Save more frequently
    max_grad_norm=0.5,  # Reduce aggressive gradient clipping
    num_train_epochs = 2,
    report_to = "none",
    output_dir="outputs"
)
prompt_to_label = {example["prompt"]: example["label"].lower() for example in dataset}

def correctness_reward_func(prompts, completions, **kwargs):
    """
    Assigns reward based on correctness with fallback for empty outputs.
    """
    extracted_responses = []
    for r in completions:
        if "<answer>" in r:
            extracted_responses.append(r.split("<answer>")[-1].split("</answer>")[0].strip())
        else:
            extracted_responses.append("")

    rewards = []
    for r, a in zip(extracted_responses, prompts):
        correct_label = prompt_to_label.get(a, "").lower()
        if not r:  # If empty response, penalize
            rewards.append(-1.0)
        elif r.lower() == correct_label:
            rewards.append(2.0)
        elif fuzz.ratio(r.lower(), correct_label) >= 85:  # Partial match
            rewards.append(1.5)
        else:
            rewards.append(0.0)
    return rewards



def format_reward_func(completions, **kwargs):
    """
    Rewards completions that follow the structured XML format.
    """
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = completions
    return [0.5 if re.match(pattern, r) else 0.0 for r in responses]

def reasoning_reward_func(completions, **kwargs):
    """
    Rewards completions that include a reasoning step before answering.
    """
    responses = completions
    return [0.5 if "<reasoning>" in r and "</reasoning>" in r else 0.0 for r in responses]



In [None]:
model.to(torch.bfloat16)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        correctness_reward_func,
        format_reward_func,
        reasoning_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
)

# Start training
trainer.train()