# Welcome to Modal notebooks!

Write Python code and collaborate in real time. Your code runs in Modal's
**serverless cloud**, and anyone in the same workspace can join.

This notebook comes with some common Python libraries installed. Run
cells with `Shift+Enter`.

In [1]:
# %% [markdown]
# ## Cell 1: Install Dependencies
# Run this cell first to install all required packages


In [2]:
%uv pip install torch datasets accelerate peft trl bitsandbytes
%uv pip install flash-attn --no-build-isolation  # Optional: for faster attention
%uv pip install --upgrade transformers --no-cache   

[2mUsing Python 3.12.6 environment at: /usr/local[0m
[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mtorch==2.8.0+cu129                                                            [0m[2K[37m⠙[0m [2mdatasets==4.5.0                                                               [0m[2K[37m⠙[0m [2maccelerate==1.10.1                                                            [0m[2K[37m⠙[0m [2mpeft==0.18.1                                                                  [0m[2K[37m⠙[0m [2mtrl==0.27.2                                                                   [0m[2K[37m⠙[0m [2mbitsandbyte

In [3]:
import json
import os
from pathlib import Path
from typing import Optional, List, Dict, Any

import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig

In [4]:
MODEL_NAME = "arcee-ai/Trinity-Mini"
TRACES_FILE = "reasoning_traces_correct.json"  # Path to filtered correct traces
OUTPUT_DIR = "/mnt/models/trinity-triton-sft"
MAX_LENGTH = 4096  

In [5]:
# Training hyperparameters
LEARNING_RATE = 2e-4
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8  # Effective batch size = 1 * 8 = 8
NUM_EPOCHS = 3
WARMUP_RATIO = 0.1

# LoRA configuration for efficient finetuning
LORA_R = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05

In [6]:
def load_and_filter_traces(
    traces_file: str,
    require_correctness: bool = False,  # Default: include all (correct + failed)
    require_fast_0: bool = False,  # Default: include all
    require_fast_1: bool = False,  # Default: include all
    require_fast_2: bool = False,  # Default: include all
    min_speedup: Optional[float] = None,
) -> List[Dict[str, Any]]:
    """
    Load reasoning traces from JSON and optionally filter based on quality criteria.

    DEFAULT BEHAVIOR: Includes ALL traces (correct + failed) from the file.
    The reasoning_traces_correct.json is already curated with the desired mix.

    Training on both correct and failed traces helps the model:
    - Learn successful Triton patterns from correct examples
    - Understand common pitfalls and errors from failed examples
    - Develop better reasoning about what works and what doesn't

    Args:
        traces_file: Path to the reasoning_traces_correct.json file
        require_correctness: If True, only include correct traces
        require_fast_0: If True, only include traces where kernel compiled correctly
        require_fast_1: If True, only include traces with speedup > 1.0
        require_fast_2: If True, only include traces with speedup >= 2.0
        min_speedup: Optional minimum speedup threshold

    Returns:
        List of filtered trace dictionaries
    """
    # Load traces
    traces_path = Path(traces_file)
    if not traces_path.exists():
        raise FileNotFoundError(
            f"Traces file not found at {traces_path}. "
            "Please run the orchestrator first to generate traces: "
            "python orchestrator.py"
        )
    
    with open(traces_path, "r") as f:
        traces = json.load(f)
    
    print(f"Loaded {len(traces)} total traces")
    
    # Count correct vs failed for statistics
    correct_count = sum(1 for t in traces if t.get("result", {}).get("correctness", False))
    failed_count = len(traces) - correct_count

    # Apply optional filters
    filtered_traces = []
    filtered_out = 0

    for trace in traces:
        result = trace.get("result", {})

        # Apply filters only if explicitly requested
        if require_correctness and not result.get("correctness", False):
            filtered_out += 1
            continue

        if require_fast_0 and not result.get("fast_0", False):
            filtered_out += 1
            continue

        if require_fast_1 and not result.get("fast_1", False):
            filtered_out += 1
            continue

        if require_fast_2 and not result.get("fast_2", False):
            filtered_out += 1
            continue

        if min_speedup is not None:
            speedup = result.get("speedup", 0.0)
            if speedup < min_speedup:
                filtered_out += 1
                continue

        # Trace passed all filters
        filtered_traces.append(trace)

    stats = {
        "total": len(traces),
        "correct": correct_count,
        "failed": failed_count,
        "filtered_out": filtered_out,
        "final_count": len(filtered_traces),
    }
    
    # Print statistics
    print("\n" + "=" * 60)
    print("DATASET LOADING SUMMARY")
    print("=" * 60)
    print(f"Total traces loaded:       {stats['total']}")
    print(f"  ✓ Correct kernels:       {stats['correct']} ({stats['correct']/stats['total']*100:.1f}%)")
    print(f"  ✗ Failed kernels:        {stats['failed']} ({stats['failed']/stats['total']*100:.1f}%)")

    if stats['filtered_out'] > 0:
        print(f"\nFiltered out:              {stats['filtered_out']} traces")

    print("-" * 60)
    print(f"Final training set:        {stats['final_count']} traces")

    if stats['final_count'] > 0:
        final_correct = sum(1 for t in filtered_traces if t.get("result", {}).get("correctness", False))
        final_failed = stats['final_count'] - final_correct
        print(f"  Training mix:            {final_correct/stats['final_count']*100:.1f}% correct, "
              f"{final_failed/stats['final_count']*100:.1f}% failed")

    print("=" * 60 + "\n")
    
    return filtered_traces

In [7]:
try:
    traces = load_and_filter_traces(TRACES_FILE)  # Include ALL traces by default
    print(f"✓ Successfully loaded {len(traces)} traces for training")
except FileNotFoundError as e:
    print(f"Note: {e}")
    traces = []

Loaded 158 total traces

DATASET LOADING SUMMARY
Total traces loaded:       158
  ✓ Correct kernels:       133 (84.2%)
  ✗ Failed kernels:        25 (15.8%)
------------------------------------------------------------
Final training set:        158 traces
  Training mix:            84.2% correct, 15.8% failed

✓ Successfully loaded 158 traces for training


In [8]:
def format_trace_for_sft(trace: Dict[str, Any]) -> Dict[str, Any]:
    """
    Format a single trace into the SFT conversational format.

    The format follows the pattern:
    User: PyTorch code
    Assistant: <think>reasoning</think><triton>code</triton>
    """
    pytorch_code = trace.get("pytorch_code", "")
    # Prefer model_reasoning (actual reasoning process), fallback to thinking (polished explanation)
    reasoning = trace.get("model_reasoning") or trace.get("thinking", "")
    triton_code = trace.get("triton_code", "")

    # Build the user message
    user_content = f"""Convert the following PyTorch code to an optimized Triton kernel:

```python
{pytorch_code}
```

Generate a complete Triton implementation that produces the same output as the PyTorch code."""

    # Build the assistant message with reasoning and code
    # Always wrap reasoning in <think></think> tags
    if reasoning:
        if not reasoning.strip().startswith("<think>"):
            reasoning = f"<think>\n{reasoning}\n</think>"
    else:
        reasoning = "<think>\n</think>"

    # Wrap triton code in <triton></triton> tags
    if triton_code:
        if not triton_code.strip().startswith("<triton>"):
            triton_code = f"<triton>\n{triton_code}\n</triton>"
    else:
        triton_code = "<triton>\n</triton>"

    assistant_content = f"{reasoning}\n\n{triton_code}"

    # Return in conversational format for SFTTrainer
    return {
        "messages": [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": assistant_content}
        ]
    }


In [9]:
def prepare_dataset(traces: List[Dict[str, Any]], test_size: float = 0.1) -> DatasetDict:
    """
    Prepare a HuggingFace DatasetDict from traces.

    Args:
        traces: List of trace dictionaries
        test_size: Fraction of data to use for validation

    Returns:
        DatasetDict with 'train' and 'test' splits
    """
    if not traces:
        raise ValueError("No traces provided! Cannot create dataset.")

    # Format all traces for SFT
    formatted_data = [format_trace_for_sft(trace) for trace in traces]

    # Create dataset
    dataset = Dataset.from_list(formatted_data)

    # Split into train/test
    if len(dataset) > 1:
        split = dataset.train_test_split(test_size=test_size, seed=42)
        return split
    else:
        # Not enough data for split, use same for both
        return DatasetDict({
            "train": dataset,
            "test": dataset
        })


# Prepare the dataset
dataset = prepare_dataset(traces)
print(f"\nDataset prepared:")
print(f"  Train samples: {len(dataset['train'])}")
print(f"  Test samples:  {len(dataset['test'])}")

# Preview a sample
print("\n" + "=" * 60)
print("SAMPLE TRAINING EXAMPLE")
print("=" * 60)
if len(dataset['train']) > 0:
    sample = dataset['train'][-1]
    print(f"User message (truncated):\n{sample['messages'][0]['content']}...")
    print(f"\nAssistant message (truncated):\n{sample['messages'][1]['content']}...")



Dataset prepared:
  Train samples: 142
  Test samples:  16

SAMPLE TRAINING EXAMPLE
User message (truncated):
Convert the following PyTorch code to an optimized Triton kernel:

```python
import math
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data


def matmul(x, y):
    if x.dim() == y.dim():
        return x @ y
    if x.dim() == y.dim() - 1:
        return (x.unsqueeze(-2) @ y).squeeze(-2)
    return (x @ y.unsqueeze(-2)).squeeze(-2)


class LayerNorm(nn.Module):

    def __init__(self, d_model, eps=1e-06):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta


class ResidualBlock(nn.Module):

    def __init__(self, layer, d_model, dropout_ratio):
        supe

In [10]:
def load_model_and_tokenizer(
    model_name: str = MODEL_NAME,
    use_4bit: bool = True,
    use_flash_attention: bool = False,
):
    """
    Load the Trinity-Mini model with quantization and LoRA for efficient finetuning.
    
    Args:
        model_name: HuggingFace model name
        use_4bit: Whether to use 4-bit quantization (recommended for 26B model)
        use_flash_attention: Whether to use Flash Attention 2
        
    Returns:
        tuple: (model, tokenizer)
    """
    print(f"Loading tokenizer for {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
    )
    
    # Set padding token if not present (required for Trinity-Mini)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        print(f"Set pad_token to eos_token: {tokenizer.pad_token}")

    # Load and modify config (required for AFMoE models like Trinity-Mini)
    from transformers import AutoConfig
    print("Loading model config...")
    config = AutoConfig.from_pretrained(model_name)
    config.pad_token_id = tokenizer.pad_token_id
    print(f"Set config.pad_token_id = {config.pad_token_id}")

    # Configure quantization
    if use_4bit:
        print("Configuring 4-bit quantization...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,  # Nested quantization for more memory savings
            bnb_4bit_quant_type="nf4",  # Normal Float 4 quantization
            bnb_4bit_compute_dtype=torch.bfloat16,  # Compute in bfloat16
        )
    else:
        bnb_config = None
    
    # Model loading kwargs
    model_kwargs = {
        "config": config,  # Pass modified config with pad_token_id
        "torch_dtype": torch.bfloat16,
        "device_map": "auto",
    }
    
    if bnb_config:
        model_kwargs["quantization_config"] = bnb_config
    
    if use_flash_attention:
        try:
            model_kwargs["attn_implementation"] = "flash_attention_2"
            print("Using Flash Attention 2")
        except Exception:
            print("Flash Attention 2 not available, using default attention")
    
    print(f"Loading model {model_name}...")
    print("This may take a few minutes for a 26B parameter model...")
    
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        **model_kwargs,
        
    )
    
    # Prepare model for k-bit training (required for QLoRA)
    if use_4bit:
        model = prepare_model_for_kbit_training(model)
    
    print(f"Model loaded successfully!")
    print(f"  Model type: {type(model).__name__}")
    print(f"  Device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'N/A'}")
    
    return model, tokenizer


In [11]:
%uv pip install --upgrade transformers


[2mUsing Python 3.12.6 environment at: /usr/local[0m
[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mtransformers==5.0.0                                                           [0m[2K[37m⠙[0m [2mfilelock==3.20.3                                                              [0m[2K[37m⠙[0m [2mhuggingface-hub==1.3.7                                                        [0m[2K[37m⠙[0m [2mnumpy==2.4.2                                                                  [0m[2K[37m⠙[0m [2mpackaging==26.0                                                               [0m[2K[37m⠙[0m [2mpyyaml==6.0.3                                                                 [0m[2K[37m⠙[0m [2mregex==2026

In [12]:
import transformers
print(transformers.__version__)

5.0.0


In [13]:
# Weights & Biases configuration
WANDB_PROJECT = "trinity-triton-sft"
WANDB_RUN_NAME = "trinity-sft-run"  # A
%uv pip install wandb

[2mUsing Python 3.12.6 environment at: /usr/local[0m
[2mAudited [1m1 package[0m [2min 5ms[0m[0m
Note: you may need to restart the kernel to use updated packages.


In [14]:
def create_lora_config(
    r: int = LORA_R,
    lora_alpha: int = LORA_ALPHA,
    lora_dropout: float = LORA_DROPOUT,
    target_modules: Optional[List[str]] = None,
) -> LoraConfig:
    """
    Create LoRA configuration for Trinity-Mini.
    
    Trinity-Mini uses an MoE architecture, so we target:
    - Query, Key, Value projections in attention
    - Gate and up/down projections in MLP
    - Optionally the expert layers (if applicable)
    
    Args:
        r: LoRA rank (higher = more parameters to train)
        lora_alpha: LoRA alpha (scaling factor, typically 2*r)
        lora_dropout: Dropout probability for LoRA layers
        target_modules: List of module names to apply LoRA to
        
    Returns:
        LoraConfig instance
    """
    if target_modules is None:
        # Default targets for typical LLM architectures
        # This should work for most transformer models
        target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj",  # Attention
            "gate_proj", "up_proj", "down_proj",      # MLP
        ]
    
    lora_config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules,
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    print("LoRA Configuration:")
    print(f"  Rank (r):          {r}")
    print(f"  Alpha:             {lora_alpha}")
    print(f"  Dropout:           {lora_dropout}")
    print(f"  Target modules:    {target_modules}")
    
    return lora_config

In [15]:
lora_config = create_lora_config()


LoRA Configuration:
  Rank (r):          64
  Alpha:             128
  Dropout:           0.05
  Target modules:    ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']


In [16]:
import os
# Set your W&B API key (get it from https://wandb.ai/authorize)
os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY", "")  # Set in environment or .env file

In [17]:
def create_training_config(
    output_dir: str = OUTPUT_DIR,
    learning_rate: float = LEARNING_RATE,
    batch_size: int = BATCH_SIZE,
    gradient_accumulation_steps: int = GRADIENT_ACCUMULATION_STEPS,
    num_epochs: int = NUM_EPOCHS,
    max_seq_length: int = MAX_LENGTH,
    warmup_ratio: float = WARMUP_RATIO,
) -> SFTConfig:
    """
    Create SFTConfig with optimized settings for Trinity-Mini finetuning.
    
    Returns:
        SFTConfig instance
    """
    config = SFTConfig(
        output_dir=output_dir,
        
        # Training parameters
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        
        # Learning rate
        learning_rate=learning_rate,
        lr_scheduler_type="cosine",
        warmup_ratio=warmup_ratio,
        
        # Optimizer
        optim="adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch",
        weight_decay=0.01,
        max_grad_norm=1.0,
        
        # Sequence settings
        max_length=MAX_LENGTH,
        packing=False,  # Don't pack examples (each kernel trace is independent)
        
        # Memory optimization
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        
        # Precision
        bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
        fp16=not (torch.cuda.is_bf16_supported() if torch.cuda.is_available() else True),
        
        # Logging (Weights & Biases)
        logging_steps=10,
        logging_first_step=True,
        report_to=["wandb"],
        run_name=WANDB_RUN_NAME,  # W&B run name
        
        # Evaluation
        eval_strategy="steps",
        eval_steps=20,  # Changed to match save frequency

        # CHECKPOINT SAVING - FIXED
        save_strategy="steps",
        save_steps=20,  # Save every 20 iterations (changed from 100)
        save_total_limit=10,  # Increased to keep more checkpoints
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        
        # Misc
        seed=42,
        dataloader_num_workers=4,
        remove_unused_columns=True,
    )
    
    print("\nTraining Configuration:")
    print(f"  Output directory:     {output_dir}")
    print(f"  Epochs:               {num_epochs}")
    print(f"  Batch size:           {batch_size}")
    print(f"  Gradient accum:       {gradient_accumulation_steps}")
    print(f"  Effective batch:      {batch_size * gradient_accumulation_steps}")
    print(f"  Learning rate:        {learning_rate}")
    print(f"  Max sequence length:  {max_seq_length}")
    print(f"  Gradient checkpoint:  True")
    print(f"  Logging:              Weights & Biases")
    
    return config


training_config = create_training_config()

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.



Training Configuration:
  Output directory:     /mnt/models/trinity-triton-sft
  Epochs:               3
  Batch size:           1
  Gradient accum:       8
  Effective batch:      8
  Learning rate:        0.0002
  Max sequence length:  4096
  Gradient checkpoint:  True
  Logging:              Weights & Biases


In [18]:
def train(
    model,
    tokenizer,
    dataset: DatasetDict,
    training_config: SFTConfig,
    lora_config: LoraConfig,
    wandb_project: str = WANDB_PROJECT,
):
    """
    Initialize SFTTrainer and start training with W&B logging.
    """
    print("\\n" + "=" * 60)
    print("INITIALIZING SFT TRAINER")
    print("=" * 60)

    # Initialize Weights & Biases
    import wandb
    wandb.init(
        project=wandb_project,
        name=training_config.run_name,
        config={
            "model": MODEL_NAME,
            "lora_r": lora_config.r,
            "lora_alpha": lora_config.lora_alpha,
            "learning_rate": training_config.learning_rate,
            "batch_size": training_config.per_device_train_batch_size,
            "gradient_accumulation": training_config.gradient_accumulation_steps,
            "epochs": training_config.num_train_epochs,
            "max_length": training_config.max_length,
        }
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        args=training_config,
        peft_config=lora_config,
        processing_class=tokenizer,  # Explicitly pass tokenizer for proper saving
    )

    # Print trainable parameters
    trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in trainer.model.parameters())
    print(f"\\nTrainable parameters: {trainable_params:,}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable %: {100 * trainable_params / total_params:.2f}%")

    print("\\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)

    # Train! Checkpoints will be saved automatically every 20 steps
    trainer.train()

    # Save the final model explicitly (in addition to checkpoints)
    print("\\nSaving final model...")
    final_output_dir = os.path.join(training_config.output_dir, "final_model")
    trainer.save_model(final_output_dir)
    
    # Also save just the adapter weights for easier loading
    trainer.model.save_pretrained(final_output_dir)

    # Save tokenizer
    tokenizer.save_pretrained(final_output_dir)
    tokenizer.save_pretrained(training_config.output_dir)

    print(f"\\nTraining complete!")
    print(f"Checkpoints saved to: {training_config.output_dir}/checkpoint-*")
    print(f"Final model saved to: {final_output_dir}")

    # Finish W&B run
    wandb.finish()

    return trainer



In [19]:
def main():
    """
    Main training pipeline.
    
    1. Load and filter reasoning traces
    2. Prepare dataset in SFT format
    3. Load Trinity-Mini with quantization
    4. Apply LoRA for efficient finetuning
    5. Train with SFTTrainer
    6. Save the finetuned model
    """
    print("=" * 60)
    print("TRINITY-MINI FINETUNING PIPELINE")
    print("Finetuning on PyTorch → Triton reasoning traces")
    print("=" * 60)
    
    # Step 1: Load traces (includes ALL traces: correct + failed)
    print("\n[Step 1/5] Loading traces...")
    try:
        traces = load_and_filter_traces(TRACES_FILE)  # No filtering - include all
    except FileNotFoundError:
        print("ERROR: No traces file found!")
        print("Please run the orchestrator first to generate traces:")
        print("  cd .. && python orchestrator.py")
        return

    if len(traces) < 10:
        print(f"WARNING: Only {len(traces)} traces available.")
        print("Consider generating more traces for better training.")
    
    # Step 2: Prepare dataset
    print("\n[Step 2/5] Preparing dataset...")
    dataset = prepare_dataset(traces)
    
    # Step 3: Load model
    print("\n[Step 3/5] Loading Trinity-Mini model...")
    model, tokenizer = load_model_and_tokenizer()
    
    # Step 4: Configure LoRA
    print("\n[Step 4/5] Configuring LoRA...")
    lora_config = create_lora_config()
    
    # Step 5: Train
    print("\n[Step 5/5] Starting training...")
    training_config = create_training_config()
    trainer = train(model, tokenizer, dataset, training_config, lora_config)
    
    print("\n" + "=" * 60)
    print("TRAINING COMPLETE!")
    print("=" * 60)
    print(f"Model saved to: {OUTPUT_DIR}")
    print("\nTo use the finetuned model:")
    print(f"""
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

model = AutoPeftModelForCausalLM.from_pretrained("{OUTPUT_DIR}")
tokenizer = AutoTokenizer.from_pretrained("{OUTPUT_DIR}")
""")
    
    return trainer

In [20]:
if __name__ == "__main__":
    main()



TRINITY-MINI FINETUNING PIPELINE
Finetuning on PyTorch → Triton reasoning traces

[Step 1/5] Loading traces...
Loaded 158 total traces

DATASET LOADING SUMMARY
Total traces loaded:       158
  ✓ Correct kernels:       133 (84.2%)
  ✗ Failed kernels:        25 (15.8%)
------------------------------------------------------------
Final training set:        158 traces
  Training mix:            84.2% correct, 15.8% failed


[Step 2/5] Preparing dataset...

[Step 3/5] Loading Trinity-Mini model...
Loading tokenizer for arcee-ai/Trinity-Mini...
Loading model config...
Set config.pad_token_id = 12
Configuring 4-bit quantization...
Loading model arcee-ai/Trinity-Mini...
This may take a few minutes for a 26B parameter model...


Loading weights:   0%|          | 0/12031 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'min_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.


Model loaded successfully!
  Model type: AfmoeForCausalLM
  Device map: N/A

[Step 4/5] Configuring LoRA...
LoRA Configuration:
  Rank (r):          64
  Alpha:             128
  Dropout:           0.05
  Target modules:    ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']

[Step 5/5] Starting training...

Training Configuration:
  Output directory:     /mnt/models/trinity-triton-sft
  Epochs:               3
  Batch size:           1
  Gradient accum:       8
  Effective batch:      8
  Learning rate:        0.0002
  Max sequence length:  4096
  Gradient checkpoint:  True
  Logging:              Weights & Biases
INITIALIZING SFT TRAINER


[34m[1mwandb[0m: Currently logged in as: [33mppbhatt500[0m ([33mppbhatt500-verizon[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.24.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/root/wandb/run-20260204_060831-fh1jdjwn[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mtrinity-sft-run[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/ppbhatt500-verizon/trinity-triton-sft[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/ppbhatt500-verizon/trinity-triton-sft/runs/fh1jdjwn[0m


Tokenizing train dataset:   0%|          | 0/142 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/142 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/16 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/16 [00:00<?, ? examples/s]

Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 3, 'bos_token_id': 0, 'pad_token_id': 12}.


\nTrainable parameters: 2,333,999,104
Total parameters: 15,806,117,632
Trainable %: 14.77%
STARTING TRAINING


Step,Training Loss,Validation Loss
20,2.421357,0.810319


In [30]:
import shutil
from pathlib import Path

source = Path("/mnt/models/trinity-triton-sft")
dest = Path("/mnt/arcee-vol/models/trinity-triton-sft")

dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(source, dest, dirs_exist_ok=True)

print(f"Checkpoints copied: {list(dest.glob('checkpoint-*'))}")

In [31]:
import shutil
from pathlib import Path

source_cp40 = Path("/mnt/models/trinity-triton-sft/checkpoint-40")
dest_cp40 = Path("/mnt/arcee-vol/models/trinity-triton-sft/checkpoint-40")

if source_cp40.exists():
    print(f"Copying checkpoint-40 ({sum(f.stat().st_size for f in source_cp40.rglob('*') if f.is_file()) / 1e9:.2f} GB)...")
    shutil.copytree(source_cp40, dest_cp40, dirs_exist_ok=True)
    print("✓ Done")
else:
    print("❌ checkpoint-40 not found in original location")

In [21]:
def inference(
    pytorch_code: str,
    model_path: str = OUTPUT_DIR,
    max_new_tokens: int = 2048,
):
    """
    Run inference with the finetuned model.
    
    Args:
        pytorch_code: PyTorch code to convert to Triton
        model_path: Path to the finetuned model
        max_new_tokens: Maximum tokens to generate
        
    Returns:
        Generated Triton kernel with reasoning
    """
    from peft import AutoPeftModelForCausalLM
    
    print("Loading finetuned model...")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoPeftModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    
    # Format input as conversation
    messages = [
        {
            "role": "user",
            "content": f"""Convert the following PyTorch code to an optimized Triton kernel:

```python
{pytorch_code}
```

Generate a complete Triton implementation that produces the same output as the PyTorch code."""
        }
    ]
    
    # Apply chat template
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate
    print("Generating Triton kernel...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    # Decode
    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    
    return response