In [190]:
import numpy as np
import random
import torch
import re
import importlib
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from rich import print
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import reasoning_gym
from reasoning_gym import get_score_answer_fn
from functools import partial

In [191]:
# --- Configuration ---
FORMAT_REWARD_WEIGHT = 0.15
CORRECTNESS_REWARD_WEIGHT = 0.85
MAX_NEW_TOKENS = 150
N_ROLLOUTS = 3
BATCH_SIZE = 2
BUFFER_SIZE = 15
EPOCHS = 2  # Reduced for demo
LR = 1e-5   # Lower LR is usually better for PPO/GRPO

# Initialize Accelerator
accelerator = Accelerator()

In [192]:
def extract_answer(answer):
    match = re.search(r"<answer>(.*?)</answer>", answer, re.DOTALL)
    if match:
        return match.group(1).strip()
    return answer

def calculate_format_reward(response):
    # required tags
    required = ["<think>", "</think>", "<answer>", "</answer>"]
    if any(tag not in response for tag in required):
        return -0.5

    think_open = response.find("<think>")
    think_close = response.find("</think>")
    answer_open = response.find("<answer>")
    answer_close = response.find("</answer>")

    # enforce correct order
    if not (0 <= think_open < think_close < answer_open < answer_close):
        return -0.5

    reward = 0.0
    reward += 0.2  # correct format
    reward += 0.3  # answer block exists
    return reward

def correctness_reward(response, validation_object):
    # This might fail if the dataset requires specific logic, wrapping in try/except is safer
    try:
        score_fn = get_score_answer_fn(validation_object['metadata']['source_dataset'])
        return score_fn(response, validation_object)
    except Exception:
        return 0.0

def calculate_rewards(batch_response, validation_objects):
    format_r = np.array([calculate_format_reward(r) for r in batch_response])
    correctness_r = np.array([
        correctness_reward(extract_answer(r), val_obj) 
        for val_obj, r in zip(validation_objects, batch_response)
    ])
    return FORMAT_REWARD_WEIGHT * format_r + CORRECTNESS_REWARD_WEIGHT * correctness_r

# --- Model & Math Utils ---

def calculate_logits(model, input_ids, attention_mask):
    """
    Calculates token-level log probabilities for the specific tokens in input_ids.
    """
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits[:, :-1, :]  # Shift so we predict next token
        input_ids_shifted = input_ids[:, 1:]
        
        log_probs = torch.log_softmax(logits, dim=-1)
        token_logprobs = torch.gather(log_probs, 2, input_ids_shifted.unsqueeze(-1)).squeeze(-1)
        
        return token_logprobs

def left_pad(sequences, pad_value):
    """Left-pads a list of 1D tensors."""
    max_len = max(len(s) for s in sequences)
    batch_size = len(sequences)
    dtype = sequences[0].dtype
    device = sequences[0].device
    
    padded_batch = torch.full((batch_size, max_len), fill_value=pad_value, dtype=dtype, device=device)
    
    for i, seq in enumerate(sequences):
        padded_batch[i, -len(seq):] = seq
        
    return padded_batch

# --- Data Loading ---

class GymDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return len(self.data)

def collate_fn(batch, tokenizer):
    prompts = []
    for sample in batch:
        question = sample['question']
        prompt = (
            f"A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\n"
            f"The assistant first thinks about the reasoning process in the mind and then provides the user "
            f"with the answer.\nUser: {question}.\nAssistant:"
        )
        prompts.append(prompt)
    
    # Basic tokenization of prompts
    input_tokenized = tokenizer(prompts, padding=True, add_special_tokens=False, return_tensors='pt')
    
    return {
        'input_ids': input_tokenized['input_ids'],
        'attention_mask': input_tokenized['attention_mask'],
        'raw_batch': batch  # Keep original objects for validation
    }

def get_dataloader(dataset_name, tokenizer, batch_size=2):
    # Note: 'syllogism' might need to be 'logic/syllogism' or similar depending on reasoning_gym version
    gym_data = reasoning_gym.create_dataset(dataset_name, seed=42, size=100) 
    data = GymDataset(gym_data)
    return DataLoader(
        data,
        batch_size=batch_size,
        collate_fn=partial(collate_fn, tokenizer=tokenizer),
        shuffle=True
    )

In [193]:
def collect_exp(model, tokenizer, input_ids, raw_batch, attention_mask):
    model.eval() # Ensure eval mode for generation
    
    # 1. Generate Rollouts
    with torch.no_grad():
        # input_ids shape: [B, Seq]
        # full_response shape: [B * N_ROLLOUTS, Total_Seq]
        full_response = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            temperature=1.0,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            top_p=0.95,
            num_return_sequences=N_ROLLOUTS,
            pad_token_id=tokenizer.pad_token_id
        )
        
        # 2. Separate Prompt and Completion
        # Note: full_response includes input_ids. We need to find where generation started.
        prompt_len = input_ids.shape[1]
        completion_ids = full_response[:, prompt_len:]
        
        # 3. Calculate Log Probs of the Generated Sequence (Reference Policy)
        # We need full attention mask for the generated sequence
        full_attention_mask = (full_response != tokenizer.pad_token_id).long()
        
        # This returns log_probs for tokens at indices [1:] (next token prediction)
        # Shape: [B*N, Total_Seq-1]
        all_token_log_probs = calculate_logits(model, full_response, full_attention_mask)
        
        # Extract only the completion log probs
        # We align by slicing. all_token_log_probs index i corresponds to prediction of full_response[i+1]
        # The first token of completion is at prompt_len. It was predicted by token at prompt_len-1.
        completion_log_probs = all_token_log_probs[:, prompt_len-1:]

        # 4. Decoding & Rewards
        decoded_responses = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
        
        # Repeat validation objects for the rollouts
        # raw_batch is size B, we need B*N
        expanded_val_objs = []
        for obj in raw_batch:
            expanded_val_objs.extend([obj] * N_ROLLOUTS)

        rewards = calculate_rewards(decoded_responses, expanded_val_objs)
        
        # 5. Advantage Calculation (Group Relative)
        # Reshape to [B, N] to normalize within the group
        rewards_reshaped = rewards.reshape(-1, N_ROLLOUTS)
        mean_rewards = rewards_reshaped.mean(axis=1, keepdims=True)
        std_rewards = rewards_reshaped.std(axis=1, keepdims=True) + 1e-8
        advantages = (rewards_reshaped - mean_rewards) / std_rewards
        
        # Flatten back to [B*N] and move to tensor
        advantages = torch.tensor(advantages.flatten(), dtype=torch.float32, device=accelerator.device)

        experiences = []
        for i in range(len(full_response)):
            # Create a mask for ONLY the completion part (ignoring padding)
            # 1 for valid completion tokens, 0 for everything else
            seq_len = full_response[i].size(0)
            mask = torch.zeros(seq_len, dtype=torch.float32, device=accelerator.device)
            
            # Find end of valid tokens (first pad after prompt) or end of seq
            is_pad = (full_response[i] == tokenizer.pad_token_id)
            # We only care about padding occurring AFTER the prompt
            completion_pads = is_pad[prompt_len:]
            if completion_pads.any():
                # First pad index relative to start of completion
                valid_len = completion_pads.nonzero(as_tuple=True)[0][0].item()
                end_idx = prompt_len + valid_len
            else:
                end_idx = seq_len
                
            mask[prompt_len:end_idx] = 1.0

            # Store the log probs needed for optimization (shifted internally in loss calc, but here we store raw)
            # Actually, to save memory, we usually just recompute logits during training or store current log_probs.
            # Storing `completion_log_probs` requires alignment. 
            # Let's store the Full Sequence and recompute/extract in training loop for simplicity or alignment.
            
            experiences.append({
                'input_ids': full_response[i], # [Seq]
                'attention_mask': full_attention_mask[i], # [Seq]
                'response_mask': mask, # [Seq] (1s on completion, 0s elsewhere)
                'old_log_probs': all_token_log_probs[i], # [Seq-1]
                'advantages': advantages[i] # Scalar
            })
            
    return experiences

In [194]:
def calculate_grpo_loss(new_log_probs, old_log_probs, response_mask, advantages):
    """
    new_log_probs: [B, Seq-1]
    old_log_probs: [B, Seq-1]
    response_mask: [B, Seq] -> Needs to be shifted to [B, Seq-1] to match logits
    advantages: [B, 1]
    """
    # Align mask with logits (drop the first token position as logits predict next token)
    mask = response_mask[:, 1:]
    
    # Probability Ratio
    # exp(new - old) = new_prob / old_prob
    ratio = torch.exp(new_log_probs - old_log_probs)
    
    # Clip
    clip_epsilon = 0.2
    clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
    
    # Element-wise loss
    surr1 = ratio * advantages
    surr2 = clipped_ratio * advantages
    min_surr = torch.min(surr1, surr2)
    
    # Mask out non-response tokens
    masked_loss = min_surr * mask
    
    # Average over valid response tokens only
    loss = -masked_loss.sum() / (mask.sum() + 1e-8)
    
    return loss

def train_step(model, batch, optimizer):
    model.train()
    
    # 1. Pad Batch
    input_ids = left_pad([b['input_ids'] for b in batch], tokenizer.pad_token_id)
    attention_mask = left_pad([b['attention_mask'] for b in batch], 0)
    response_mask = left_pad([b['response_mask'] for b in batch], 0)
    
    # Old log probs need to be padded. They are usually float. Pad with 0.
    old_log_probs = left_pad([b['old_log_probs'] for b in batch], 0.0)
    
    advantages = torch.stack([b['advantages'] for b in batch]).unsqueeze(-1) # [B, 1]
    
    # 2. Forward Pass (Recompute Logits)
    # This calls model() so gradients flow
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits[:, :-1, :]
    input_ids_shifted = input_ids[:, 1:]
    
    log_probs_all = torch.log_softmax(logits, dim=-1)
    new_token_log_probs = torch.gather(log_probs_all, 2, input_ids_shifted.unsqueeze(-1)).squeeze(-1)
    
    # 3. Calculate Loss
    loss = calculate_grpo_loss(
        new_log_probs=new_token_log_probs,
        old_log_probs=old_log_probs,
        response_mask=response_mask,
        advantages=advantages
    )
    
    # 4. Backward
    optimizer.zero_grad()
    accelerator.backward(loss)
    optimizer.step()
    
    return loss.item()

In [195]:
def train_loop(model, buffer, optimizer, num_epochs=2):
    random.shuffle(buffer)
    losses = []
    
    # TQDM for training epochs
    pbar = tqdm(range(num_epochs), desc="Training Epochs", leave=False)
    for _ in pbar:
        epoch_losses = []
        for i in range(0, len(buffer), BATCH_SIZE):
            batch = buffer[i : i + BATCH_SIZE]
            if len(batch) == 0: continue
            
            loss_val = train_step(model, batch, optimizer)
            epoch_losses.append(loss_val)
            
        avg_loss = np.mean(epoch_losses) if epoch_losses else 0
        losses.append(avg_loss)
        pbar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"})
        
    return np.mean(losses)

In [197]:
if __name__ == "__main__":
    # Load Tokenizer
    base_model_id = "HuggingFaceTB/SmolLM2-360M-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    tokenizer.padding_side = "left" # Important for generation
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    # Load Model
    print("[bold green]Loading Model...[/bold green]")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_id, 
        torch_dtype=torch.bfloat16,
        # device_map="auto" # Let Accelerator handle device
    )
    
    # Load PEFT adapter (Update path as needed)
    model = PeftModel.from_pretrained(base_model, '/Users/peeyushsharma/Downloads/lora_sft_2')
    # For demonstration, we just use the base model as if it were the peft model
    model = base_model 
    
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    
    # Data Loader
    data_loader = get_dataloader('syllogism', tokenizer)
    
    # Prepare with Accelerator
    # Note: If using PeftModel, ensure only adapter params are optimized
    model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
    
    buffer = []
    
    # Main Loop
    print("[bold green]Starting Collection & Training...[/bold green]")
    
    # TQDM over the dataloader
    main_pbar = tqdm(data_loader, desc="Data Collection", disable=not accelerator.is_local_main_process)
    
    for i, sample in enumerate(main_pbar):
        # Collect Experience
        experiences = collect_exp(
            model, 
            tokenizer, 
            sample['input_ids'], 
            sample['raw_batch'], 
            sample['attention_mask']
        )
        buffer.extend(experiences)
        
        # Trigger training when buffer is full
        if len(buffer) >= BUFFER_SIZE:
            main_pbar.set_description("Training...")
            avg_loss = train_loop(model, buffer, optimizer, num_epochs=EPOCHS)
            
            # Clear buffer and update stats
            buffer = []
            main_pbar.set_description(f"Collecting (Last Loss: {avg_loss:.4f})")
            
    print("[bold green]Finished![/bold green]")

Data Collection:   0%|          | 0/50 [01:06<?, ?it/s]


KeyboardInterrupt: 