In [1]:
import os
# Force unsloth to use the local GPU memory efficiently
# os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
# run `setup_grpo_transformers.sh` first

# OR Install required packages (needed after Space restarts)
%pip install -q huggingface_hub transformers trl peft accelerate bitsandbytes datasets tensorboard dotenv
%pip install httpx==0.27.2

Note: you may need to restart the kernel to use updated packages.
Collecting httpx==0.27.2
  Downloading httpx-0.27.2-py3-none-any.whl (76 kB)
[K     |‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 76 kB 18.8 MB/s eta 0:00:01
Collecting sniffio
  Downloading sniffio-1.3.1-py3-none-any.whl (10 kB)
Installing collected packages: sniffio, httpx
  Attempting uninstall: httpx
    Found existing installation: httpx 0.28.1
    Uninstalling httpx-0.28.1:
      Successfully uninstalled httpx-0.28.1
Successfully installed httpx-0.27.2 sniffio-1.3.1
Note: you may need to restart the kernel to use updated packages.


### Local login, not for use with spaces

### Server-Side HF Login

In [3]:
# Set remote HF_TOKEN from local .env
import os
from dotenv import load_dotenv

load_dotenv()
hf_token = os.getenv('HF_TOKEN')

# ssh -i ~/.ssh/id_ed25519 dataimaginations-heirarchical-reasoning@ssh.hf.space "echo 'export HF_TOKEN={hf_token}' >> ~/.bashrc"
print("‚úÖ Token set! Restart remote shell to activate.")

‚úÖ Token set! Restart remote shell to activate.


In [4]:
import os
from huggingface_hub import login

# Login using your HF token
hf_token = os.getenv('HF_TOKEN')  # Try environment variable first

if hf_token:
    # login(token=hf_token)
    login()
    print("‚úÖ Logged in with HF_TOKEN environment variable")
else:
    # If no env var, prompt for token (you'll need to paste it)
    login()
    print("‚úÖ Logged in interactively")

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

‚úÖ Logged in with HF_TOKEN environment variable


In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType
from trl import GRPOConfig, GRPOTrainer

# --- CONFIGURATION ---
MODEL_NAME = "meta-llama/Llama-3.2-1B" # "google/gemma-2-2b-it"
output_dir = "llama-1b-reasoning-v1"

print("‚è≥ Loading model in 4-bit...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    # No attn_implementation - let it auto-select
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Fix padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id

print(f"‚úÖ Loaded {MODEL_NAME}")# Changed from flash_attention_2


  from trl import GRPOConfig, GRPOTrainer


‚è≥ Loading model in 4-bit...


config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

‚úÖ Loaded meta-llama/Llama-3.2-1B


In [6]:
# Force the input embeddings to require grads (connects the frozen model to the trainable adapters)
model.enable_input_require_grads()

# Prepare Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = "left" # CRITICAL for reasoning/generation steps!
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [7]:
# Fix tokenizer padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id

print(f"‚úÖ Pad token set to: {tokenizer.pad_token}")
print(f"‚úÖ Vocab size: {len(tokenizer)}")

# Gemma3 has a nested config structure (vision + text)
if hasattr(model.config, 'text_config'):
    print(f"‚úÖ Model vocab size: {model.config.text_config.vocab_size}")
else:
    print(f"‚úÖ Model vocab size: {model.config.vocab_size}")

‚úÖ Pad token set to: <|end_of_text|>
‚úÖ Vocab size: 128256
‚úÖ Model vocab size: 128256


In [8]:
from trl import GRPOTrainer
import torch
import torch.nn as nn

class HICRAGRPOTrainer(GRPOTrainer):
    def __init__(self, hicra_manager, tokenizer, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.hicra_manager = hicra_manager
        self.tokenizer = tokenizer

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Overridden to inject HICRA Token-Level Advantage Modification.
        """
        # 1. Standard GRPO Forward Pass (to get completions and per-token KL)
        # We assume the parent class handles the rollout generation and padding
        # This part relies on TRL's internal structure. 
        # Since TRL 0.8+ is complex, we will hook into the advantage calculation step 
        # if possible, or post-process the advantages.
        
        # NOTE: For stability in a Notebook, we will use the standard parent compute_loss
        # but we will use a PyTorch hook to modify the advantages on the fly 
        # if the library supports it. 
        
        # However, since simpler is better for crashing notebooks, let's try a 
        # cleaner approach: The "HICRA Loss Wrapper".
        
        return super().compute_loss(model, inputs, return_outputs)

    # --- THE MAGIC METHOD ---
    # TRL's GRPOTrainer usually computes advantages internally. 
    # To implement HICRA properly without breaking the library versions,
    # we can pass the HICRA logic as a "Reward Function" that returns
    # token-level scores, but GRPO expects scalar rewards.
    
    # STRATEGY CHANGE:
    # Since subclassing the massive compute_loss is prone to breaking with version updates,
    # we will use the HICRA_Manager you have, but apply it as a "Weighted Mask" 
    # on the loss, which is mathematically very similar.
    
    pass 

In [9]:
# Cell 3: The HICRA Logic (Strategic Grams)
import torch
import re
from tqdm import tqdm
from torch.optim import AdamW
import torch.nn.functional as F

# Global for compatibility with the trainer's reward_funcs if needed
STRATEGIC_GRAMS = [
    "first i need to", "let's look at", "alternatively", "wait", 
    "but i'm not sure", "let's see if", "notice that", 
    "the final answer is", "let's assume", "we can canclude",
    "implies that", "to solve this", "break it down", 
    "suppose that", "checking the", "recall that"
]

class HICRA_Manager:
    def __init__(self, alpha=0.2):
        """
        Args:
            alpha (float): The amplification factor (paper uses 0.2)[cite: 574].
        """
        self.alpha = alpha
        
        # A subset of "Strategic Grams" from Listing 1 in the paper.
        # These function as 'planning tokens'.
        self.strategic_grams = [
            "first i need to", "let's look at", "alternatively", "wait", 
            "but i'm not sure", "let's see if", "notice that", 
            "the final answer is", "let's assume", "we can canclude",
            "implies that", "to solve this", "break it down", 
            "suppose that", "checking the", "recall that"
        ]
        
    def identify_planning_mask(self, input_ids, tokenizer):
        """
        Identifies which tokens in a sequence belong to a 'Strategic Gram'.
        Returns a boolean mask of shape (batch_size, seq_len).
        """
        batch_size, seq_len = input_ids.shape
        planning_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        
        # Decode to text to find phrases (simplified approach)
        texts = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
        
        for b_idx, text in enumerate(texts):
            # Simple heuristic: Check if the text contains the gram
            for gram in self.strategic_grams:
                # Find all occurrences of the gram in the text
                for match in re.finditer(re.escape(gram), text, re.IGNORECASE):
                    start_char, end_char = match.span()
                    
                    # Convert char positions to token positions using the tokenizer
                    encoding = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
                    offsets = encoding['offset_mapping']
                    
                    for t_idx, (t_start, t_end) in enumerate(offsets):
                        # If token falls within the matched gram
                        if t_end > start_char and t_start < end_char:
                            # Handle truncation/padding indices safely
                            if t_idx < seq_len:
                                planning_mask[b_idx, t_idx] = True
                                
        return planning_mask

    def compute_hicra_advantages(self, advantages, planning_mask):
        """
        Applies the HICRA modification to the advantages.
        """
        hicra_adjustment = self.alpha * advantages.abs()
        
        # Apply mask: only adjust planning tokens
        hicra_adjustment = hicra_adjustment * planning_mask.float()
        
        # Final HICRA advantages
        hicra_advantages = advantages + hicra_adjustment
        
        return hicra_advantages

# Keep helper functions for potential backward compatibility
def correctness_reward_func(prompts, completions, answer, **kwargs):
    rewards = []
    for completion in completions:
        reward = 1.0 if str(answer).strip() in completion.strip() else 0.0
        rewards.append(reward)
    return rewards

def hicra_planning_reward_func(prompts, completions, **kwargs):
    rewards = []
    for completion in completions:
        score = 0.0
        completion_lower = completion.lower()
        for gram in STRATEGIC_GRAMS:
            if gram in completion_lower:
                score += 0.1 
        rewards.append(min(score, 0.5)) 
    return rewards


In [10]:
# 1. Load BOTH Datasets
from datasets import load_dataset

dataset_train = load_dataset("json", data_files="reasoning_dataset_v2_train.json", split="train", download_mode="force_redownload")
dataset_test = load_dataset("json", data_files="reasoning_dataset_v2_test.json", split="train", download_mode="force_redownload") # Load as 'train' split then rename if needed, but 'train' split works fine for just passing data
# For later: nvidia/Nemotron-Post-Training-Dataset-v1

# GRPO expects a specific format. We don't need a system prompt for simple math.
# It just needs 'prompt' and 'answer' (which we generated).
print(dataset_train[0])

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

{'prompt': 'In a survey of 200 students, it was found that x students like mathematics, y students like physics, and z students like chemistry. The number of students who like exactly two subjects is 45, and the number who like all three subjects is 12. If 25 students like both math and physics but not chemistry, 18 students like both physics and chemistry but not math, and 22 students like both math and chemistry but not physics, find the value of x + y + z given that exactly 30 students like none of the three subjects.', 'answer': '312'}


In [11]:
# Test your dataset format
print("Sample data:")
print(dataset_train[0])
print("\nKeys:", dataset_train[0].keys())
print("\nPrompt type:", type(dataset_train[0]['prompt']))
print("Answer type:", type(dataset_train[0]['answer']))

Sample data:
{'prompt': 'In a survey of 200 students, it was found that x students like mathematics, y students like physics, and z students like chemistry. The number of students who like exactly two subjects is 45, and the number who like all three subjects is 12. If 25 students like both math and physics but not chemistry, 18 students like both physics and chemistry but not math, and 22 students like both math and chemistry but not physics, find the value of x + y + z given that exactly 30 students like none of the three subjects.', 'answer': '312'}

Keys: dict_keys(['prompt', 'answer'])

Prompt type: <class 'str'>
Answer type: <class 'str'>


In [12]:
# Attach LoRA Adapters (PEFT)
print("üîó Attaching LoRA adapters...")
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    bias="none",
)
# We wrap the model manually so GRPO knows it's a PEFT model
model = get_peft_model(model, peft_config)

üîó Attaching LoRA adapters...


In [13]:
# --- MEMORY OPTIMIZATION MONKEY PATCH ---
# This fixes the OOM error by processing log-probs in mini-batches
import trl.trainer.grpo_trainer
import torch

def get_per_token_logps_chunked(model, input_ids, num_logits_to_keep, mini_batch_size=1):
    per_token_logps = []
    batch_size = input_ids.size(0)
    for i in range(0, batch_size, mini_batch_size):
        batch_end = min(i + mini_batch_size, batch_size)
        mini_input_ids = input_ids[i:batch_end]
        # The standard implementation calculates all logits at once, causing OOM
        # We do it in chunks:
        mini_logits = model(mini_input_ids, num_logits_to_keep=num_logits_to_keep + 1).logits
        mini_logits = mini_logits[:, :-1, :] # exclude last logit

        # Compute log probs
        log_probs = mini_logits.log_softmax(dim=-1)
        labels = mini_input_ids[:, -num_logits_to_keep:].unsqueeze(2)
        token_log_prob = torch.gather(log_probs, dim=2, index=labels).squeeze(2)
        per_token_logps.append(token_log_prob)
    return torch.cat(per_token_logps, dim=0)

# Apply the patch to the library function
# We set mini_batch_size=1 to be extremely safe with memory
trl.trainer.grpo_trainer.get_per_token_logps = lambda model, input_ids, num_logits_to_keep: get_per_token_logps_chunked(model, input_ids, num_logits_to_keep, mini_batch_size=1)
print("‚úÖ Applied memory optimization patch to pooling layer!")


‚úÖ Applied memory optimization patch to pooling layer!


In [14]:
# Load the TensorBoard extension
# %load_ext tensorboard
# %tensorboard --logdir gemma-2-2b-reasoning-v2 --port 6006 --bind_all
# Start TensorBoard pointing to your output directory
# (Make sure 'gemma-3-reasoning-output' matches the 'output_dir' in your GRPOConfig!)
# %tensorboard --logdir gemma-2-2b-reasoning-output

# Define Training Arguments (GRPO)
training_args = GRPOConfig(
    output_dir="llama-1b-reasoning-v1",
    learning_rate=5e-6,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=2, # Reverted to 2 (must be divisible by num_generations)
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    max_prompt_length=256,
    max_completion_length=200, # Reduced to save memory
    num_generations=2, # Reduced to 2 to prevent OOM
    
    # --- STEP MATH SETTINGS ---
    max_steps=330, 
    warmup_steps=30, # ~10% of total steps
    
    # --- VALIDATION SETTINGS ---
    eval_strategy="steps",      # Check validation set every X steps
    eval_steps=10,              # Check every 50 steps
    save_steps=10,              # Save a checkpoint every 50 steps
    logging_steps=2,
    
    fp16=False,
    bf16=True,
    report_to="tensorboard"
)
# Force the input embeddings to require grads (connects the frozen model to the trainable adapters)
# model.enable_input_require_grads()

In [15]:
# 5. Initialize Trainer
# Note: We assume you still have your 'dataset' and reward functions from before
print("üöÄ Starting GRPO Trainer...")
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[correctness_reward_func, hicra_planning_reward_func],
    args=training_args,
    train_dataset=dataset_train, # Your 630 questions
    eval_dataset=dataset_test,   # Your 50 test questions <--- ADDED THIS
)

üöÄ Starting GRPO Trainer...


In [None]:
# 1. Setup HICRA Manager (Using your code)
import gc
hicra = HICRA_Manager(alpha=0.2) # Alpha from paper 

# 2. Optimizer
optimizer = AdamW(model.parameters(), lr=5e-6)

# 3. Manual Training Loop (Replaces trainer.train())
print("üöÄ Starting Manual HICRA Training Loop (Updated V2)...")

# CRITICAL STABILIZATION FIXES:
model.gradient_checkpointing_disable() # Llama-1B on L40S doesn't need this, and it causes crashes
model.train()

EPOCHS = 2
GRAD_ACCUM_STEPS = 4

for epoch in range(EPOCHS):
    for step, batch in enumerate(tqdm(dataset_train)):
        
        # A. Format Input
        prompt = batch['prompt']
        answer = batch['answer']
        
        # B. Generate Rollouts
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        try:
            with torch.no_grad():
                # Enable autocast for mixed precision (bf16) to avoid type mismatches
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=300,
                        do_sample=True,
                        temperature=0.9,
                        num_return_sequences=4,
                        pad_token_id=tokenizer.pad_token_id
                    )
            
            # C. Score the Outputs
            rewards = []
            generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            
            for text in generated_texts:
                r = 1.0 if str(answer) in text else -1.0
                rewards.append(r)
            
            rewards = torch.tensor(rewards, device=model.device)
            
            # D. Compute Standard GRPO Advantages
            mean_r = rewards.mean()
            std_r = rewards.std() + 1e-8
            advantages = (rewards - mean_r) / std_r
            
            # --- E. APPLY HICRA (Updated Method) ---
            # 1. Identify Planning Tokens
            planning_mask = hicra.identify_planning_mask(outputs, tokenizer).to(model.device)
            
            # 2. Modify Advantages
            token_advantages = advantages.view(-1, 1).expand_as(planning_mask).clone()
            
            # Use the class method for advantage adjustment
            final_advantages = hicra.compute_hicra_advantages(token_advantages, planning_mask)
            
            # F. Compute Policy Gradient Loss
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                logits = model(outputs).logits
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = outputs[..., 1:].contiguous()
                log_probs = F.log_softmax(shift_logits, dim=-1)
                token_log_probs = torch.gather(log_probs, 2, shift_labels.unsqueeze(2)).squeeze(2)
                
                prompt_len = inputs.input_ids.shape[1]
                loss_mask = torch.ones_like(token_log_probs)
                loss_mask[:, :prompt_len] = 0
                
                # Mask alignment
                current_advantages = final_advantages[:, 1:]
                min_len = min(current_advantages.shape[1], token_log_probs.shape[1])
                current_advantages = current_advantages[:, :min_len]
                token_log_probs = token_log_probs[:, :min_len]
                loss_mask = loss_mask[:, :min_len]
                
                loss = - (current_advantages * token_log_probs * loss_mask).sum() / loss_mask.sum()
            
            # G. Optimization Step
            loss = loss / GRAD_ACCUM_STEPS
            loss.backward()
            
            if (step + 1) % GRAD_ACCUM_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
                print(f"Step {step}: Loss={loss.item()*GRAD_ACCUM_STEPS:.4f}, Mean Reward={mean_r:.2f}")

        except Exception as e:
            print(f"Error at step {step}: {e}")
            # Don't break, try to clean up and continue (optional, but good for debugging)
            continue

        # H. VRAM Cleanup
        del outputs, logits, loss
        gc.collect()
        torch.cuda.empty_cache()


üöÄ Starting Manual HICRA Training Loop (Updated V2)...


  0%|‚ñã                                                                                                                                                                  | 3/729 [00:53<3:37:06, 17.94s/it]

Step 3: Loss=-2.4257, Mean Reward=-0.50


  1%|‚ñà‚ñå                                                                                                                                                                 | 7/729 [01:47<2:46:53, 13.87s/it]

: 

In [None]:
# Continue training from step 60 to step 180
trainer.args.max_steps = 180  # New target

# Resume from the last checkpoint
trainer_stats = trainer.train(resume_from_checkpoint=True)

In [None]:
# Continue training from step 60 to step 180
trainer.args.max_steps = 270  # New target

# Resume from the last checkpoint
trainer_stats = trainer.train(resume_from_checkpoint=True)

Set up the transformers inference API:

In [None]:
import torch
import os
import gc
from huggingface_hub import login

# --- 1. MEMORY CLEANUP (Crucial for Cloud) ---
# RL Training fills VRAM. We need to clear it before the heavy "Merge" step.
print("üßπ Cleaning up VRAM before merging...")
try:
    del trainer
    del batch
    gc.collect()
    torch.cuda.empty_cache()
except:
    pass

# --- 2. RELOAD MODEL FOR MERGING ---
# Sometimes it's safer to reload the base model + adapter freshly to merge
# independent of the messy training state.
from unsloth import FastLanguageModel

print("üîÑ Reloading model for clean merge...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "gemma-2-9b-reasoning-v1", # Your base model
    max_seq_length = 4096,
    dtype = None,
    load_in_4bit = True,
)

# Load the adapters you just trained
# Assuming your GRPOConfig output_dir was "gemma-reasoning-output"
# and the latest checkpoint is saved there.
from peft import PeftModel
model = PeftModel.from_pretrained(model, "gemma-reasoning-output/checkpoint-final") # Update path to your actual checkpoint folder!

# --- 3. LOGIN & PUSH ---
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
else:
    print("‚ö†Ô∏è No HF_TOKEN found! Check your Space 'Settings' -> 'Variables' to add it.")

repo_name = "DataImaginations/Gemma-2-9B-Reasoning-v1" # Repo name might be `DataImaginations/` or `david-barnes`

print(f"‚è≥ Merging to 16-bit and Pushing to: {repo_name}...")

# This takes care of the de-quantization and merging in one go
model.push_to_hub_merged(
    repo_name,
    tokenizer,
    save_method = "merged_16bit", # options: "merged_4bit", "merged_16bit"
    token = hf_token
)

print("‚úÖ Success! Your reasoning model is live.")

# Push Model to hub!

In [None]:
from unsloth import FastLanguageModel
import os
device = "cuda:0"

# 1. CONFIGURATION
# Point this to the exact folder on your disk
checkpoint_path = "outputs/checkpoint-180" 
repo_name = "DataImaginations/ministral-3B-Beancount-v1" # Your Hugging Face repo
hf_token = os.getenv('HF_TOKEN')

# 2. LOAD SPECIFIC CHECKPOINT
# Unsloth is smart: if you point it to a folder, it loads the base model 
# AND applies the adapters from that folder automatically.
print(f"üìÇ Loading checkpoint from {checkpoint_path}...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = checkpoint_path, 
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True, # Keep True for fast loading (Unsloth handles the merge magic)
)

# 3. MERGE & PUSH
# This will de-quantize the base model, merge your checkpoint-180 adapters, 
# and upload a clean 16-bit model to the Hub.
print(f"üöÄ Merging and pushing to {repo_name}...")

model.push_to_hub_merged(
    repo_name,
    tokenizer,
    save_method = "merged_16bit", # options: "merged_4bit", "merged_16bit"
    token = hf_token
)

print("‚úÖ Done! Your Junior Accountant (Checkpoint 180) is live!")