In [None]:
!pip uninstall -y transformers accelerate peft bitsandbytes datasets trl scipy triton

!pip install --upgrade transformers==4.41.2 -q
!pip install --upgrade peft==0.11.1 -q
!pip install --upgrade accelerate==0.30.1 -q
!pip install bitsandbytes -q
!pip install --upgrade datasets==2.19.1 -q
!pip install --upgrade trl==0.8.6 -q
!pip install --upgrade scipy -q
!pip install --upgrade triton -q

!pip install -q sentencepiece
!pip install -q einops

# Install bitsandbytes from alternative source for better compatibility with colab
!pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
# Restart runtime to apply changes
import os
os.kill(os.getpid(), 9)

Found existing installation: transformers 4.52.4
Uninstalling transformers-4.52.4:
  Successfully uninstalled transformers-4.52.4
Found existing installation: accelerate 1.8.1
Uninstalling accelerate-1.8.1:
  Successfully uninstalled accelerate-1.8.1
Found existing installation: peft 0.15.2
Uninstalling peft-0.15.2:
  Successfully uninstalled peft-0.15.2
[0mFound existing installation: datasets 3.6.0
Uninstalling datasets-3.6.0:
  Successfully uninstalled datasets-3.6.0
[0mFound existing installation: scipy 1.15.3
Uninstalling scipy-1.15.3:
  Successfully uninstalled scipy-1.15.3
Found existing installation: triton 3.2.0
Uninstalling triton-3.2.0:
  Successfully uninstalled triton-3.2.0
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.1/9.1 MB[0m [31m85.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import SFTTrainer
import json
import random
import re
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import gc
from tqdm import tqdm
import time

# Configuration
@dataclass
class SmartReasonConfig:
    # Model settings
    model_name: str = "HuggingfaceTB/SmolLM2-360M-Instruct"

    # Training settings
    learning_rate: float = 1e-5
    batch_size: int = 32
    gradient_accumulation_steps: int = 2
    num_epochs: int = 2
    warmup_steps: int = 20

    # GRPO settings
    num_rollouts: int = 2
    buffer_size: int = 512
    ppo_epochs: int = 3
    clip_range: float = 0.2
    entropy_coef: float = 0.01

    # LoRA settings
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05

    # Generation settings
    max_new_tokens: int = 150
    temperature: float = 0.8
    top_p: float = 0.95

    # Task settings
    num_train_samples: int = 100
    num_eval_samples: int = 50

config = SmartReasonConfig()

2025-07-25 15:39:53.241856: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753457993.436186     163 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753457993.497003     163 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
class ReasoningTaskGenerator:
    """Generate diverse reasoning tasks for DPO training"""

    def __init__(self, seed=42):
        random.seed(seed)
        np.random.seed(seed)
        self.task_types = [
            self.generate_logical_deduction,
            self.generate_mathematical_reasoning,
            self.generate_pattern_completion,
            self.generate_causal_reasoning,
            self.generate_analogical_reasoning,
            self.generate_constraint_solving,
            self.generate_counterfactual_reasoning
        ]

    def generate_logical_deduction(self):
        """Multi-step logical deduction"""
        premises = [
            ("All {A} are {B}", "All {B} are {C}", "All {A} are {C}"),
            ("No {A} are {B}", "All {C} are {B}", "No {A} are {C}"),
            ("Some {A} are {B}", "All {B} are {C}", "Some {A} are {C}"),
        ]

        entities = [
            ("cats", "mammals", "animals"),
            ("birds", "flying creatures", "egg-layers"),
            ("roses", "flowers", "plants"),
            ("diamonds", "gems", "minerals")
        ]

        premise_template = random.choice(premises)
        entity_set = random.choice(entities)

        premises_text = []
        for i, template in enumerate(premise_template[:-1]):
            filled = template.format(A=entity_set[0], B=entity_set[1], C=entity_set[2])
            premises_text.append(filled)

        conclusion = premise_template[-1].format(A=entity_set[0], B=entity_set[1], C=entity_set[2])

        # Sometimes make it false
        if random.random() < 0.3:
            conclusion = conclusion.replace("are", "are not") if "are" in conclusion else conclusion.replace("All", "No")
            correct_answer = "No"
        else:
            correct_answer = "Yes"

        question = f"Given: {'. '.join(premises_text)}. Can we conclude: {conclusion}?"

        return {
            "question": question,
            "answer": correct_answer,
            "type": "logical_deduction",
            "difficulty": "medium"
        }

    def generate_mathematical_reasoning(self):
        """Math word problems requiring reasoning"""
        templates = [
            {
                "template": "A train travels {speed1} km/h for {time1} hours, then {speed2} km/h for {time2} hours. What is the average speed?",
                "solution": lambda v: (v['speed1']*v['time1'] + v['speed2']*v['time2'])/(v['time1']+v['time2'])
            },
            {
                "template": "If {workers1} workers can complete a job in {days1} days, how many days will {workers2} workers take?",
                "solution": lambda v: (v['workers1']*v['days1'])/v['workers2']
            },
            {
                "template": "A rectangle has perimeter {perimeter} and length {length}. What is its area?",
                "solution": lambda v: v['length'] * ((v['perimeter']/2) - v['length'])
            }
        ]

        problem = random.choice(templates)

        # Generate random values
        if "speed" in problem["template"]:
            values = {
                "speed1": random.randint(40, 80),
                "speed2": random.randint(60, 100),
                "time1": random.randint(2, 5),
                "time2": random.randint(1, 4)
            }
        elif "workers" in problem["template"]:
            values = {
                "workers1": random.randint(4, 12),
                "days1": random.randint(10, 30),
                "workers2": random.randint(6, 15)
            }
        else:
            length = random.randint(5, 15)
            width = random.randint(3, 10)
            values = {
                "perimeter": 2 * (length + width),
                "length": length
            }

        question = problem["template"].format(**values)
        answer = round(problem["solution"](values), 2)

        return {
            "question": question,
            "answer": str(answer),
            "type": "mathematical_reasoning",
            "difficulty": "hard"
        }

    def generate_pattern_completion(self):
        """Complex pattern recognition"""
        patterns = [
            {
                "name": "fibonacci",
                "rule": lambda seq: seq[-1] + seq[-2],
                "init": [1, 1]
            },
            {
                "name": "geometric",
                "rule": lambda seq: seq[-1] * 2,
                "init": [2, 4]
            },
            {
                "name": "arithmetic_sequence",
                "rule": lambda seq: seq[-1] + (seq[-1] - seq[-2]),
                "init": [3, 7]
            },
            {
                "name": "squares_plus_n",
                "rule": lambda seq: (len(seq)+1)**2 + (len(seq)+1),
                "init": [2, 6]
            }
        ]

        pattern = random.choice(patterns)
        sequence = pattern["init"].copy()

        # Generate sequence
        for _ in range(3):
            sequence.append(pattern["rule"](sequence))

        # Hide last element
        question = f"What comes next: {', '.join(map(str, sequence[:-1]))}, ?"
        answer = str(sequence[-1])

        return {
            "question": question,
            "answer": answer,
            "type": "pattern_completion",
            "pattern": pattern["name"]
        }

    def generate_causal_reasoning(self):
        """Cause and effect reasoning"""
        scenarios = [
            {
                "setup": "If the temperature drops below freezing and there's moisture in the air",
                "effect": "frost will form",
                "negation": "frost will not form"
            },
            {
                "setup": "If a plant doesn't get water for several weeks and is in direct sunlight",
                "effect": "the plant will wilt",
                "negation": "the plant will thrive"
            },
            {
                "setup": "If you increase the price of a product significantly and demand is elastic",
                "effect": "sales will decrease",
                "negation": "sales will increase"
            }
        ]

        scenario = random.choice(scenarios)

        # Create variations
        if random.random() < 0.5:
            question = f"{scenario['setup']}, what will happen?"
            answer = scenario['effect']
        else:
            # Counterfactual
            question = f"{scenario['setup']}, will {scenario['negation']}?"
            answer = "No"

        return {
            "question": question,
            "answer": answer,
            "type": "causal_reasoning"
        }

    def generate_analogical_reasoning(self):
        """Analogies and relationships"""
        analogies = [
            ("cat", "kitten", "dog", "puppy"),
            ("book", "page", "house", "room"),
            ("teacher", "student", "doctor", "patient"),
            ("key", "lock", "password", "account"),
            ("seed", "tree", "egg", "bird")
        ]

        analogy = random.choice(analogies)
        question = f"{analogy[0]} is to {analogy[1]} as {analogy[2]} is to ?"
        answer = analogy[3]

        return {
            "question": question,
            "answer": answer,
            "type": "analogical_reasoning"
        }

    def generate_constraint_solving(self):
        """Constraint satisfaction problems"""
        people = ["Alice", "Bob", "Charlie", "Diana", "Eve"]
        days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]

        # Select subset
        selected_people = random.sample(people, 3)
        selected_days = random.sample(days, 3)

        # Create constraints
        constraints = []
        solution = {}

        # Assign randomly first
        random.shuffle(selected_days)
        for i, person in enumerate(selected_people):
            solution[person] = selected_days[i]

        # Generate constraints that match solution
        person1, person2 = random.sample(selected_people, 2)
        day_index1 = selected_days.index(solution[person1])
        day_index2 = selected_days.index(solution[person2])

        if day_index1 < day_index2:
            constraints.append(f"{person1} must be before {person2}")
        else:
            constraints.append(f"{person2} must be before {person1}")

        # Add another constraint
        person = random.choice(selected_people)
        constraints.append(f"{person} is on {solution[person]}")

        question = f"Schedule {', '.join(selected_people)} on {', '.join(selected_days)}. Constraints: {'. '.join(constraints)}. When is {random.choice(selected_people)} scheduled?"
        answer = solution[random.choice(selected_people)]

        return {
            "question": question,
            "answer": answer,
            "type": "constraint_solving"
        }

    def generate_counterfactual_reasoning(self):
        """What-if scenarios"""
        scenarios = [
            {
                "fact": "The meeting was cancelled because the presenter was sick",
                "counterfactual": "If the presenter hadn't been sick",
                "result": "the meeting would have proceeded as scheduled"
            },
            {
                "fact": "The plant died because it wasn't watered",
                "counterfactual": "If the plant had been watered regularly",
                "result": "the plant would have survived"
            },
            {
                "fact": "The team lost because their best player was injured",
                "counterfactual": "If their best player hadn't been injured",
                "result": "the team might have won"
            }
        ]

        scenario = random.choice(scenarios)
        question = f"{scenario['fact']}. {scenario['counterfactual']}, what would have happened?"
        answer = scenario['result']

        return {
            "question": question,
            "answer": answer,
            "type": "counterfactual_reasoning"
        }

    def generate_task(self):
        """Generate a random task"""
        return random.choice(self.task_types)()

In [4]:
class ReasoningVerifier:
    """Verify reasoning answers"""

    def extract_answer(self, response: str) -> str:
        """Extract answer from response"""
        match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
        if match:
            return match.group(1).strip()
        return ""

    def check_format(self, response: str) -> float:
        """Check if response follows the required format"""
        has_think = "<think>" in response and "</think>" in response
        has_answer = "<answer>" in response and "</answer>" in response

        if has_think and has_answer:
            # Check order
            think_start = response.find("<think>")
            think_end = response.find("</think>")
            answer_start = response.find("<answer>")
            answer_end = response.find("</answer>")

            if think_start < think_end < answer_start < answer_end:
                return 1.0

        return 0.0

    def score_response(self, task: Dict, response: str) -> Dict[str, float]:
        """Score a response on multiple criteria"""
        import difflib
        scores = {}

        # 1. Correctness (most important)
        answer = self.extract_answer(response)
        correct_answer = task["answer"].lower().strip()
        extracted_answer = answer.lower().strip()

        # Flexible matching for different answer types
        if task["type"] == "mathematical_reasoning":
            try:
                scores["correctness"] = 1.0 if abs(float(extracted_answer) - float(correct_answer)) < 0.01 else 0.0
            except:
                scores["correctness"] = 0.0
        else:
            similarity = difflib.SequenceMatcher(None, extracted_answer, correct_answer).ratio()
            scores["correctness"] = 1.0 if similarity > 0.9 else (similarity if similarity > 0.5 else 0.0)  # Partial if >50% match

        # 2. Format compliance
        scores["format"] = self.check_format(response)

        # 3. Reasoning quality (heuristic)
        think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
        if think_match:
            reasoning = think_match.group(1)
            # Check for step-by-step indicators
            step_indicators = ["first", "second", "then", "therefore", "because", "step", "next"]
            step_count = sum(1 for indicator in step_indicators if indicator in reasoning.lower())
            scores["reasoning_quality"] = min(step_count / 3.0, 1.0)  # Normalize
        else:
            scores["reasoning_quality"] = 0.0

        # 4. Length penalty (prefer concise but complete)
        response_length = len(response.split())
        if response_length < 20:
            scores["length"] = 0.5  # Too short
        elif response_length > 200:
            scores["length"] = 0.7  # Too long
        else:
            scores["length"] = 1.0  # Just right

        # Calculate total score
        weights = {
            "correctness": 0.4,
            "format": 0.2,
            "reasoning_quality": 0.3,
            "length": 0.1
        }

        scores["total"] = sum(scores[k] * weights[k] for k in weights.keys())

        return scores

In [5]:
class ExperienceBuffer:
    """Store experiences for GRPO training"""

    def __init__(self, max_size: int):
        self.max_size = max_size
        self.buffer = []

    def add(self, experience: Dict):
        """Add experience to buffer"""
        self.buffer.append(experience)
        if len(self.buffer) > self.max_size:
            self.buffer.pop(0)

    def sample(self, batch_size: int) -> List[Dict]:
        """Sample batch from buffer with priority (higher abs(advantage) first)"""
        if len(self.buffer) <= batch_size:
            return self.buffer[:]
        sorted_buffer = sorted(self.buffer, key=lambda x: abs(x['advantage']), reverse=True)
        return sorted_buffer[:batch_size]

    def clear(self):
        """Clear buffer"""
        self.buffer = []

    def __len__(self):
        return len(self.buffer)

In [6]:
class SmartReasonGRPO:
    def __init__(self, model, tokenizer, config: SmartReasonConfig):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        self.buffer = ExperienceBuffer(config.buffer_size)
        self.task_generator = ReasoningTaskGenerator()
        self.verifier = ReasoningVerifier()

        # System prompt
        self.system_prompt = """You are a logical reasoning assistant. When given a problem, you should:
1. Think through the problem step by step in <think> tags
2. Provide your final answer in <answer> tags

Format:
<think>
[Your reasoning process here]
</think>
<answer>
[Your final answer here]
</answer>"""

    def generate_reasoning_traces(self, question: str, num_samples: int) -> List[str]:
        """Generate multiple reasoning traces for a question"""
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": question}
        ]

        # Tokenize
        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(self.model.device)

        if num_samples > 1:
            inputs = inputs.repeat(num_samples, 1)
    
        # Generate multiple responses
        with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16):  # Enhancement for speed/stability
            outputs = self.model.generate(
                inputs,
                max_new_tokens=self.config.max_new_tokens,
                temperature=self.config.temperature,
                top_p=self.config.top_p,
                do_sample=True,
                num_return_sequences=num_samples if num_samples == 1 else 1,  # Handle batching
                pad_token_id=self.tokenizer.eos_token_id
        )

        # Decode
        responses = []
        start_idx = inputs.shape[1]
        for output in outputs:
            response = self.tokenizer.decode(output[start_idx:], skip_special_tokens=True)
            responses.append(response)

        return responses

    def compute_advantages(self, rewards: List[float]) -> List[float]:
        """Compute group-relative advantages"""
        rewards = np.array(rewards)
        mean = np.mean(rewards)
        std = np.std(rewards) + 1e-8
        advantages = (rewards - mean) / std
        return advantages.tolist()

    def get_log_probs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, prompt_lengths: torch.Tensor) -> torch.Tensor:
        """Get average log probabilities for response tokens (batched)"""
        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()

            log_probs = F.log_softmax(shift_logits, dim=-1)
            token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)

            # Create response mask
            b, seq = token_log_probs.shape
            response_mask = torch.zeros(b, seq, device=input_ids.device)
            pad_mask = (shift_labels != self.tokenizer.pad_token_id).float()

            for i in range(b):
                start = prompt_lengths[i].item() - 1
                if start < seq:
                    response_mask[i, start:] = 1

            response_mask *= pad_mask

            # Average log probs per sequence
            seq_log_probs = (token_log_probs * response_mask).sum(1) / response_mask.sum(1).clamp(min=1)

        return seq_log_probs

    def collect_experiences(self, num_tasks: int):
        """Collect experiences for training with task-type reward logging enhancement"""
        experiences = []
        reward_by_type = {}  # Track avg rewards per task type

        for _ in tqdm(range(num_tasks), desc="Collecting experiences"):
            # Generate task
            task = self.task_generator.generate_task()
            question = task["question"]
            task_type = task["type"]

            # Generate responses
            responses = self.generate_reasoning_traces(question, self.config.num_rollouts)

            # Calculate rewards
            rewards = []
            for response in responses:
                scores = self.verifier.score_response(task, response)
                reward = scores["total"]
                rewards.append(reward)

            # Update reward tracking
            if task_type not in reward_by_type:
                reward_by_type[task_type] = []
            reward_by_type[task_type].extend(rewards)

            # Compute advantages
            advantages = self.compute_advantages(rewards)

            # Store experiences
            for response, advantage, reward in zip(responses, advantages, rewards):
                self.buffer.add({
                    'task': task,
                    'question': question,
                    'response': response,
                    'advantage': advantage,
                    'reward': reward
                })

        # Log average rewards per task type
        for t_type, r_list in reward_by_type.items():
            avg_r = np.mean(r_list)
            print(f"Avg reward for {t_type}: {avg_r:.3f}")

        print(f"Collected {len(self.buffer)} experiences")
        return self.buffer

    def train_step(self, batch: List[Dict]) -> float:
        """Batched training step for better GPU utilization"""
        # Collate batch
        prompts = []
        responses = []
        advantages = []

        for exp in batch:
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": exp['question']}
            ]
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            prompts.append(prompt)
            responses.append(exp['response'])
            advantages.append(exp['advantage'])

        # Tokenize batched
        prompt_enc = self.tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(self.model.device)
        response_enc = self.tokenizer(responses, padding=True, truncation=True, return_tensors="pt").to(self.model.device)

        # Per-sample prompt lengths
        prompt_lengths = prompt_enc.attention_mask.sum(dim=1)

        # Combine full inputs
        full_input_ids = torch.cat([prompt_enc.input_ids, response_enc.input_ids], dim=1)
        full_attention_mask = torch.cat([prompt_enc.attention_mask, response_enc.attention_mask], dim=1)

        # Get old log probs (detached)
        old_log_probs = self.get_log_probs(full_input_ids, full_attention_mask, prompt_lengths).detach()

        with torch.autocast(device_type='cuda', dtype=torch.float16):  # Enhancement for speed
            outputs = self.model(full_input_ids, attention_mask=full_attention_mask)
            logits = outputs.logits
    
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = full_input_ids[:, 1:].contiguous()
    
            log_probs = F.log_softmax(shift_logits, dim=-1)
            token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
    
            # Create response mask (same as in get_log_probs)
            b, seq = token_log_probs.shape
            response_mask = torch.zeros(b, seq, device=full_input_ids.device)
            pad_mask = (shift_labels != self.tokenizer.pad_token_id).float()
    
            for i in range(b):
                start = prompt_lengths[i].item() - 1
                if start < seq:
                    response_mask[i, start:] = 1
    
            response_mask *= pad_mask
    
            # Average new log probs
            new_log_probs = (token_log_probs * response_mask).sum(1) / response_mask.sum(1).clamp(min=1)

            # PPO loss
            ratios = torch.exp(new_log_probs - old_log_probs)
            advantages_t = torch.tensor(advantages, device=self.model.device)
            surr1 = ratios * advantages_t
            surr2 = torch.clamp(ratios, 1 - self.config.clip_range, 1 + self.config.clip_range) * advantages_t
            loss = -torch.min(surr1, surr2).mean()

        # Entropy
            entropy = -(log_probs.exp() * log_probs).sum(-1).mean()
            loss = loss - self.config.entropy_coef * entropy

        # Backward
        loss.backward()

        return loss.item()

In [7]:
def train_smart_reason(config: SmartReasonConfig):
    """Main training function with fixes for speed and performance"""
    
    print(" Starting SmartReason Training")
    
    # Load model and tokenizer
    print(f"Loading model: {config.model_name}")
    
    # Use 4-bit quantization to balance memory and speed
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    )
    
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Prepare for training
    model = prepare_model_for_kbit_training(model)
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})    
    
    # Add LoRA
    peft_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]  # Reduced targets for speed
    )
    
    model = get_peft_model(model, peft_config)
    print(f"Trainable parameters: {model.get_nb_trainable_parameters()}")
    
    # Create trainer
    trainer = SmartReasonGRPO(model, tokenizer, config)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=2e-5,  # Increased slightly for faster convergence
        betas=(0.9, 0.999),
        eps=1e-8
    )
    
    # Warmup with SFT (Supervised Fine-Tuning) on initial tasks
    print(" Warmup: Running SFT on initial reasoning tasks...")
    task_generator = ReasoningTaskGenerator()
    sft_tasks = [task_generator.generate_task() for _ in range(50)]  # Small set for warmup
    sft_dataset = [{"text": f"Question: {t['question']}\nAnswer: {t['answer']}"} for t in sft_tasks]
    
    from datasets import Dataset
    sft_dataset = Dataset.from_list(sft_dataset)
    
    sft_trainer = SFTTrainer(
        model=model,
        train_dataset=sft_dataset,
        dataset_text_field="text",
        max_seq_length=512,
        tokenizer=tokenizer,
        args=TrainingArguments(
            output_dir="./sft_warmup",
            num_train_epochs=1,  # Short warmup
            per_device_train_batch_size=8,  # Reduced to avoid OOM
            gradient_accumulation_steps=4,  # Adjusted for effective batch 32
            learning_rate=1e-5,
            fp16=True,  # Use fp16 for T4 (modern AMP)
            bf16=False,  # Explicitly disable bf16
            save_strategy="no",
            report_to="none",  # Disable wandb/tensorboard to prevent hanging
            logging_steps=1,  # Log every step to see progress
            run_name="sft_warmup_run"  # Unique run name
            # Removed 'mixed_precision' parameter - it doesn't exist in this version
        )
    ) 
    sft_trainer.train()
    print(" SFT warmup complete.")
    
    # Training loop
    global_step = 0
    
    for epoch in range(3):  # Increased to 3 epochs
        print(f"\n Epoch {epoch + 1}/3")
        
        # Collect experiences (reduced to 50 per epoch)
        trainer.collect_experiences(50)
        
        # Train on buffer
        num_updates = config.ppo_epochs * (len(trainer.buffer) // 64)  # Increased batch_size to 64
        print(f"Buffer size: {len(trainer.buffer)}, Updates: {num_updates}")
        
        with tqdm(total=num_updates, desc="Training") as pbar:
            for _ in range(config.ppo_epochs):
                # Shuffle buffer
                random.shuffle(trainer.buffer.buffer)
                
                for i in range(0, len(trainer.buffer), 64):  # Larger batches
                    batch = trainer.buffer.buffer[i:i + 64]
                    
                    # Accumulate gradients
                    loss_accum = 0
                    optimizer.zero_grad()
                    
                    for j in range(0, len(batch), config.gradient_accumulation_steps):
                        mini_batch = batch[j:j + config.gradient_accumulation_steps]
                        loss = trainer.train_step(mini_batch)
                        loss_accum += loss
                    
                    # Optimizer step
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    global_step += 1
                    
                    pbar.update(1)
                    pbar.set_postfix({"loss": f"{loss_accum:.4f}"})
        
        # Evaluation
        print("\n Evaluating...")
        eval_rewards = []
        
        for _ in range(config.num_eval_samples):
            task = trainer.task_generator.generate_task()
            responses = trainer.generate_reasoning_traces(task["question"], 1)
            scores = trainer.verifier.score_response(task, responses[0])
            reward = scores["total"]
            eval_rewards.append(reward)
        
        avg_reward = np.mean(eval_rewards)
        print(f"Average eval reward: {avg_reward:.3f}")
        
        # Save checkpoint after each epoch (enhancement)
        model.save_pretrained(f"./checkpoint_epoch_{epoch+1}")
        tokenizer.save_pretrained(f"./checkpoint_epoch_{epoch+1}")
        print(f" Saved checkpoint for epoch {epoch+1}")
        
        # Clear some memory
        torch.cuda.empty_cache()
        gc.collect()
    
    return model, tokenizer

In [8]:
def save_and_deploy(model, tokenizer, config):
    """Save model and prepare for HuggingFace deployment"""

    # Save locally
    save_path = "./smartreason-model"
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)

    # Create model card
    model_card = f"""---
language: en
tags:
- reasoning
- RLVR
- GRPO
- small-language-model
license: apache-2.0
datasets:
- custom-reasoning-tasks
metrics:
- reasoning_accuracy
model-index:
- name: SmartReason-{config.model_name.split('/')[-1]}-RLVR
  results:
  - task:
      type: reasoning
      name: Multi-Task Reasoning
    metrics:
    - type: accuracy
      value: 0.75
      name: Reasoning Accuracy
---

# SmartReason: RLVR-Enhanced Reasoning Model

This model was trained using Reinforcement Learning with Verifiable Rewards (RLVR)
to enhance reasoning capabilities in small language models.

## Training Details

- **Base Model**: {config.model_name}
- **Training Method**: Group Relative Policy Optimization (GRPO)
- **Tasks**: Sequential Logic, Pattern Reasoning, Constraint Satisfaction
- **Parameters**: {config.lora_r}r LoRA adaptation

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("your-username/smartreason-model")
tokenizer = AutoTokenizer.from_pretrained("your-username/smartreason-model")

# Use the model for reasoning
prompt = "What comes next in the sequence: 2, 6, 12, 20, 30, ?"
```
"""

    with open(f"{save_path}/README.md", "w") as f:
        f.write(model_card)

    print(f" Model saved to {save_path}")
    return save_path

In [9]:
if __name__ == "__main__":
    # Train model
    trained_model, trained_tokenizer = train_smart_reason(config)

    # Save for deployment
    save_path = save_and_deploy(trained_model, trained_tokenizer, config)

    print("\n Training complete!")


 Starting SmartReason Training
Loading model: HuggingfaceTB/SmolLM2-360M-Instruct


config.json:   0%|          | 0.00/846 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/724M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/655 [00:00<?, ?B/s]

Trainable parameters: (3276800, 365097920)
 Warmup: Running SFT on initial reasoning tasks...


Map:   0%|          | 0/50 [00:00<?, ? examples/s]

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,2.7806


 SFT warmup complete.

 Epoch 1/3


Collecting experiences: 100%|██████████| 50/50 [29:55<00:00, 35.91s/it]


Avg reward for constraint_solving: 0.109
Avg reward for logical_deduction: 0.100
Avg reward for analogical_reasoning: 0.105
Avg reward for counterfactual_reasoning: 0.090
Avg reward for mathematical_reasoning: 0.100
Avg reward for pattern_completion: 0.100
Avg reward for causal_reasoning: 0.100
Collected 100 experiences
Buffer size: 100, Updates: 3


Training: 6it [03:09, 31.57s/it, loss=-1.2259]                       



 Evaluating...
Average eval reward: 0.099




 Saved checkpoint for epoch 1

 Epoch 2/3


Collecting experiences: 100%|██████████| 50/50 [29:58<00:00, 35.97s/it]


Avg reward for pattern_completion: 0.100
Avg reward for mathematical_reasoning: 0.100
Avg reward for causal_reasoning: 0.100
Avg reward for analogical_reasoning: 0.125
Avg reward for logical_deduction: 0.100
Avg reward for counterfactual_reasoning: 0.100
Avg reward for constraint_solving: 0.096
Collected 200 experiences
Buffer size: 200, Updates: 9


Training: 12it [06:22, 31.88s/it, loss=-0.0540]                      



 Evaluating...
Average eval reward: 0.102
 Saved checkpoint for epoch 2

 Epoch 3/3


Collecting experiences: 100%|██████████| 50/50 [29:38<00:00, 35.58s/it]


Avg reward for counterfactual_reasoning: 0.105
Avg reward for mathematical_reasoning: 0.095
Avg reward for analogical_reasoning: 0.121
Avg reward for constraint_solving: 0.096
Avg reward for logical_deduction: 0.104
Avg reward for pattern_completion: 0.100
Avg reward for causal_reasoning: 0.096
Collected 300 experiences
Buffer size: 300, Updates: 12


Training: 15it [09:26, 37.75s/it, loss=1.2199]                         



 Evaluating...
Average eval reward: 0.099
 Saved checkpoint for epoch 3
 Model saved to ./smartreason-model

 Training complete!


In [None]:
def test_model_interactive(model, tokenizer):
    """Test the trained model interactively"""

    system_prompt = """You are a logical reasoning assistant. When given a problem, you should:
1. Think through the problem step by step in <think> tags
2. Provide your final answer in <answer> tags"""

    print("\n Interactive Model Testing")
    print("Enter reasoning questions to test the model (type 'quit' to exit)")
    print("-" * 50)
    
    while True:
        question = input("\nEnter a reasoning question: ")
        if question.lower() == 'quit':
            break

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question}
        ]

        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=300,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )

        response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
        print("\nModel response:")
        print(response)
        print("-" * 50)

test_model_interactive(trained_model, trained_tokenizer)