# Q-Agent Training - Llama-3.2-3B with GRPO

Model: Llama-3.2-3B-Instruct (3 billion parameters)

Task: Generate logical reasoning questions

Training: SFT + GRPO (reward for valid JSON structure)



In [None]:
import os
import json
import torch
import random
import re
import gc
from pathlib import Path
from datasets import Dataset
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template, standardize_sharegpt
from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer
from transformers import DataCollatorForSeq2Seq

print("Imports successful")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

In [None]:
CURATED_DIR = Path("MAIN_CURATED_JSON")
MODELS_DIR = Path("FINAL_MODELS")

MODELS_DIR.mkdir(exist_ok=True)

print(f"Directories ready")

In [None]:
# Load all curated questions and remove difficulty field
all_questions = []
curated_files = sorted(CURATED_DIR.glob("*.json"))

for file_path in curated_files:
    try:
        with open(file_path, 'r') as f:
            questions = json.load(f)
            for q in questions:
                if len(q.get('choices', [])) == 4 and q.get('answer', '') in ['A', 'B', 'C', 'D']:
                    # Explicitly remove difficulty field
                    all_questions.append({
                        "topic": q['topic'],
                        "question": q['question'],
                        "choices": q['choices'],
                        "answer": q['answer'],
                        "explanation": q['explanation'],
                        "reasoning": q['reasoning']
                    })
    except Exception as e:
        print(f"Error loading {file_path.name}: {e}")

print(f"\nLoaded {len(all_questions)} valid questions")
print(f"\nSample question (NO difficulty field):")
print(json.dumps(all_questions[0], indent=2))

In [None]:
Q_AGENT_SYSTEM_PROMPT = """You are Q-Agent, an expert question generator for the AMD AI Dev Day Hackathon.

Your task is to generate high-quality logical reasoning questions for competitive exam preparation.

ALLOWED TOPICS:
Only generate questions on these two topics:
1. blood_relations - Family relationship puzzles and genealogy problems
2. seating_arrangement - Linear row seating OR circular table seating

CRITICAL REQUIREMENTS:

1. ENTITY COUNT:
   - Include 3-6 named entities (people, objects, or positions) per question
   - Use distinct names or letters (A, B, C, D, E or Ram, Sita, John, Mary)
   - Each entity must play a clear role in the problem

2. CONSTRAINTS AND RELATIONS:
   - Provide at least 3 unique relations or constraints
   - Each constraint must add information to solve the problem
   - Constraints should not contradict each other

3. SELF-CONTAINED QUESTIONS:
   - Include ALL information needed to solve the problem within the question text
   - Do NOT assume any external context or prior knowledge beyond basic logic
   - State ALL relationships, positions, and constraints explicitly
   - BAD: \"How is X related to Y?\" (missing context about who X and Y are)
   - GOOD: \"X is the son of Y. Y is the brother of Z. Z is the father of W. How is X related to W?\"

4. NO CODED RELATIONS:
   - Do NOT use mathematical symbols to represent relationships
   - Avoid: *, +, -, ×, %, $, #, @ to indicate relations
   - Use plain English descriptions only
   - BAD: \"K * J × L + M\" to represent family relations
   - GOOD: \"K is the father of J. J is the mother of L. M is the son of L.\"

5. MULTIPLE CHOICE FORMAT:
   - Provide exactly 4 choices labeled A, B, C, D
   - Format: [\"A) option1\", \"B) option2\", \"C) option3\", \"D) option4\"]
   - Each choice must start with the letter, followed by parenthesis and space
   - Only ONE choice should be correct
   - Other three choices should be plausible but incorrect distractors

6. ANSWER FORMAT:
   - Provide single letter answer: \"A\", \"B\", \"C\", or \"D\"
   - Nothing else, just the capital letter

7. EXPLANATION:
   - Provide brief explanation (under 100 words) why the answer is correct
   - Focus on the key logical step that leads to the answer

8. REASONING:
   - Provide detailed step-by-step reasoning
   - Format as SINGLE STRING with exactly 5 steps
   - \"Step 1: [description]. Step 2: [description]. Step 3: [description]. Step 4: [description]. Step 5: [description].\"
   - Each step should build on previous steps to reach the answer

BLOOD RELATIONS GUIDELINES:

Use clear family relationship terms:
- Direct relations: father, mother, son, daughter, brother, sister
- Extended relations: grandfather, grandmother, uncle, aunt, cousin, nephew, niece
- In-law relations: brother-in-law, sister-in-law, father-in-law, mother-in-law

Structure:
- Start with 3-5 relationship statements establishing the family tree
- End with a question asking about a specific relationship
- Ensure the chain of relationships logically connects to the answer

Example pattern:
\"A is the son of B. B is the brother of C. C is the father of D. How is A related to D?\"

SEATING ARRANGEMENT GUIDELINES:

Linear Arrangements:
- Format: \"P, Q, R, S, T sit in a row from left to right\"
- Use positional terms: left, right, leftmost, rightmost, between, adjacent
- Example: \"P sits third from the left. Q sits between P and R.\"

Circular Arrangements:
- Format: \"A, B, C, D, E sit around a circular table\"
- MUST specify facing direction: \"facing the center\" or \"facing outward\"
- Use terms: clockwise, anticlockwise, opposite, between, next to
- Example: \"A sits to the immediate right of B. C sits opposite to D.\"

Restrictions:
- Use ONLY linear rows OR circular tables
- Do NOT use: rectangular tables, square arrangements, compass directions (North, South, East, West)
- Do NOT use: boxes, grids, or complex 2D layouts

OUTPUT FORMAT:

Generate a single JSON object with exactly 6 fields (NO difficulty field):

{
  \"topic\": \"blood_relations\" or \"seating_arrangement\",
  \"question\": \"Complete self-contained problem statement with all necessary information\",
  \"choices\": [\"A) option1\", \"B) option2\", \"C) option3\", \"D) option4\"],
  \"answer\": \"A\" or \"B\" or \"C\" or \"D\",
  \"explanation\": \"Brief justification under 100 words explaining why the answer is correct\",
  \"reasoning\": \"Step 1: [description]. Step 2: [description]. Step 3: [description]. Step 4: [description]. Step 5: [description].\"
}

IMPORTANT:
- Do NOT include a difficulty field
- Difficulty emerges naturally from entity count and relation complexity
- Return ONLY valid JSON, no markdown code blocks
- No explanatory text before or after the JSON object"""

print(f"Q-Agent system prompt: {len(Q_AGENT_SYSTEM_PROMPT)} characters")

In [None]:
# Create Q-Agent training examples
q_agent_examples = []

for q in all_questions:
    topic_readable = q['topic'].replace('_', ' ')
    
    conversation = [
        {
            "role": "system",
            "content": Q_AGENT_SYSTEM_PROMPT
        },
        {
            "role": "user",
            "content": f"Generate a {topic_readable} question with 3-6 entities, multiple choice options, and step-by-step reasoning."
        },
        {
            "role": "assistant",
            "content": json.dumps(q, ensure_ascii=False)
        }
    ]
    
    q_agent_examples.append({"conversations": conversation})

print(f"Created {len(q_agent_examples)} Q-Agent training examples")

# Shuffle and split
random.seed(42)
random.shuffle(q_agent_examples)
split_idx = int(len(q_agent_examples) * 0.9)

q_train = q_agent_examples[:split_idx]
q_val = q_agent_examples[split_idx:]

print(f"\nSplit:")
print(f"  Train: {len(q_train)}")
print(f"  Val: {len(q_val)}")

# Show sample
print(f"\nSample Q-Agent output:")
print(q_train[0]['conversations'][2]['content'][:200] + "...")

In [None]:
# Configuration
Q_AGENT_CONFIG = {
    "model_name": "unsloth/Llama-3.2-3B-Instruct",
    "max_seq_length": 2048,
    "lora_r": 32,
    "lora_alpha": 32,
    "batch_size": 4,
    "gradient_accumulation": 4,
    "learning_rate": 2e-4,
    "num_epochs": 3,
    "warmup_steps": 10,
}

print("Q-Agent Model Configuration:")
for key, value in Q_AGENT_CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
# Load Q-Agent model
print("\nLoading Q-Agent base model...")

q_model, q_tokenizer = FastLanguageModel.from_pretrained(
    model_name=Q_AGENT_CONFIG["model_name"],
    max_seq_length=Q_AGENT_CONFIG["max_seq_length"],
    dtype=torch.bfloat16,
    load_in_4bit=False,
)

print(f"Loaded {Q_AGENT_CONFIG['model_name']}")

# Add LoRA adapters
print("\nAdding LoRA adapters...")

q_model = FastLanguageModel.get_peft_model(
    q_model,
    r=Q_AGENT_CONFIG["lora_r"],
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_alpha=Q_AGENT_CONFIG["lora_alpha"],
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print(f"LoRA adapters added (r={Q_AGENT_CONFIG['lora_r']})")

In [None]:
# Prepare datasets
print("\nLoading Q-Agent datasets...")

q_train_dataset = Dataset.from_list(q_train)
q_val_dataset = Dataset.from_list(q_val)

print(f"Train: {len(q_train_dataset)} examples")
print(f"Val: {len(q_val_dataset)} examples")

# Set chat template for Llama
q_tokenizer = get_chat_template(q_tokenizer, chat_template="llama-3.1")

if q_tokenizer.pad_token is None:
    q_tokenizer.pad_token = q_tokenizer.eos_token
    q_tokenizer.pad_token_id = q_tokenizer.eos_token_id

# Formatting function
def formatting_prompts_func_q(examples):
    convos = examples["conversations"]
    texts = []
    for convo in convos:
        if isinstance(convo, list):
            text = q_tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
            texts.append(text)
    return {"text": texts}

print("\nFormatting datasets...")

q_train_dataset = standardize_sharegpt(q_train_dataset)
q_train_dataset = q_train_dataset.map(formatting_prompts_func_q, batched=True, remove_columns=q_train_dataset.column_names)
q_train_dataset = q_train_dataset.filter(lambda x: len(x["text"].strip()) > 0)

q_val_dataset = standardize_sharegpt(q_val_dataset)
q_val_dataset = q_val_dataset.map(formatting_prompts_func_q, batched=True, remove_columns=q_val_dataset.column_names)
q_val_dataset = q_val_dataset.filter(lambda x: len(x["text"].strip()) > 0)

print(f"Formatted {len(q_train_dataset)} train + {len(q_val_dataset)} val examples")

In [None]:
# Setup SFT trainer
print("\nSetting up Q-Agent SFT trainer...")

q_output_dir = MODELS_DIR / "q_agent_llama"
q_output_dir.mkdir(exist_ok=True)

q_trainer = SFTTrainer(
    model=q_model,
    tokenizer=q_tokenizer,
    train_dataset=q_train_dataset,
    eval_dataset=q_val_dataset,
    dataset_text_field="text",
    max_seq_length=Q_AGENT_CONFIG["max_seq_length"],
    data_collator=DataCollatorForSeq2Seq(tokenizer=q_tokenizer, padding=True),
    packing=False,
    args=SFTConfig(
        per_device_train_batch_size=Q_AGENT_CONFIG["batch_size"],
        per_device_eval_batch_size=Q_AGENT_CONFIG["batch_size"],
        gradient_accumulation_steps=Q_AGENT_CONFIG["gradient_accumulation"],
        warmup_steps=Q_AGENT_CONFIG["warmup_steps"],
        num_train_epochs=Q_AGENT_CONFIG["num_epochs"],
        learning_rate=Q_AGENT_CONFIG["learning_rate"],
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",
        seed=3407,
        output_dir=str(q_output_dir / "checkpoints"),
        report_to="none",
        bf16=True,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
    ),
)

print("Trainer configured")

In [None]:
# Train Q-Agent Model (SFT Stage)
print("\nSTAGE 1: SFT TRAINING Q-AGENT MODEL...\n")

FastLanguageModel.for_training(q_model)
q_trainer.train()

print("\nQ-Agent SFT Training Complete")

In [None]:
# Save SFT model
print("\nSaving Q-Agent SFT model...")

q_model.save_pretrained(str(q_output_dir / "sft_lora"))
q_tokenizer.save_pretrained(str(q_output_dir / "sft_lora"))
print(f"LoRA adapters saved: {q_output_dir / 'sft_lora'}")

q_model.save_pretrained_merged(str(q_output_dir / "sft_merged_16bit"), q_tokenizer, save_method="merged_16bit")
print(f"Merged 16bit saved: {q_output_dir / 'sft_merged_16bit'}")

In [None]:
# Prepare GRPO dataset
print("\nPreparing GRPO dataset...")

grpo_dataset = []

for q in all_questions:
    topic_readable = q['topic'].replace('_', ' ')
    
    grpo_dataset.append({
        "prompt": [
            {"role": "system", "content": Q_AGENT_SYSTEM_PROMPT},
            {"role": "user", "content": f"Generate a {topic_readable} question with 3-6 entities, multiple choice options, and step-by-step reasoning."}
        ],
        "ground_truth": q
    })

grpo_dataset = Dataset.from_list(grpo_dataset)
print(f"GRPO dataset: {len(grpo_dataset)} examples")

In [None]:
# GRPO Reward Functions for Q-Agent
print("\nDefining GRPO reward functions...")

# Reward Function 1: Valid JSON Structure
def reward_json_validity(prompts, completions, **kwargs):
    """Reward if response is valid JSON"""
    scores = []
    
    for completion in completions:
        response = completion[0]["content"]
        
        try:
            parsed = json.loads(response)
            # Valid JSON gets high reward
            score = 3.0
        except json.JSONDecodeError:
            # Try to extract JSON from text
            if '{' in response and '}' in response:
                json_start = response.find('{')
                json_end = response.rfind('}') + 1
                try:
                    parsed = json.loads(response[json_start:json_end])
                    score = 2.0  # Partial credit for extractable JSON
                except:
                    score = -2.0  # Penalty for invalid JSON
            else:
                score = -3.0  # Strong penalty for no JSON
        
        scores.append(score)
    
    return scores

# Reward Function 2: Required Fields Present
def reward_required_fields(prompts, completions, **kwargs):
    """Reward if all required fields are present"""
    scores = []
    required_fields = ['topic', 'question', 'choices', 'answer', 'explanation', 'reasoning']
    
    for completion in completions:
        response = completion[0]["content"]
        
        try:
            # Try direct parse
            parsed = json.loads(response)
        except:
            # Try extraction
            if '{' in response and '}' in response:
                json_start = response.find('{')
                json_end = response.rfind('}') + 1
                try:
                    parsed = json.loads(response[json_start:json_end])
                except:
                    parsed = None
            else:
                parsed = None
        
        if parsed:
            # Count present fields
            present = sum(1 for field in required_fields if field in parsed)
            
            if present == 6:
                score = 2.0  # All fields present
            elif present >= 4:
                score = 0.5  # Most fields present
            else:
                score = -1.0  # Too few fields
            
            # Penalty if difficulty field is present (we don't want it)
            if 'difficulty' in parsed:
                score -= 1.0
        else:
            score = -1.5
        
        scores.append(score)
    
    return scores

# Reward Function 3: Format Correctness
def reward_format_correctness(prompts, completions, **kwargs):
    """Reward if choices and answer format are correct"""
    scores = []
    
    for completion in completions:
        response = completion[0]["content"]
        
        try:
            parsed = json.loads(response)
        except:
            if '{' in response and '}' in response:
                json_start = response.find('{')
                json_end = response.rfind('}') + 1
                try:
                    parsed = json.loads(response[json_start:json_end])
                except:
                    parsed = None
            else:
                parsed = None
        
        if parsed:
            score = 0.0
            
            # Check choices format
            choices = parsed.get('choices', [])
            if isinstance(choices, list) and len(choices) == 4:
                # Check if choices start with A), B), C), D)
                if all(c[0] in 'ABCD' and c[1:3] == ') ' for c in choices if len(c) > 2):
                    score += 1.0
            
            # Check answer format
            answer = parsed.get('answer', '')
            if isinstance(answer, str) and answer in ['A', 'B', 'C', 'D']:
                score += 1.0
            
            # Check reasoning has steps
            reasoning = parsed.get('reasoning', '')
            if 'Step 1' in reasoning and 'Step 5' in reasoning:
                score += 1.0
        else:
            score = -1.5
        
        scores.append(score)
    
    return scores

print("GRPO reward functions defined:")
print("  1. JSON Validity (±3.0 points)")
print("  2. Required Fields (±2.0 points)")
print("  3. Format Correctness (±3.0 points)")
print("  Total max reward: 8.0 points")

In [None]:
# Get max prompt length for GRPO
max_seq_length = 2048
max_prompt_len = max(grpo_dataset.map(
    lambda x: {"tokens": q_tokenizer.apply_chat_template(x["prompt"], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {"length": len(x["tokens"])}))["length"])

max_prompt_length = max_prompt_len + 10
max_completion_length = max_seq_length - max_prompt_length

print(f"Max prompt length: {max_prompt_length}")
print(f"Max completion length: {max_completion_length}")

In [None]:
# GRPO configuration
grpo_output_dir = MODELS_DIR / "q_agent_llama_grpo"
grpo_output_dir.mkdir(exist_ok=True)

grpo_config = GRPOConfig(
    learning_rate=5e-6,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=8,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    max_steps=500,
    save_steps=250,
    max_grad_norm=1.0,
    report_to="none",
    output_dir=str(grpo_output_dir / "checkpoints"),
)

# Create GRPO trainer
grpo_trainer = GRPOTrainer(
    model=q_model,
    processing_class=q_tokenizer,
    reward_funcs=[
        reward_json_validity,
        reward_required_fields,
        reward_format_correctness,
    ],
    args=grpo_config,
    train_dataset=grpo_dataset,
)

print("GRPO trainer configured")

In [None]:
# Train with GRPO
print("\nSTAGE 2: GRPO TRAINING...\n")
print("Watch the reward column increase over steps\n")

grpo_trainer.train()

print("\nGRPO Training Complete")

In [None]:
# Save final model
print("\nSaving final Q-Agent model...")

q_model.save_pretrained(str(grpo_output_dir / "lora"))
q_tokenizer.save_pretrained(str(grpo_output_dir / "lora"))
print(f"LoRA adapters saved: {grpo_output_dir / 'lora'}")

q_model.save_pretrained_merged(str(grpo_output_dir / "merged_16bit"), q_tokenizer, save_method="merged_16bit")
print(f"Merged 16bit saved: {grpo_output_dir / 'merged_16bit'}")

print(f"\nQ-Agent model ready: {grpo_output_dir}")

In [None]:
# Test Q-Agent
print("\nTesting Q-Agent Model...\n")

FastLanguageModel.for_inference(q_model)

test_topics = [
    "blood relations",
    "seating arrangement"
]

for topic in test_topics:
    print(f"\n{'='*60}")
    print(f"Test: Generate {topic} question")
    print('='*60)

    messages = [
        {"role": "system", "content": Q_AGENT_SYSTEM_PROMPT},
        {"role": "user", "content": f"Generate a {topic} question with 3-6 entities, multiple choice options, and step-by-step reasoning."}
    ]

    prompt = q_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = q_tokenizer(prompt, return_tensors="pt").to(q_model.device)

    outputs = q_model.generate(
        **inputs,
        max_new_tokens=1024,
        temperature=0.3,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2,
        pad_token_id=q_tokenizer.eos_token_id
    )

    response = q_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

    print(f"\nGenerated JSON:\n")
    print(response[:500])
    
    # Try to parse JSON
    try:
        q_json = json.loads(response)
        print(f"\nValid JSON")
        print(f"Topic: {q_json.get('topic')}")
        print(f"Choices: {len(q_json.get('choices', []))}")
        print(f"Answer: {q_json.get('answer')}")
        print(f"Has reasoning: {'reasoning' in q_json}")
        print(f"NO difficulty: {'difficulty' not in q_json}")
    except Exception as e:
        print(f"\nCould not parse JSON: {e}")

print(f"\n{'='*60}")
print("Q-Agent Testing Complete")
print('='*60)