# Unsloth GRPO Training for Hierarchical Reasoning

This notebook uses Unsloth's optimized kernels with TRL's GRPOTrainer for stable training.

**Key features:**
- ~50% less VRAM usage than standard transformers
- vLLM fast inference for generation
- HICRA-inspired reward functions for reasoning

In [15]:
# Cell 1: Environment Setup
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"  # Extra 30% context lengths
os.environ["fix_mistral_regex"] = "True"
# os.environ["OMP_NUM_THREADS"] = "1"

# Install dependencies (run this if not already installed)
# !pip install unsloth vllm
# !pip install transformers==4.56.2
# !pip install --no-deps trl==0.22.2

In [2]:
# Cell 2: HuggingFace Login
import os
from huggingface_hub import login
from dotenv import load_dotenv

load_dotenv()
hf_token = os.getenv('HF_TOKEN')

if hf_token:
    login(token=hf_token)
    print("‚úÖ Logged in with HF_TOKEN")
else:
    login()
    print("‚úÖ Logged in interactively")

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


‚úÖ Logged in with HF_TOKEN


In [3]:
from unsloth import FastLanguageModel
import torch
# Configuration
max_seq_length = 1024  # Can increase for longer reasoning traces
lora_rank = 32  # Larger rank = smarter, but slower
print("‚è≥ Loading model with Unsloth...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    fast_inference=False,  # Disabled - requires CUDA toolkit for vLLM
)
print("üîó Attaching LoRA adapters...")
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # Optimized gradient checkpointing
    random_state=3407,
)
print("‚úÖ Model loaded successfully!")

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.
ü¶• Unsloth Zoo will now patch everything to make training faster!
‚è≥ Loading model with Unsloth...
==((====))==  Unsloth 2025.12.8: Fast Llama patching. Transformers: 4.57.3. vLLM: 0.13.0.
   \\   /|    NVIDIA GeForce RTX 4070 SUPER. Num GPUs = 1. Max memory: 11.594 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
üîó Attaching LoRA adapters...


Unsloth 2025.12.8 patched 16 layers with 16 QKV layers, 16 O layers and 16 MLP layers.


‚úÖ Model loaded successfully!


In [4]:
# Cell 4: Load Dataset
from datasets import load_dataset

# System prompt for reasoning format
SYSTEM_PROMPT = """
You are a mathematical reasoning assistant. Think through problems step by step.
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def format_prompt(example):
    """Format dataset for GRPO training with chat template."""
    return {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT.strip()},
            {'role': 'user', 'content': example['prompt']}
        ],
        'answer': str(example['answer'])  # Ensure string type
    }

# Load datasets
print("üìÇ Loading datasets...")
dataset_train = load_dataset(
    "json", 
    data_files="reasoning_dataset_v2_train.json", 
    split="train"
)
dataset_test = load_dataset(
    "json", 
    data_files="reasoning_dataset_v2_test.json", 
    split="train"
)

# Format for GRPO
dataset_train = dataset_train.map(format_prompt)
dataset_test = dataset_test.map(format_prompt)

print(f"‚úÖ Loaded {len(dataset_train)} training examples")
print(f"‚úÖ Loaded {len(dataset_test)} test examples")
print(f"\nSample prompt format:")
print(dataset_train[0]['prompt'])

üìÇ Loading datasets...
‚úÖ Loaded 729 training examples
‚úÖ Loaded 36 test examples

Sample prompt format:
[{'content': 'You are a mathematical reasoning assistant. Think through problems step by step.\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>', 'role': 'system'}, {'content': 'In a survey of 200 students, it was found that x students like mathematics, y students like physics, and z students like chemistry. The number of students who like exactly two subjects is 45, and the number who like all three subjects is 12. If 25 students like both math and physics but not chemistry, 18 students like both physics and chemistry but not math, and 22 students like both math and chemistry but not physics, find the value of x + y + z given that exactly 30 students like none of the three subjects.', 'role': 'user'}]


In [12]:
# Cell 5: Reward Functions
import re

# Strategic reasoning phrases (from HICRA paper)
STRATEGIC_GRAMS = [
    "first i need to", "let's look at", "alternatively", "wait",
    "but i'm not sure", "let's see if", "notice that",
    "the final answer is", "let's assume", "we can conclude",
    "implies that", "to solve this", "break it down",
    "suppose that", "checking the", "recall that",
    "step 1", "step 2", "therefore", "thus"
]

def extract_xml_answer(text: str) -> str:
    """Extract answer from <answer> tags."""
    if "<answer>" not in text:
        return text.strip()
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Check if the model's answer matches the expected answer.
    Returns 2.0 for correct, 0.0 for incorrect.
    """
    responses = [completion[0]['content'] for completion in completions]
    extracted = [extract_xml_answer(r) for r in responses]
    
    # Debug output (first item only)
    q = prompts[0][-1]['content'][:100]  # First 100 chars of question
    print(f"---\nQ: {q}...\nExpected: {answer[0]}\nExtracted: {extracted[0][:50]}...")
    
    rewards = []
    for ext, ans in zip(extracted, answer):
        # Check if answer appears in extracted text
        if str(ans).strip() in ext:
            rewards.append(2.0)
        else:
            rewards.append(0.0)
    return rewards

def reasoning_reward_func(completions, **kwargs) -> list[float]:
    """
    HICRA-inspired reward for reasoning structure.
    Gives bonus for using strategic reasoning phrases.
    """
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    
    for response in responses:
        score = 0.0
        response_lower = response.lower()
        
        # Check for strategic grams
        for gram in STRATEGIC_GRAMS:
            if gram in response_lower:
                score += 0.05
        
        # Bonus for using reasoning tags
        if "<reasoning>" in response and "</reasoning>" in response:
            score += 0.2
        if "<answer>" in response and "</answer>" in response:
            score += 0.1
        
        # Cap the reward
        rewards.append(min(score, 0.5))
    
    return rewards

def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward for correct XML format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]['content'] for completion in completions]
    return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]

print("‚úÖ Reward functions defined")

‚úÖ Reward functions defined


### Chat Template (Save for Base models)

```
# Set Llama 3 chat template (required for GRPO with conversational data)
tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{{ message['content'] }}<|eot_id|>{% elif message['role'] == 'user' %}<|start_header_id|>user<|end_header_id|}
{{ message['content'] }}<|eot_id|>{% elif message['role'] == 'assistant' %}<|start_header_id|>assistant<|end_header_id|>
{{ message['content'] }}<|eot_id|>{% endif %}{% endfor %}{% if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>
{% endif %}"""
print("‚úÖ Chat template set!")
```

In [6]:
# Cell 6: Training Configuration
from trl import GRPOConfig, GRPOTrainer

max_prompt_length = 256

training_args = GRPOConfig(
    # Optimization
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",  # Memory efficient optimizer
    
    # Batch settings
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=4,  # Reduce to 2 if OOM
    
    # Sequence lengths
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    
    # Training duration
    max_steps=250,
    
    # Stability
    max_grad_norm=0.1,  # Aggressive gradient clipping
    
    # Logging & Saving
    logging_steps=1,
    save_steps=50,
    output_dir="llama-1b-reasoning-unsloth",
    report_to="none",  # Change to "tensorboard" or "wandb" if desired
)

print("‚úÖ Training configuration set")

‚úÖ Training configuration set


In [7]:
# Cell 7: Initialize Trainer
print("üöÄ Initializing GRPO Trainer...")

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        correctness_reward_func,
        reasoning_reward_func,
        format_reward_func,
    ],
    args=training_args,
    train_dataset=dataset_train,
)

print("‚úÖ Trainer initialized!")

üöÄ Initializing GRPO Trainer...
‚úÖ Trainer initialized!


In [8]:
# Cell 8: Run Training!
print("üèãÔ∏è Starting training...")
print("Note: First ~100 steps may show 0 reward. Be patient!")
print("="*50)

trainer_stats = trainer.train()

print("="*50)
print("‚úÖ Training complete!")

The model is already on multiple devices. Skipping the move to device specified in `args`.


üèãÔ∏è Starting training...
Note: First ~100 steps may show 0 reward. Be patient!


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 729 | Num Epochs = 1 | Total steps = 250
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 22,544,384 of 1,258,358,784 (1.79% trained)


Unsloth: Will smartly offload gradients to save VRAM!
---
Q: A buffet has 24 plates. Fast eaters take 2 plates each and finish in 10 minutes, while average eater...
Expected: 10
Extracted: To find out how many plates are being used by the ...


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / correctness_reward_func / mean,rewards / correctness_reward_func / std,rewards / reasoning_reward_func / mean,rewards / reasoning_reward_func / std,rewards / format_reward_func / mean,rewards / format_reward_func / std
1,0.0,1.575,0.984463,456.0,236.0,768.0,0.25,352.0,236.0,569.0,0,0,0,0,0,0.000767,1.5,1.0,0.075,0.05,0.0,0.0
2,0.0,0.0,0.0,649.75,440.0,768.0,0.25,610.333374,440.0,739.0,No Log,No Log,No Log,No Log,No Log,0.000404,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.5625,0.991947,548.0,227.0,768.0,0.5,328.0,227.0,429.0,No Log,No Log,No Log,No Log,No Log,0.000627,0.5,1.0,0.0625,0.025,0.0,0.0
4,0.0,0.0375,0.047871,507.75,210.0,768.0,0.25,421.0,210.0,741.0,No Log,No Log,No Log,No Log,No Log,0.000752,0.0,0.0,0.0375,0.047871,0.0,0.0
5,0.0,0.5125,0.991947,482.75,312.0,761.0,0.0,482.75,312.0,761.0,No Log,No Log,No Log,No Log,No Log,0.000681,0.5,1.0,0.0125,0.025,0.0,0.0
6,0.0,0.5375,1.008609,574.5,330.0,768.0,0.5,381.0,330.0,432.0,No Log,No Log,No Log,No Log,No Log,0.000683,0.5,1.0,0.0375,0.025,0.0,0.0
7,0.0,0.5,1.0,768.0,768.0,768.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.000281,0.5,1.0,0.0,0.0,0.0,0.0
8,0.0,0.05,0.0,686.5,532.0,768.0,0.5,605.0,532.0,678.0,No Log,No Log,No Log,No Log,No Log,0.000815,0.0,0.0,0.05,0.0,0.0,0.0
9,0.0,1.075,1.155061,621.75,402.0,768.0,0.5,475.5,402.0,549.0,No Log,No Log,No Log,No Log,No Log,0.000567,1.0,1.154701,0.075,0.028868,0.0,0.0
10,0.0,0.0625,0.025,386.75,202.0,768.0,0.25,259.666687,202.0,353.0,No Log,No Log,No Log,No Log,No Log,0.000627,0.0,0.0,0.0625,0.025,0.0,0.0


---
Q: Consider the set S = {n : n = k^3 + k^2 + k + 1 for k ‚àà {2, 3, 4, 5, 6}}. Four elements of S share a...
Expected: 1297
Extracted: To find the element that is the outlier, we first ...
---
Q: A cloud computing cluster has three server nodes (N1, N2, N3) that handle two job classes: Priority ...
Expected: 300
Extracted: Reasoning:
To find the number of Standard jobs pro...
---
Q: On a coordinate plane, a line segment connects points A(0,0) and B(10,0). A point P is chosen random...
Expected: 33.33
Extracted: To find the expected area of the square, we need t...
---
Q: In a quantum logic circuit, the state of a system is represented by a vector in R^3. Three possible ...
Expected: -2
Extracted: Since the system's state is represented by a vecto...
---
Q: In a two-player zero-sum game, Player 1 chooses a rate r‚ÇÅ ‚àà {1, 2, 3, 4} and Player 2 simultaneously...
Expected: 19
Extracted: To find p, where \(p = \frac{a}{b}\) indicates the...
---
Q: In the 'Bidding War' tournament, the

In [9]:
# Cell 9: Save Model
import os

# Option 1: Save locally
output_path = "llama-1b-reasoning-unsloth-HICRA-v1"
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
print(f"‚úÖ Model saved to {output_path}")

# Option 2: Push to HuggingFace Hub (uncomment to use)
# repo_name = "DataImaginations/Llama-1B-Reasoning-v1"
# hf_token = os.getenv('HF_TOKEN')
# 
# print(f"‚è≥ Pushing to {repo_name}...")
# model.push_to_hub_merged(
#     repo_name,
#     tokenizer,
#     save_method="merged_16bit",
#     token=hf_token
# )
# print("‚úÖ Model pushed to Hub!")

‚úÖ Model saved to llama-1b-reasoning-unsloth-HICRA-v1


In [1]:
# Cell: Merge LoRA adapters and save for evaluation
from unsloth import FastLanguageModel

# Load the adapter model
model, tokenizer = FastLanguageModel.from_pretrained(
    "llama-1b-reasoning-unsloth-HICRA-v1",
    max_seq_length=1024,
    load_in_4bit=True,
)

# Merge and save in 16-bit
print("‚è≥ Merging adapters...")
model.save_pretrained_merged(
    "llama-1b-reasoning-merged",  # New path for merged model
    tokenizer,
    save_method="merged_16bit",  # Full precision merged weights
)
print("‚úÖ Merged model saved!")

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


ü¶• Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.12.8: Fast Llama patching. Transformers: 4.57.3. vLLM: 0.13.0.
   \\   /|    NVIDIA GeForce RTX 4070 SUPER. Num GPUs = 1. Max memory: 11.594 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.12.8 patched 16 layers with 16 QKV layers, 16 O layers and 16 MLP layers.


‚è≥ Merging adapters...
Found HuggingFace hub cache directory: /home/david-barnes/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Checking cache directory for required files...
Cache check failed: tokenizer.model not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Preparing safetensor model files: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:33<00:00, 33.81s/it]


Note: tokenizer.model not found (this is OK for non-SentencePiece models)


Unsloth: Merging weights into 16bit: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:05<00:00,  5.09s/it]


Unsloth: Merge process complete. Saved to `/home/david-barnes/Documents/Projects/heirarch/llama-1b-reasoning-merged`
‚úÖ Merged model saved!


## Test the Trained Model

In [11]:
# Cell 10: Test Inference
from unsloth import FastLanguageModel

# Put model in inference mode
FastLanguageModel.for_inference(model)

# Test question
test_question = "A loan is repaid with 20 equal annual payments. The interest portion of the 16th payment is 400 and the interest portion of the 11th payment is 600. Find the interest portion of the 1st payment."

messages = [
    {"role": "system", "content": SYSTEM_PROMPT.strip()},
    {"role": "user", "content": test_question}
]

inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(
    input_ids=inputs,
    max_new_tokens=256,
    temperature=0.7,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Question:", test_question)
print("\nResponse:")
print(response)

Question: A loan is repaid with 20 equal annual payments. The interest portion of the 16th payment is 400 and the interest portion of the 11th payment is 600. Find the interest portion of the 1st payment.

Response:
system

Cutting Knowledge Date: December 2023
Today Date: 26 Dec 2025

You are a mathematical reasoning assistant. Think through problems step by step.
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>user

A loan is repaid with 20 equal annual payments. The interest portion of the 16th payment is 400 and the interest portion of the 11th payment is 600. Find the interest portion of the 1st payment.assistant

To find the interest portion of the 1st payment, we first need to determine the interest portion of the 15th payment. Since the loan is repaid with 20 equal annual payments, we can use the formula for the total amount repaid:

Total Repaid = Principal + Interest

We know the interest portion of the 16th payment is 400 and the interest 

# Bench for unsloth_HICRA

In [14]:
import lm_eval
from lm_eval.models.huggingface import HFLM

# 1. Load your local model
# If you just saved your model to "llama-1b-reasoning-final", point to that folder.
print("‚è≥ Loading model for evaluation...")

# We wrap the model in the Harness's HFLM wrapper
# 'pretrained' can be a local path OR a Hub ID (e.g., "david-barnes/my-model")
llm = HFLM(
    pretrained="llama-1b-reasoning-merged",  # Use the merged model
    batch_size=1,
    trust_remote_code=True,
    dtype="bfloat16"
)

# 2. Define the tasks you want
# These key names correspond to the harness registry.
# Note: "minerva_math" is often split by subject (algebra, etc), 
# so we usually run the main "math" group or specific subtasks.
task_list = [
    "aime24",          # AIME 2024
    "minerva_math",    # Minerva Math (covers multiple subjects)
    # "math_500",        # The 'easy' 500 questions from MATH
    "leaderboard_gpqa_main",   # leaderboard_math_hard      
]

# 3. Run the Eval
print(f"üöÄ Running evaluation on: {task_list}...")
results = lm_eval.simple_evaluate(
    model=llm,
    tasks=task_list,
    num_fewshot=0,        # Reasoning models often prefer 0-shot (Instruction)
    limit=None,           # Set to e.g., 50 to test quickly before full run!
    log_samples=True,    # Set True if you want to see exactly what it got wrong
)

# 4. Print a Pretty Table
from lm_eval.utils import make_table
print(make_table(results))

# 5. Save detailed results to JSON (Crucial for your blog!)
import json
with open("llama_1b_unsloth_HICRA_v1_benchmark_results.json", "w") as f:
    json.dump(results, f, indent=2)

[2025-12-26 10:43:01] INFO huggingface.py:158: Using device 'cuda'


‚è≥ Loading model for evaluation...


The tokenizer you are loading from 'llama-1b-reasoning-merged' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.
[2025-12-26 10:43:02] INFO huggingface.py:420: Model parallel was set to False, max memory was not set, and device map was set to {'': 'cuda'}
[2025-12-26 10:43:02] INFO evaluator.py:202: Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
[2025-12-26 10:43:02] INFO evaluator.py:258: Using pre-initialized model


üöÄ Running evaluation on: ['aime24', 'minerva_math', 'leaderboard_gpqa_main']...


Generating train split: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 448/448 [00:00<00:00, 9875.17 examples/s]
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 448/448 [00:00<00:00, 2259.56 examples/s]
[2025-12-26 10:43:15] INFO __init__.py:695: Selected tasks:
[2025-12-26 10:43:15] INFO __init__.py:686: Task: leaderboard_gpqa_main (leaderboard/gpqa/gpqa_main_zeroshot.yaml)
[2025-12-26 10:43:15] INFO __init__.py:698: Group: minerva_math
[2025-12-26 10:43:15] INFO __init__.py:712: ConfigurableGroup(group=minerva_math,group_alias=None): {'minerva_math_algebra': ConfigurableTask(task_name=minerva_math_algebra,output_type=generate_until,num_fewshot=4,num_samples=1187), 'minerva_math_counting_and_prob': ConfigurableTask(task_name=minerva_math_counting_and_prob,output_type=generate_until,num_fewshot=4,num_samples=474), 'minerva_math_geometry': ConfigurableTask(task_name=minerva_math_geometry,output_type=generate_until,num_fewshot=4,num_samples=479), 'minerva_math_intermediate_algebra': ConfigurableTask(tas

KeyboardInterrupt: 

# Soft VRAM clear

In [13]:
import torch
import gc

# 1. Delete the Python variables holding the model
# (Wrap in try/except so it doesn't crash if they are already gone)
try:
    del model
    del tokenizer
    del trainer
except NameError:
    print("Variables already deleted or not defined.")

# 2. Python Garbage Collection (Clears CPU RAM)
gc.collect()

# 3. PyTorch Cache Clearing (The most important step for VRAM)
torch.cuda.empty_cache()

# Verify: Print current memory usage
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU Memory Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Variables already deleted or not defined.
GPU Memory Allocated: 2.30 GB
GPU Memory Reserved:  3.69 GB
