# A-Agent Training - Llama-3.2-3B with GRPO

Model: Llama-3.2-3B-Instruct (3 billion parameters)

Task: Answer logical reasoning questions

Training: SFT + GRPO (reward for correct answers and proper reasoning)

NO emojis

In [None]:
import os
import json
import torch
import random
import re
import gc
from pathlib import Path
from datasets import Dataset
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template, standardize_sharegpt
from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer
from transformers import DataCollatorForSeq2Seq

print("Imports successful")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

In [None]:
CURATED_DIR = Path("MAIN_CURATED_JSON")
MODELS_DIR = Path("FINAL_MODELS")

MODELS_DIR.mkdir(exist_ok=True)

print(f"Directories ready")

In [None]:
# Load all curated questions
all_questions = []
curated_files = sorted(CURATED_DIR.glob("*.json"))

for file_path in curated_files:
    try:
        with open(file_path, 'r') as f:
            questions = json.load(f)
            for q in questions:
                if len(q.get('choices', [])) == 4 and q.get('answer', '') in ['A', 'B', 'C', 'D']:
                    all_questions.append(q)
    except Exception as e:
        print(f"Error loading {file_path.name}: {e}")

print(f"\nLoaded {len(all_questions)} valid questions")
print(f"\nSample:")
print(json.dumps(all_questions[0], indent=2))

In [None]:
A_AGENT_SYSTEM_PROMPT = """You are A-Agent, an expert logical reasoning solver for the AMD AI Dev Day Hackathon.

Your task is to analyze multiple choice questions on blood relations and seating arrangements, then provide the correct answer with detailed step-by-step reasoning.

BLOOD RELATIONS SOLVING STRATEGY:

1. IDENTIFY THE SPEAKER PERSPECTIVE:
   - Determine whose perspective the question is from
   - \"my\" always refers to the speaker
   - Fix the reference point before analyzing relationships

2. RESOLVE EACH RELATION STEP-BY-STEP:
   - \"my grandfather's only son\" = my father
   - \"my mother's daughter\" = my sister (or me if female)
   - \"my father's father\" = my grandfather
   - \"my uncle's wife\" = my aunt
   - Break down complex chains into simple parent-child or sibling links

3. MAINTAIN GENERATION LEVELS:
   Track the generational hierarchy:
   - Grandparent level: -2 (grandfather, grandmother)
   - Parent level: -1 (father, mother, uncle, aunt)
   - Self level: 0 (brother, sister, cousin)
   - Child level: +1 (son, daughter, nephew, niece)
   - Grandchild level: +2 (grandson, granddaughter)
   Never confuse upward relations (toward ancestors) with downward relations (toward descendants)

4. GENDER ASSUMPTIONS:
   - If gender is not explicitly stated, assume the speaker is male
   - Pay attention to gender-specific terms: he/she, his/her, husband/wife, brother/sister

5. FINALIZE THE RELATION:
   - After mapping all relationships, determine the final connection
   - Use standard family terms: father, mother, son, daughter, brother, sister, uncle, aunt, cousin, nephew, niece, grandfather, grandmother, brother-in-law, sister-in-law
   - Select the choice that matches this relationship

SEATING ARRANGEMENT SOLVING STRATEGY:

1. IDENTIFY THE ARRANGEMENT TYPE:
   - Linear: People sit in a straight row (left to right or numbered positions)
   - Circular: People sit around a table (clockwise/anticlockwise direction matters)

2. DETERMINE FACING DIRECTION (for circular only):
   - Facing center: \"left\" means clockwise direction, \"right\" means anticlockwise
   - Facing outward: \"left\" means anticlockwise direction, \"right\" means clockwise
   - This is critical for determining relative positions

3. LIST ALL CONSTRAINTS:
   - Write down every constraint given in the problem
   - Constraints might be: \"A is left of B\", \"C is opposite D\", \"E is between F and G\"
   - Number the constraints for systematic application

4. APPLY CONSTRAINTS STEP BY STEP:
   - Start with the most definite constraint (exact positions or opposites)
   - Build the arrangement progressively
   - Use process of elimination for remaining positions

5. VERIFY THE FINAL ARRANGEMENT:
   - Check that all constraints are satisfied
   - Ensure no contradictions exist
   - Answer the specific question asked (who sits where, who is next to whom, etc.)

CONSISTENCY CHECK:

Before providing your final answer:
- Ensure your reasoning logically leads to the answer you select
- If reasoning concludes \"X is the uncle\", your answer should match the choice that says \"uncle\"
- If reasoning concludes \"E sits at position 5\", your answer should match that choice
- Do NOT contradict your own logical deduction
- Double-check that the answer letter (A/B/C/D) corresponds to the correct option

OUTPUT FORMAT:

Provide your response in exactly this format:

answer: \"A\" or \"B\" or \"C\" or \"D\"
reasoning: \"Step 1: [first logical step]. Step 2: [second logical step]. Step 3: [third logical step]. Step 4: [fourth logical step]. Step 5: [final conclusion].\"

CRITICAL FORMAT REQUIREMENTS:
- answer line must have: answer: followed by the letter in quotes
- Only use capital letters A, B, C, or D
- reasoning line must have: reasoning: followed by exactly 5 steps in a single string
- Each step must start with \"Step N:\" where N is 1, 2, 3, 4, or 5
- All 5 steps must be in ONE continuous string, separated by periods
- Do NOT use an array or list format
- Do NOT add any other fields or text"""

print(f"A-Agent system prompt: {len(A_AGENT_SYSTEM_PROMPT)} characters")

In [None]:
# Helper functions
def format_question_with_choices(question, choices):
    formatted = f"{question}\n\nChoices:\n"
    for choice in choices:
        formatted += f"{choice}\n"
    return formatted.strip()

def format_answer_simple(answer, reasoning):
    response = f'answer: "{answer}"\n'
    response += f'reasoning: "{reasoning}"'
    return response

# Create A-Agent training examples
a_agent_examples = []

for q in all_questions:
    conversation = [
        {
            "role": "system",
            "content": A_AGENT_SYSTEM_PROMPT
        },
        {
            "role": "user",
            "content": format_question_with_choices(
                q.get('question', ''),
                q.get('choices', [])
            )
        },
        {
            "role": "assistant",
            "content": format_answer_simple(
                q.get('answer', ''),
                q.get('reasoning', '')
            )
        }
    ]
    
    a_agent_examples.append({"conversations": conversation})

print(f"Created {len(a_agent_examples)} A-Agent training examples")

# Shuffle and split
random.seed(42)
random.shuffle(a_agent_examples)
split_idx = int(len(a_agent_examples) * 0.9)

a_train = a_agent_examples[:split_idx]
a_val = a_agent_examples[split_idx:]

print(f"\nSplit:")
print(f"  Train: {len(a_train)}")
print(f"  Val: {len(a_val)}")

# Show sample
print(f"\nSample A-Agent output:")
print(a_train[0]['conversations'][2]['content'])

In [None]:
# Configuration
A_AGENT_CONFIG = {
    "model_name": "unsloth/Llama-3.2-3B-Instruct",
    "max_seq_length": 1536,
    "lora_r": 32,
    "lora_alpha": 32,
    "batch_size": 8,
    "gradient_accumulation": 2,
    "learning_rate": 2e-4,
    "num_epochs": 3,
    "warmup_steps": 10,
}

print("A-Agent Model Configuration:")
for key, value in A_AGENT_CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
# Load A-Agent model
print("\nLoading A-Agent base model...")

a_model, a_tokenizer = FastLanguageModel.from_pretrained(
    model_name=A_AGENT_CONFIG["model_name"],
    max_seq_length=A_AGENT_CONFIG["max_seq_length"],
    dtype=torch.bfloat16,
    load_in_4bit=False,
)

print(f"Loaded {A_AGENT_CONFIG['model_name']}")

# Add LoRA adapters
print("\nAdding LoRA adapters...")

a_model = FastLanguageModel.get_peft_model(
    a_model,
    r=A_AGENT_CONFIG["lora_r"],
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_alpha=A_AGENT_CONFIG["lora_alpha"],
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print(f"LoRA adapters added (r={A_AGENT_CONFIG['lora_r']})")

In [None]:
# Prepare datasets
print("\nLoading A-Agent datasets...")

a_train_dataset = Dataset.from_list(a_train)
a_val_dataset = Dataset.from_list(a_val)

print(f"Train: {len(a_train_dataset)} examples")
print(f"Val: {len(a_val_dataset)} examples")

# Set chat template for Llama
a_tokenizer = get_chat_template(a_tokenizer, chat_template="llama-3.1")

if a_tokenizer.pad_token is None:
    a_tokenizer.pad_token = a_tokenizer.eos_token
    a_tokenizer.pad_token_id = a_tokenizer.eos_token_id

# Formatting function
def formatting_prompts_func_a(examples):
    convos = examples["conversations"]
    texts = []
    for convo in convos:
        if isinstance(convo, list):
            text = a_tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
            texts.append(text)
    return {"text": texts}

print("\nFormatting datasets...")

a_train_dataset = standardize_sharegpt(a_train_dataset)
a_train_dataset = a_train_dataset.map(formatting_prompts_func_a, batched=True, remove_columns=a_train_dataset.column_names)
a_train_dataset = a_train_dataset.filter(lambda x: len(x["text"].strip()) > 0)

a_val_dataset = standardize_sharegpt(a_val_dataset)
a_val_dataset = a_val_dataset.map(formatting_prompts_func_a, batched=True, remove_columns=a_val_dataset.column_names)
a_val_dataset = a_val_dataset.filter(lambda x: len(x["text"].strip()) > 0)

print(f"Formatted {len(a_train_dataset)} train + {len(a_val_dataset)} val examples")

In [None]:
# Setup SFT trainer
print("\nSetting up A-Agent SFT trainer...")

a_output_dir = MODELS_DIR / "a_agent_llama"
a_output_dir.mkdir(exist_ok=True)

a_trainer = SFTTrainer(
    model=a_model,
    tokenizer=a_tokenizer,
    train_dataset=a_train_dataset,
    eval_dataset=a_val_dataset,
    dataset_text_field="text",
    max_seq_length=A_AGENT_CONFIG["max_seq_length"],
    data_collator=DataCollatorForSeq2Seq(tokenizer=a_tokenizer, padding=True),
    packing=False,
    args=SFTConfig(
        per_device_train_batch_size=A_AGENT_CONFIG["batch_size"],
        per_device_eval_batch_size=A_AGENT_CONFIG["batch_size"],
        gradient_accumulation_steps=A_AGENT_CONFIG["gradient_accumulation"],
        warmup_steps=A_AGENT_CONFIG["warmup_steps"],
        num_train_epochs=A_AGENT_CONFIG["num_epochs"],
        learning_rate=A_AGENT_CONFIG["learning_rate"],
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",
        seed=3407,
        output_dir=str(a_output_dir / "checkpoints"),
        report_to="none",
        bf16=True,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
    ),
)

print("Trainer configured")

In [None]:
# Train A-Agent Model (SFT Stage)
print("\nSTAGE 1: SFT TRAINING A-AGENT MODEL...\n")

FastLanguageModel.for_training(a_model)
a_trainer.train()

print("\nA-Agent SFT Training Complete")

In [None]:
# Save SFT model
print("\nSaving A-Agent SFT model...")

a_model.save_pretrained(str(a_output_dir / "sft_lora"))
a_tokenizer.save_pretrained(str(a_output_dir / "sft_lora"))
print(f"LoRA adapters saved: {a_output_dir / 'sft_lora'}")

a_model.save_pretrained_merged(str(a_output_dir / "sft_merged_16bit"), a_tokenizer, save_method="merged_16bit")
print(f"Merged 16bit saved: {a_output_dir / 'sft_merged_16bit'}")

In [None]:
# Prepare GRPO dataset
print("\nPreparing GRPO dataset...")

grpo_dataset = []

for q in all_questions:
    grpo_dataset.append({
        "prompt": [
            {"role": "system", "content": A_AGENT_SYSTEM_PROMPT},
            {"role": "user", "content": format_question_with_choices(q['question'], q['choices'])}
        ],
        "ground_truth_answer": q['answer'],
        "ground_truth_reasoning": q['reasoning']
    })

grpo_dataset = Dataset.from_list(grpo_dataset)
print(f"GRPO dataset: {len(grpo_dataset)} examples")

In [None]:
# GRPO Reward Functions for A-Agent
print("\nDefining GRPO reward functions...")

# Regex patterns for parsing
answer_pattern = re.compile(r'answer:\s*\"([A-D])\"', re.IGNORECASE)
reasoning_pattern = re.compile(r'reasoning:\s*\"([^\"]+)\"', re.IGNORECASE)
step_pattern = re.compile(r'Step\s+(\d+):', re.IGNORECASE)

# Reward Function 1: Answer Correctness
def reward_answer_correctness(prompts, completions, ground_truth_answer, **kwargs):
    """Reward if predicted answer matches ground truth"""
    scores = []
    
    for completion, true_answer in zip(completions, ground_truth_answer):
        response = completion[0]["content"]
        
        # Extract answer
        match = answer_pattern.search(response)
        if match:
            predicted_answer = match.group(1).upper()
            # High reward for correct answer
            score = 3.0 if predicted_answer == true_answer else -1.5
        else:
            # Penalty for not following format
            score = -2.0
        
        scores.append(score)
    
    return scores

# Reward Function 2: Reasoning Quality (Step 1-5)
def reward_reasoning_quality(prompts, completions, **kwargs):
    """Reward if reasoning contains Step 1 through Step 5"""
    scores = []
    
    for completion in completions:
        response = completion[0]["content"]
        score = 0.0
        
        # Extract reasoning
        match = reasoning_pattern.search(response)
        if match:
            reasoning = match.group(1)
            
            # Find all steps
            steps_found = step_pattern.findall(reasoning)
            unique_steps = set(steps_found)
            
            # Check if we have exactly steps 1-5
            expected_steps = {'1', '2', '3', '4', '5'}
            
            if expected_steps.issubset(unique_steps):
                # Perfect - all 5 steps present
                score = 2.0
            elif len(unique_steps) >= 3:
                # At least 3 steps, partial credit
                score = 0.5
            else:
                # Too few steps
                score = -1.0
        else:
            # No reasoning found
            score = -1.5
        
        scores.append(score)
    
    return scores

# Reward Function 3: Format Adherence
def reward_format_adherence(prompts, completions, **kwargs):
    """Reward if response follows the exact format"""
    scores = []
    
    for completion in completions:
        response = completion[0]["content"]
        score = 0.0
        
        # Check if both answer and reasoning are present
        has_answer = answer_pattern.search(response) is not None
        has_reasoning = reasoning_pattern.search(response) is not None
        
        if has_answer and has_reasoning:
            score = 1.0
        elif has_answer or has_reasoning:
            score = 0.3
        else:
            score = -1.0
        
        scores.append(score)
    
    return scores

print("GRPO reward functions defined:")
print("  1. Answer Correctness (±3.0 points)")
print("  2. Reasoning Quality - Step 1-5 (±2.0 points)")
print("  3. Format Adherence (±1.0 points)")
print("  Total max reward: 6.0 points")

In [None]:
# Get max prompt length for GRPO
max_seq_length = 1536
max_prompt_len = max(grpo_dataset.map(
    lambda x: {"tokens": a_tokenizer.apply_chat_template(x["prompt"], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {"length": len(x["tokens"])}))["length"])

max_prompt_length = max_prompt_len + 10
max_completion_length = max_seq_length - max_prompt_length

print(f"Max prompt length: {max_prompt_length}")
print(f"Max completion length: {max_completion_length}")

In [None]:
# GRPO configuration
grpo_output_dir = MODELS_DIR / "a_agent_llama_grpo"
grpo_output_dir.mkdir(exist_ok=True)

grpo_config = GRPOConfig(
    learning_rate=5e-6,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=12,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    max_steps=500,
    save_steps=250,
    max_grad_norm=1.0,
    report_to="none",
    output_dir=str(grpo_output_dir / "checkpoints"),
)

# Create GRPO trainer
grpo_trainer = GRPOTrainer(
    model=a_model,
    processing_class=a_tokenizer,
    reward_funcs=[
        reward_answer_correctness,
        reward_reasoning_quality,
        reward_format_adherence,
    ],
    args=grpo_config,
    train_dataset=grpo_dataset,
)

print("GRPO trainer configured")

In [None]:
# Train with GRPO
print("\nSTAGE 2: GRPO TRAINING...\n")
print("Watch the reward column increase over steps\n")

grpo_trainer.train()

print("\nGRPO Training Complete")

In [None]:
# Save final model
print("\nSaving final A-Agent model...")

a_model.save_pretrained(str(grpo_output_dir / "lora"))
a_tokenizer.save_pretrained(str(grpo_output_dir / "lora"))
print(f"LoRA adapters saved: {grpo_output_dir / 'lora'}")

a_model.save_pretrained_merged(str(grpo_output_dir / "merged_16bit"), a_tokenizer, save_method="merged_16bit")
print(f"Merged 16bit saved: {grpo_output_dir / 'merged_16bit'}")

print(f"\nA-Agent model ready: {grpo_output_dir}")

In [None]:
# Test A-Agent
print("\nTesting A-Agent Model...\n")

FastLanguageModel.for_inference(a_model)

test_questions = [
    {
        "topic": "blood_relations",
        "question": "Pointing to a woman, a man said, 'She is the daughter of my grandfather's only son.' How is the woman related to the man?",
        "choices": [
            "A) Sister",
            "B) Cousin",
            "C) Aunt",
            "D) Mother"
        ],
        "correct_answer": "A"
    },
    {
        "topic": "seating_arrangement",
        "question": "Six friends A, B, C, D, E, F are sitting in a circle facing the center. A is to the immediate right of B. C is between D and E. F is not next to B. Who is to the immediate left of A?",
        "choices": [
            "A) C",
            "B) D",
            "C) E",
            "D) F"
        ],
        "correct_answer": "B"
    }
]

for i, test_q in enumerate(test_questions, 1):
    print(f"\n{'='*60}")
    print(f"Test Question {i}")
    print('='*60)

    question_text = format_question_with_choices(test_q["question"], test_q["choices"])

    messages = [
        {"role": "system", "content": A_AGENT_SYSTEM_PROMPT},
        {"role": "user", "content": question_text}
    ]

    prompt = a_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = a_tokenizer(prompt, return_tensors="pt").to(a_model.device)

    outputs = a_model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.3,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2,
        pad_token_id=a_tokenizer.eos_token_id
    )

    response = a_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

    print(f"\nQuestion: {test_q['question']}")
    print(f"\nGenerated Answer:\n")
    print(response)

    # Extract answer and reasoning
    answer_match = re.search(r'answer:\s*\"([A-D])\"', response)
    reasoning_match = re.search(r'reasoning:\s*\"([^\"]+)\"', response)
    
    if answer_match:
        predicted = answer_match.group(1)
        correct = test_q["correct_answer"]
        status = "MATCH" if predicted == correct else "WRONG"
        print(f"\nPredicted: {predicted} | Correct: {correct} | {status}")
    else:
        print("\nCould not extract answer")
    
    if reasoning_match:
        print(f"Has reasoning: Yes")
    else:
        print("Could not extract reasoning")

print(f"\n{'='*60}")
print("A-Agent Testing Complete")
print('='*60)