# 🚀 Enhanced TRL Training with Kubeflow SDK and Advanced Checkpointing

This notebook demonstrates how to use the **Kubeflow Trainer SDK** to create and manage TrainJobs with enhanced checkpointing capabilities, following the same pattern as `trl-trainjob-clean.yaml`.

## 🎯 Features Demonstrated

- ✅ **Kubeflow SDK Integration**: Programmatic TrainJob creation and management
- ✅ **Enhanced Checkpointing**: Controller-managed progress tracking and model checkpoints
- ✅ **TRL SFTTrainer**: Advanced supervised fine-tuning with LoRA
- ✅ **Distributed Training**: Multi-node coordination with operator-injected variables
- ✅ **Real-Time Monitoring**: Live progress tracking via Kubernetes API
- ✅ **Production Ready**: Fault tolerance and graceful shutdown handling

## 📚 References
- [Kubeflow Trainer SDK](https://github.com/kubeflow/sdk)
- [TRL Documentation](https://huggingface.co/docs/trl/)
- [PEFT Documentation](https://huggingface.co/docs/peft/)


## 📦 Install Dependencies

In [None]:
# Install Kubeflow SDK from local development version
%pip install /Users/abdhumal/Dev/RedHatDev/sdk/dist/kubeflow-0.1.0-py3-none-any.whl

# Install additional dependencies for training
%pip install "transformers[torch]" "trl" "peft" "datasets" "accelerate" "torch" "numpy"


In [None]:
## 🔧 Setup and Configuration


import json
import time
from datetime import datetime
from kubeflow.trainer import TrainerClient, CustomTrainer, HuggingFaceDatasetInitializer, HuggingFaceModelInitializer, Initializer

# Initialize Kubeflow TrainerClient
trainer_client = TrainerClient(namespace="abdhumal-test")

# Configuration
NAMESPACE = "abdhumal-test"
print(f"🎮 Kubeflow TrainerClient initialized for namespace: {NAMESPACE}")


## 🐍 Load Exact Training Script from trl-trainjob.yaml

This cell extracts and loads the **identical** training script embedded in `trl-trainjob.yaml`


In [None]:
# Define the TRL training function that matches trl-trainjob.yaml
def train_gpt2_with_trl_checkpointing(args):
    """
    Advanced TRL training script with controller integration and distributed coordination.
    This matches the exact same training logic as embedded in trl-trainjob.yaml.
    """
    import os
    import json
    import time
    import signal
    import torch
    import random
    import numpy
    from numpy.core.multiarray import _reconstruct
    import torch.serialization
    import torch.distributed as dist
    import logging
    from datetime import datetime
    from pathlib import Path
    from datasets import load_dataset, load_from_disk
    from transformers import (
        AutoTokenizer,
        TrainingArguments,
        TrainerState,
        TrainerControl,
        TrainerCallback,
        set_seed,
    )
    from transformers.trainer_utils import get_last_checkpoint
    from trl import (
        ModelConfig,
        ScriptArguments,
        SFTConfig,
        SFTTrainer,
        TrlParser,
        get_peft_config,
        get_quantization_config,
        get_kbit_device_map,
    )
    
    # Safe tensor loading configuration
    torch.serialization.add_safe_globals([_reconstruct, numpy.ndarray, numpy.dtype, numpy.dtypes.UInt32DType])
    
    # Patch torch.load to handle weights_only parameter and device mapping
    original_torch_load = torch.load
    def patched_torch_load(*args, **kwargs):
        if 'weights_only' not in kwargs:
            kwargs['weights_only'] = False
        if 'map_location' not in kwargs:
            if torch.cuda.is_available():
                kwargs['map_location'] = 'cuda'
            else:
                kwargs['map_location'] = 'cpu'
        return original_torch_load(*args, **kwargs)
    torch.load = patched_torch_load
    
    class AdvancedDistributedCheckpointCallback(TrainerCallback):
        """
        Production-grade distributed SIGTERM handling with tensor-based coordination.
        Combines the example's advanced features with controller integration.
        """
        def __init__(self, output_dir: str):
            self.output_dir = output_dir
            self.checkpoint_requested = False
            self.save_triggered = False
            self.checkpoint_stream = None
            self.sigterm_tensor = None
            
            # Use controller-injected checkpoint configuration
            self.checkpoint_enabled = os.environ.get('CHECKPOINT_ENABLED', 'false').lower() == 'true'
            self.checkpoint_uri = os.environ.get('CHECKPOINT_URI', '/workspace/checkpoints')
            
            # Controller progress file (simple)
            self.progress_file = os.environ.get('TRAINING_PROGRESS_FILE', '/workspace/training_progress.json')

        def _log_message(self, message: str):
            """Helper to print messages with a timestamp."""
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            print(f"[{timestamp}] {message}")
        
        def _write_progress(self, state: TrainerState):
            """Simple progress file writer for controller consumption"""
            # Only write from rank 0
            rank = int(os.environ.get('RANK', '0'))
            if rank != 0:
                return
                
            try:
                # Extract metrics from trainer state
                latest_loss = 0.0
                latest_lr = 0.0
                if state.log_history:
                    latest_log = state.log_history[-1]
                    latest_loss = latest_log.get('loss', latest_log.get('train_loss', latest_log.get('training_loss', 0.0)))
                    latest_lr = latest_log.get('learning_rate', latest_log.get('lr', latest_log.get('train_lr', 0.0)))
                
                progress_data = {
                    "epoch": int(state.epoch) if state.epoch else 1,
                    "totalEpochs": int(state.num_train_epochs) if state.num_train_epochs else 1,
                    "step": state.global_step,
                    "totalSteps": state.max_steps,
                    "loss": f"{latest_loss:.4f}",
                    "learningRate": f"{latest_lr:.6f}",
                    "percentComplete": f"{(state.global_step / state.max_steps * 100):.1f}" if state.max_steps > 0 else "0.0",
                    "lastUpdateTime": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
                }
                
                # Atomic write
                temp_file = self.progress_file + '.tmp'
                with open(temp_file, 'w') as f:
                    json.dump(progress_data, f, indent=2)
                os.rename(temp_file, self.progress_file)
                os.chmod(self.progress_file, 0o644)
                
            except Exception as e:
                pass  # Silent fail - controller handles missing files gracefully

        def _init_distributed_signal_tensor(self):
            """Initialize tensor for distributed SIGTERM signaling."""
            try:
                if dist.is_initialized():
                    device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')
                    self.sigterm_tensor = torch.zeros(1, dtype=torch.float32, device=device)
                    self._log_message(f"Initialized distributed SIGTERM tensor on device: {device}")
                else:
                    self._log_message("Distributed training not initialized - using local SIGTERM handling only")
            except Exception as e:
                self._log_message(f"Failed to initialize distributed SIGTERM tensor: {e}. Using local handling only.")

        def _check_distributed_sigterm(self):
            """Check if any rank has received SIGTERM."""
            try:
                if dist.is_initialized() and self.sigterm_tensor is not None:
                    dist.all_reduce(self.sigterm_tensor, op=dist.ReduceOp.MAX)
                    return self.sigterm_tensor.item() > 0.5
            except Exception as e:
                self._log_message(f"Distributed SIGTERM check failed: {e}. Using local signal only.")
            return self.checkpoint_requested

        def _sigterm_handler(self, signum, frame):
            """Sets a flag and updates the tensor to indicate that a SIGTERM signal was received."""
            rank = os.environ.get("RANK", "-1")
            self._log_message(f"Rank {rank}: SIGTERM received, flagging for distributed checkpoint.")
            self.checkpoint_requested = True
            if self.sigterm_tensor is not None:
                self.sigterm_tensor.fill_(1.0)

        def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            rank = os.environ.get("RANK", "-1")
            os.makedirs(self.output_dir, exist_ok=True)
            self._init_distributed_signal_tensor()
            
            if torch.cuda.is_available():
                self.checkpoint_stream = torch.cuda.Stream()
                self._log_message(f"Rank {rank}: Created dedicated CUDA stream for checkpointing.")

            signal.signal(signal.SIGTERM, self._sigterm_handler)
            self._log_message(f"Rank {rank}: Advanced distributed SIGTERM handler registered.")

            try:
                if dist.is_initialized():
                    dist.barrier()
                    self._log_message(f"Rank {rank}: Distributed coordination setup synchronized across all ranks")
            except Exception as e:
                self._log_message(f"Rank {rank}: Failed to synchronize distributed setup: {e}")

        def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            # Write progress for controller (every logging step)
            if state.global_step % args.logging_steps == 0:
                self._write_progress(state)
                
            if self._check_distributed_sigterm() and not self.save_triggered:
                rank = os.environ.get("RANK", "-1")
                self._log_message(f"Rank {rank}: Distributed SIGTERM detected, initiating checkpoint at step {state.global_step}.")
                self.save_triggered = True
                control.should_save = True
                control.should_training_stop = True

        def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            # Write final progress
            self._write_progress(state)
            
            rank = os.environ.get("RANK", "-1")
            if rank == "0" and self.checkpoint_requested:
                self._log_message(f"Rank {rank}: Training ended due to distributed SIGTERM checkpoint request.")

        def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            rank = os.environ.get("RANK", "-1")
            if rank == "0":
                self._log_message(f"Rank {rank}: Checkpoint save completed.")
                if self.checkpoint_requested:
                    self._log_message(f"Rank {rank}: Distributed SIGTERM-triggered checkpoint save finished successfully.")
    
    def setup_distributed():
        """Initialize distributed training using operator-injected PET environment variables"""
        # Use PET_* environment variables injected by the training operator
        node_rank = int(os.getenv('PET_NODE_RANK', '0'))
        num_nodes = int(os.getenv('PET_NNODES', '1'))
        nproc_per_node = int(os.getenv('PET_NPROC_PER_NODE', '1'))
        master_addr = os.getenv('PET_MASTER_ADDR', 'localhost')
        master_port = os.getenv('PET_MASTER_PORT', '29500')
        
        # Calculate standard PyTorch distributed variables
        local_rank = int(os.getenv('LOCAL_RANK', '0'))
        world_size = num_nodes * nproc_per_node
        global_rank = node_rank * nproc_per_node + local_rank
        
        # Set standard PyTorch environment variables for compatibility
        os.environ['RANK'] = str(global_rank)
        os.environ['WORLD_SIZE'] = str(world_size)
        os.environ['LOCAL_RANK'] = str(local_rank)
        os.environ['MASTER_ADDR'] = master_addr
        os.environ['MASTER_PORT'] = master_port
        
        # Initialize distributed training if world_size > 1
        if world_size > 1:
            try:
                torch.distributed.init_process_group(
                    backend='gloo',
                    rank=global_rank,
                    world_size=world_size
                )
                torch.distributed.barrier()
            except Exception as e:
                print(f"Warning: Failed to initialize distributed training: {e}")
        
        return local_rank, global_rank, world_size
    
    def load_dataset_from_initializer():
        """Load dataset from V2 initializer or fallback to download"""
        dataset_dir = Path("/workspace/dataset")
        
        if dataset_dir.exists() and any(dataset_dir.iterdir()):
            try:
                full_dataset = load_from_disk(str(dataset_dir))
                if isinstance(full_dataset, dict):
                    train_dataset = full_dataset.get('train', full_dataset.get('train[:100]'))
                    test_dataset = full_dataset.get('test', full_dataset.get('test[:20]'))
                else:
                    # Split dataset if it's not already split
                    train_size = min(100, len(full_dataset) - 20)
                    train_dataset = full_dataset.select(range(train_size))
                    test_dataset = full_dataset.select(range(train_size, min(train_size + 20, len(full_dataset))))
                
                return train_dataset, test_dataset
            except Exception as e:
                print(f"Failed to load from initializer: {e}")
        
        # Fallback to direct download
        dataset_name = os.getenv('DATASET_NAME', 'tatsu-lab/alpaca')
        train_split = os.getenv('DATASET_TRAIN_SPLIT', 'train[:100]')
        test_split = os.getenv('DATASET_TEST_SPLIT', 'train[100:120]')
        
        train_dataset = load_dataset(dataset_name, split=train_split)
        test_dataset = load_dataset(dataset_name, split=test_split)
        
        return train_dataset, test_dataset
    
    def load_model_from_initializer():
        """Load model and tokenizer from V2 initializer or fallback to download"""
        model_dir = Path("/workspace/model")
        
        if model_dir.exists() and any(model_dir.iterdir()):
            model_path = str(model_dir)
        else:
            model_path = os.getenv('MODEL_NAME', 'gpt2')
        
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Set up chat template for instruction following
            if tokenizer.chat_template is None:
                tokenizer.chat_template = (
                    "{% for message in messages %}"
                    "{% if message['role'] == 'user' %}"
                    "### Instruction:\n{{ message['content'] }}\n"
                    "{% elif message['role'] == 'assistant' %}"
                    "### Response:\n{{ message['content'] }}{{ eos_token }}\n"
                    "{% endif %}"
                    "{% endfor %}"
                )
            
            return model_path, tokenizer
            
        except Exception as e:
            print(f"Error loading model: {e}")
            # Fallback to gpt2
            model_path = 'gpt2'
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            return model_path, tokenizer
    
    def prepare_datasets(train_dataset, test_dataset, tokenizer):
        """Prepare datasets for training with proper formatting"""
        def template_dataset(sample):
            # Handle different dataset formats
            if 'instruction' in sample and 'output' in sample:
                # Alpaca format
                messages = [
                    {"role": "user", "content": sample['instruction']},
                    {"role": "assistant", "content": sample['output']},
                ]
            elif 'question' in sample and 'answer' in sample:
                # GSM8K format
                messages = [
                    {"role": "user", "content": sample['question']},
                    {"role": "assistant", "content": sample['answer']},
                ]
            else:
                # Fallback format
                content = str(sample.get('text', sample.get('content', 'Sample text')))
                messages = [
                    {"role": "user", "content": "Complete this text:"},
                    {"role": "assistant", "content": content},
                ]
            
            return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}
        
        # Get original column names to remove
        train_columns = list(train_dataset.features.keys())
        train_columns.remove('text') if 'text' in train_columns else None
        
        train_dataset = train_dataset.map(template_dataset, remove_columns=train_columns)
        
        if test_dataset is not None:
            test_columns = list(test_dataset.features.keys())
            test_columns.remove('text') if 'text' in test_columns else None
            test_dataset = test_dataset.map(template_dataset, remove_columns=test_columns)
        
        return train_dataset, test_dataset
    
    def get_training_parameters():
        """Get training parameters from controller and environment variables"""
        # Use controller-injected checkpoint configuration
        checkpoint_dir = Path(os.getenv('CHECKPOINT_URI', '/workspace/checkpoints'))
        checkpoint_enabled = os.getenv('CHECKPOINT_ENABLED', 'false').lower() == 'true'
        checkpoint_interval = os.getenv('CHECKPOINT_INTERVAL', '30s')
        max_checkpoints = int(os.getenv('CHECKPOINT_MAX_RETAIN', '5'))
        
        # Training hyperparameters from environment (with sensible defaults)
        parameters = {
            'model_name_or_path': os.getenv('MODEL_NAME', 'gpt2'),
            'model_revision': 'main',
            'torch_dtype': 'bfloat16',
            'use_peft': True,
            'lora_r': int(os.getenv('LORA_R', '16')),
            'lora_alpha': int(os.getenv('LORA_ALPHA', '32')),
            'lora_dropout': float(os.getenv('LORA_DROPOUT', '0.1')),
            'lora_target_modules': ['c_attn', 'c_proj'],  # GPT-2 specific
            'dataset_name': os.getenv('DATASET_NAME', 'tatsu-lab/alpaca'),
            'dataset_config': 'main',
            'dataset_train_split': os.getenv('DATASET_TRAIN_SPLIT', 'train[:100]'),
            'dataset_test_split': os.getenv('DATASET_TEST_SPLIT', 'train[100:120]'),
            'max_seq_length': int(os.getenv('MAX_SEQ_LENGTH', '512')),
            'num_train_epochs': int(os.getenv('MAX_EPOCHS', '3')),
            'per_device_train_batch_size': int(os.getenv('BATCH_SIZE', '2')),
            'per_device_eval_batch_size': int(os.getenv('BATCH_SIZE', '2')),
            'eval_strategy': 'steps',
            'eval_steps': int(os.getenv('EVAL_STEPS', '25')),
            'bf16': torch.cuda.is_available(),  # Only use bf16 if CUDA is available
            'fp16': not torch.cuda.is_available(),  # Use fp16 for CPU training
            'learning_rate': float(os.getenv('LEARNING_RATE', '5e-5')),
            'warmup_steps': int(os.getenv('WARMUP_STEPS', '10')),
            'lr_scheduler_type': 'cosine',
            'optim': 'adamw_torch',
            'max_grad_norm': 1.0,
            'seed': 42,
            'gradient_accumulation_steps': int(os.getenv('GRADIENT_ACCUMULATION_STEPS', '4')),
            'save_strategy': 'steps',
            'save_steps': int(os.getenv('SAVE_STEPS', '20')),
            'save_total_limit': max_checkpoints if checkpoint_enabled else None,
            'logging_strategy': 'steps',
            'logging_steps': int(os.getenv('LOGGING_STEPS', '5')),
            'report_to': [],
            'output_dir': str(checkpoint_dir),
        }
        
        return parameters
    
    # Main training function with controller integration and distributed coordination
    import os
    
    # Setup distributed training
    local_rank, global_rank, world_size = setup_distributed()
    
    # Create necessary directories
    os.makedirs("/workspace/cache/transformers", exist_ok=True)
    os.makedirs("/workspace/cache", exist_ok=True)
    os.makedirs("/workspace/cache/datasets", exist_ok=True)
    
    # Get training parameters
    parameters = get_training_parameters()
    checkpoint_dir = Path(parameters['output_dir'])
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Parse configuration using TrlParser for robust handling
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_dict(parameters)
    
    set_seed(training_args.seed)
    
    # Load components using V2 initializers
    model_path, tokenizer = load_model_from_initializer()
    train_dataset, test_dataset = load_dataset_from_initializer()
    train_dataset, test_dataset = prepare_datasets(train_dataset, test_dataset, tokenizer)
    
    # Initialize trainer with advanced callbacks
    callbacks = [
        AdvancedDistributedCheckpointCallback(str(checkpoint_dir))  # Advanced distributed coordination
    ]
    
    # Initialize SFTTrainer with controller integration
    trainer = SFTTrainer(
        model=model_path,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        peft_config=get_peft_config(model_args),
        processing_class=tokenizer,
        callbacks=callbacks,
    )
    
    # Print trainable parameters info
    if trainer.accelerator.is_main_process and hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()
    
    # Check for resume from checkpoint (controller-managed)
    checkpoint = get_last_checkpoint(training_args.output_dir)
    if checkpoint is not None:
        # Validate checkpoint compatibility before resuming
        try:
            checkpoint_files = os.listdir(checkpoint)
            # Check if this is a valid checkpoint directory
            if 'trainer_state.json' not in checkpoint_files:
                checkpoint = None
        except Exception as e:
            print(f"Checkpoint validation failed: {e}")
            checkpoint = None
    
    # Start training
    try:
        trainer.train(resume_from_checkpoint=checkpoint)
    except Exception as e:
        print(f"Training failed: {e}")
        # If checkpoint loading failed, try without checkpoint
        if checkpoint is not None:
            try:
                trainer.train(resume_from_checkpoint=None)
            except Exception as retry_e:
                print(f"Training failed even from scratch: {retry_e}")
                raise retry_e
        else:
            raise
    
    # Save final model
    trainer.save_model(training_args.output_dir)

print("✅ TRL training function defined (matches trl-trainjob.yaml logic)")
print("📋 Available function: train_gpt2_with_trl_checkpointing(args)")


## 🐍 Define TRL Training Function (Identical to trl-trainjob.yaml)

This function implements the **exact same** advanced training logic as embedded in `trl-trainjob.yaml` with:
- Minimal progress file writer (no redundant callbacks)
- Advanced distributed checkpoint coordination
- Controller-managed progress tracking
- Automated model checkpointing by SIGTERM signal handling


In [None]:
# Extract the exact training script from trl-trainjob.yaml
import subprocess
import tempfile

def extract_training_script_from_yaml():
    """Extract the embedded Python script from trl-trainjob.yaml"""
    cmd = ["awk", "/advanced_trl_training.py: \\|/{flag=1; next} /^---$/{flag=0} flag", 
           "examples/checkpointing/trl-trainjob.yaml"]
    result = subprocess.run(cmd, capture_output=True, text=True, cwd="/Users/abdhumal/Dev/RedHatDev/training-operator")
    return result.stdout

# Get the exact script content
script_content = extract_training_script_from_yaml()
print(f"✅ Extracted {len(script_content)} characters from trl-trainjob.yaml embedded script")

# Execute the script in the current namespace
exec(script_content)
    from transformers import (
        AutoTokenizer,
        TrainingArguments,
        TrainerState,
        TrainerControl,
        TrainerCallback,
        set_seed,
    )
    from transformers.trainer_utils import get_last_checkpoint
    from trl import (
        ModelConfig,
        ScriptArguments,
        SFTConfig,
        SFTTrainer,
        TrlParser,
        get_peft_config,
        get_quantization_config,
        get_kbit_device_map,
    )
    
    # Safe tensor loading configuration
    torch.serialization.add_safe_globals([_reconstruct, numpy.ndarray, numpy.dtype, numpy.dtypes.UInt32DType])
    
    # Patch torch.load for device compatibility
    original_torch_load = torch.load
    def patched_torch_load(*args, **kwargs):
        if 'weights_only' not in kwargs:
            kwargs['weights_only'] = False
        if 'map_location' not in kwargs:
            if torch.cuda.is_available():
                kwargs['map_location'] = 'cuda'
            else:
                kwargs['map_location'] = 'cpu'
        return original_torch_load(*args, **kwargs)
    torch.load = patched_torch_load
    
    print("🚀 Starting advanced TRL training with controller integration...")
    
    class ControllerProgressCallback(TrainerCallback):
        """Controller-integrated progress callback for CheckpointingManager"""
        def __init__(self):
            self.progress_file = os.environ.get('TRAINING_PROGRESS_FILE', '/workspace/training_progress.json')
            self.checkpoint_enabled = os.environ.get('CHECKPOINT_ENABLED', 'false').lower() == 'true'
            print(f"📊 Controller progress tracking: file={self.progress_file}, checkpointing={self.checkpoint_enabled}")
        
        def _save_progress(self, state: TrainerState, epoch: int = None):
            rank = int(os.environ.get('RANK', '0'))
            if rank != 0:
                return
                
            try:
                progress_data = {
                    "epoch": epoch or int(state.epoch) if state.epoch else 1,
                    "totalEpochs": int(state.num_train_epochs) if state.num_train_epochs else 1,
                    "step": state.global_step,
                    "totalSteps": state.max_steps,
                    "loss": f"{state.log_history[-1].get('train_loss', 0.0):.4f}" if state.log_history else "0.0000",
                    "learningRate": f"{state.log_history[-1].get('learning_rate', 0.0):.6f}" if state.log_history else "0.000000",
                    "percentComplete": f"{(state.global_step / state.max_steps * 100):.1f}" if state.max_steps > 0 else "0.0",
                    "lastUpdateTime": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
                }
                
                # Atomic write for controller consumption
                temp_file = self.progress_file + '.tmp'
                with open(temp_file, 'w') as f:
                    json.dump(progress_data, f, indent=2)
                os.rename(temp_file, self.progress_file)
                os.chmod(self.progress_file, 0o644)
                
                print(f"📈 Progress saved: step {state.global_step}/{state.max_steps} ({progress_data['percentComplete']}%)")
                
            except Exception as e:
                print(f"Warning: Failed to save progress: {e}")
        
        def on_step_end(self, args, state, control, **kwargs):
            if state.global_step % args.logging_steps == 0:
                self._save_progress(state)
        
        def on_epoch_end(self, args, state, control, **kwargs):
            self._save_progress(state, epoch=int(state.epoch))
    
    # Additional implementation continues...
    print("✅ Advanced TRL training function defined")

print("✅ Advanced TRL training function defined")


## 🚀 Create TrainJob Using Kubeflow SDK (Matching trl-trainjob.yaml)

Now we'll use the **Kubeflow SDK** to create a TrainJob that matches the same procedure as `trl-trainjob.yaml`:
1. **CustomTrainer** with the TRL training function
2. **Initializer** for dataset and model (V2 initializers)
3. **TrainJob** with identical configuration and checkpointing


In [None]:
# Step 1: Configure the CustomTrainer with TRL training function
training_args = {
    # Training hyperparameters (matching trl-trainjob.yaml)
    "LEARNING_RATE": "5e-5",
    "BATCH_SIZE": "1",
    "MAX_EPOCHS": "10",
    "WARMUP_STEPS": "5",
    "EVAL_STEPS": "10",
    "SAVE_STEPS": "5",
    "LOGGING_STEPS": "2",
    "GRADIENT_ACCUMULATION_STEPS": "2",
    
    # Model configuration
    "MODEL_NAME": "gpt2",
    "LORA_R": "16",
    "LORA_ALPHA": "32",
    "LORA_DROPOUT": "0.1",
    "MAX_SEQ_LENGTH": "512",
    
    # Dataset configuration
    "DATASET_NAME": "tatsu-lab/alpaca",
    "DATASET_TRAIN_SPLIT": "train[:500]",
    "DATASET_TEST_SPLIT": "train[500:520]",
    
    # Checkpointing configuration (will be injected by controller)
    "CHECKPOINT_URI": "/workspace/checkpoints",
    
    # Cache directories
    "PYTHONUNBUFFERED": "1",
    "TRANSFORMERS_CACHE": "/workspace/cache/transformers",
    "HF_HOME": "/workspace/cache",
    "HF_DATASETS_CACHE": "/workspace/cache/datasets",
    
    # Distributed training debug
    "NCCL_DEBUG": "INFO",
    "NCCL_DEBUG_SUBSYS": "ALL",
    "NCCL_SOCKET_IFNAME": "eth0",
    "NCCL_IB_DISABLE": "1",
    "NCCL_P2P_DISABLE": "1",
    "NCCL_TREE_THRESHOLD": "0",
    "TORCH_DISTRIBUTED_DEBUG": "INFO",
    "TORCH_SHOW_CPP_STACKTRACES": "1",
}

# Create CustomTrainer configuration
custom_trainer = CustomTrainer(
    func=train_gpt2_with_trl_checkpointing,
    func_args=training_args,
    num_nodes=2,  # Distributed training across 2 nodes (matching YAML)
    resources_per_node={
        "cpu": "2",
        "memory": "4Gi",
        # Uncomment for GPU training:
        # "nvidia.com/gpu": "1",
    },
    packages_to_install=[
        "transformers[torch]",
        "trl", 
        "peft", 
        "datasets", 
        "accelerate",
        "torch",
        "numpy"
    ],
    env={
        "PYTHONUNBUFFERED": "1",
        "NCCL_DEBUG": "INFO",
        "TORCH_DISTRIBUTED_DEBUG": "INFO",
    }
)

print("✅ CustomTrainer configured with TRL training function")
print(f"🎯 Distributed training: {custom_trainer.num_nodes} nodes")
print(f"🔧 Environment variables: {len(training_args)} configured")
print(f"📦 Packages to install: {len(custom_trainer.packages_to_install)} packages")


In [None]:
# Step 2: Configure Initializers (matching trl-trainjob.yaml V2 initializers)
initializer = Initializer(
    dataset=HuggingFaceDatasetInitializer(
        storage_uri="hf://tatsu-lab/alpaca"
    ),
    model=HuggingFaceModelInitializer(
        storage_uri="hf://gpt2"
    )
)

print("✅ Initializer configured with V2 dataset and model initializers")
print(f"📊 Dataset: {initializer.dataset.storage_uri}")
print(f"🤖 Model: {initializer.model.storage_uri}")


In [None]:
# Step 3: Create TrainJob using Kubeflow SDK (matching trl-trainjob.yaml procedure)
job_name = trainer_client.train(
    trainer=custom_trainer,
    initializer=initializer,
    labels={
        "app.kubernetes.io/name": "trl-demo",
        "app.kubernetes.io/component": "training",
        "experiment": "advanced-controller-checkpointing"
    },
    annotations={
        "training.kubeflow.org/description": "TRL GPT-2 fine-tuning with advanced checkpointing using Kubeflow SDK"
    }
)

print(f"✅ TrainJob created successfully using Kubeflow SDK!")
print(f"📋 Job Name: {job_name}")
print(f"🎯 This TrainJob matches the same procedure as trl-trainjob.yaml:")
print(f"   - V2 Initializers for dataset and model")
print(f"   - Distributed training across 2 nodes")
print(f"   - Advanced TRL training with checkpointing")
print(f"   - Controller-managed progress tracking")


In [None]:
## 📊 Monitor Training Progress Using Kubeflow SDK

Monitor the TrainJob progress using the Kubeflow SDK - same functionality as `kubectl get trainjob`


def monitor_training_with_sdk(duration_minutes=10, interval_seconds=30):
    """Monitor training progress using Kubeflow SDK"""
    from datetime import datetime, timedelta
    
    end_time = datetime.now() + timedelta(minutes=duration_minutes)
    
    print(f"🔍 Monitoring TrainJob '{job_name}' for {duration_minutes} minutes")
    print(f"🔄 Checking every {interval_seconds} seconds using Kubeflow SDK")
    print("=" * 60)
    
    while datetime.now() < end_time:
        print(f"\n⏰ {datetime.now().strftime('%H:%M:%S')} - Checking status...")
        
        try:
            # Get TrainJob status using SDK
            trainjob = trainer_client.get_job(job_name)
            
            print(f"🎯 TrainJob Status:")
            print(f"   Name: {trainjob.name}")
            print(f"   Status: {trainjob.status}")
            print(f"   Nodes: {trainjob.num_nodes}")
            print(f"   Runtime: {trainjob.runtime.name}")
            
            if trainjob.steps:
                print(f"\n📋 Steps:")
                for step in trainjob.steps:
                    print(f"   - {step.name}: {step.status} ({step.device})")
            
            # Check if training is complete
            if trainjob.status in ['Complete', 'Failed']:
                print(f"\n🏁 Training finished with status: {trainjob.status}")
                break
                
        except Exception as e:
            print(f"⚠️  Error getting job status: {e}")
        
        print(f"\n⏳ Waiting {interval_seconds} seconds...")
        time.sleep(interval_seconds)
    
    print("\n✅ Monitoring completed")

# Start monitoring
monitor_training_with_sdk(duration_minutes=15, interval_seconds=30)


## 📜 View Training Logs Using Kubeflow SDK


In [None]:
# Get training logs using Kubeflow SDK
try:
    print(f"📜 Getting logs for TrainJob: {job_name}")
    
    # Get logs from the training nodes
    logs = trainer_client.get_job_logs(job_name, follow=False)
    
    print("\n" + "="*80)
    print("TRAINING LOGS")
    print("="*80)
    
    # Display logs from all nodes
    for node_name, node_logs in logs.items():
        print(f"\n--- {node_name.upper()} LOGS ---")
        # Display last 50 lines of logs
        log_lines = node_logs.split('\n')
        for line in log_lines[-50:]:
            if line.strip():
                print(line)
    
    print("\n" + "="*80)
    
except Exception as e:
    print(f"❌ Error getting logs: {e}")
    print("Note: Logs may not be available yet if training is still starting up")


## 🧹 Cleanup Using Kubeflow SDK


In [None]:
# Clean up the TrainJob when done
def cleanup_trainjob():
    """Clean up the TrainJob using Kubeflow SDK"""
    try:
        trainer_client.delete_job(job_name)
        print(f"✅ TrainJob '{job_name}' deleted successfully")
    except Exception as e:
        print(f"❌ Error deleting TrainJob: {e}")

# Get final job status before cleanup
try:
    final_job = trainer_client.get_job(job_name)
    print(f"📋 Final TrainJob Status:")
    print(f"   Name: {final_job.name}")
    print(f"   Status: {final_job.status}")
    print(f"   Created: {final_job.creation_timestamp}")
    print(f"   Nodes: {final_job.num_nodes}")
    print(f"   Runtime: {final_job.runtime.name}")
    
    if final_job.steps:
        print(f"   Steps:")
        for step in final_job.steps:
            print(f"     - {step.name}: {step.status}")
            
except Exception as e:
    print(f"❌ Error getting final job status: {e}")

print(f"\n💡 To delete the TrainJob, run: cleanup_trainjob()")
print(f"Current TrainJob ID: {job_name}")


## 📋 Summary

This notebook demonstrates how to use the **Kubeflow SDK** to create TrainJobs that match the same procedure as `trl-trainjob.yaml`:

### 🎯 **Key Features Implemented**:

1. **Kubeflow SDK Integration**: Used `TrainerClient` with `CustomTrainer` instead of direct Kubernetes API calls
2. **TRL-based fine-tuning** with GPT-2 and Alpaca dataset for instruction following
3. **LoRA (Low-Rank Adaptation)** for parameter-efficient fine-tuning  
4. **Advanced checkpointing** with SIGTERM handling and distributed coordination
5. **Distributed training** across multiple nodes with automatic environment setup
6. **V2 Initializers** for reproducible dataset and model preparation using `HuggingFaceDatasetInitializer` and `HuggingFaceModelInitializer`
7. **Real-time monitoring** using SDK methods instead of Kubernetes API

### 🔧 **Technical Implementation**:

- **CustomTrainer**: Encapsulates the TRL training function with all dependencies
- **Initializer**: Configures V2 dataset and model initializers (matching YAML)
- **SDK Methods**: Uses `trainer_client.train()`, `get_job()`, `get_job_logs()`, `delete_job()`
- **Distributed coordination**: SIGTERM tensor for cross-node communication
- **Automatic checkpoint resumption**: Detection and loading of latest checkpoints

### 📊 **Advantages of SDK Approach**:

- **Simplified API**: No need to manage Kubernetes manifests directly
- **Type Safety**: Python types and validation for all parameters
- **Error Handling**: Built-in retry logic and error messages
- **Monitoring**: Native support for job status and log retrieval
- **Portability**: Works across different Kubernetes clusters and configurations

### 🚀 **Usage Pattern**:

```python
# 1. Configure training function and parameters
custom_trainer = CustomTrainer(
    func=train_gpt2_with_trl_checkpointing,
    func_args=training_args,
    num_nodes=2,
    resources_per_node={"cpu": "2", "memory": "4Gi"},
    packages_to_install=["transformers[torch]", "trl", "peft", ...]
)

# 2. Configure initializers (matching YAML V2 initializers)
initializer = Initializer(
    dataset=HuggingFaceDatasetInitializer("hf://tatsu-lab/alpaca"),
    model=HuggingFaceModelInitializer("hf://gpt2")
)

# 3. Create TrainJob
job_name = trainer_client.train(
    trainer=custom_trainer,
    initializer=initializer,
    labels={"experiment": "trl-gpt2-checkpointing"}
)

# 4. Monitor and manage
trainjob = trainer_client.get_job(job_name)
logs = trainer_client.get_job_logs(job_name)
trainer_client.delete_job(job_name)
```

This approach provides the **same functionality as trl-trainjob.yaml** but with a more **Pythonic and maintainable interface** through the Kubeflow SDK.


In [None]:
def get_trainjob_status():
    """Get TrainJob status and training progress"""
    try:
        trainjob = custom_objects_api.get_namespaced_custom_object(
            group="trainer.kubeflow.org",
            version="v1alpha1",
            namespace=NAMESPACE,
            plural="trainjobs",
            name="trl-sdk-demo"
        )
        
        status = trainjob.get("status", {})
        conditions = status.get("conditions", [])
        training_progress = status.get("trainingProgress", {})
        checkpointing = status.get("checkpointing", {})
        
        print(f"🎯 TrainJob Status:")
        print(f"   Phase: {status.get('phase', 'Unknown')}")
        
        if conditions:
            latest_condition = conditions[-1]
            print(f"   Condition: {latest_condition.get('type')} - {latest_condition.get('status')}")
            if latest_condition.get('message'):
                print(f"   Message: {latest_condition.get('message')}")
        
        if training_progress:
            print(f"\n📈 Training Progress:")
            print(f"   Epoch: {training_progress.get('epoch', 0)}/{training_progress.get('totalEpochs', 0)}")
            print(f"   Step: {training_progress.get('step', 0)}/{training_progress.get('totalSteps', 0)}")
            print(f"   Loss: {training_progress.get('loss', 'N/A')}")
            print(f"   Learning Rate: {training_progress.get('learningRate', 'N/A')}")
            print(f"   Progress: {training_progress.get('percentComplete', '0')}%")
            print(f"   Last Update: {training_progress.get('lastUpdateTime', 'N/A')}")
        
        if checkpointing:
            print(f"\n💾 Checkpointing:")
            print(f"   Enabled: {checkpointing.get('enabled', False)}")
            print(f"   Checkpoints Created: {checkpointing.get('checkpointsCreated', 0)}")
            print(f"   Latest Checkpoint: {checkpointing.get('latestCheckpointTime', 'N/A')}")
        
        return trainjob
        
    except ApiException as e:
        print(f"❌ Failed to get TrainJob status: {e}")
        return None

def monitor_training(duration_minutes=10, interval_seconds=30):
    """Monitor training progress for specified duration"""
    import time
    from datetime import datetime, timedelta
    
    end_time = datetime.now() + timedelta(minutes=duration_minutes)
    
    print(f"🔍 Monitoring training for {duration_minutes} minutes (checking every {interval_seconds}s)")
    print("=" * 60)
    
    while datetime.now() < end_time:
        print(f"\n⏰ {datetime.now().strftime('%H:%M:%S')} - Checking status...")
        trainjob = get_trainjob_status()
        
        if trainjob:
            status = trainjob.get("status", {})
            phase = status.get("phase", "Unknown")
            
            if phase in ["Succeeded", "Failed"]:
                print(f"\n🏁 Training completed with status: {phase}")
                break
        
        print(f"\n⏳ Waiting {interval_seconds} seconds...")
        time.sleep(interval_seconds)
    
    print("\n✅ Monitoring completed")

# Start monitoring
monitor_training(duration_minutes=15, interval_seconds=30)


## 🧹 Cleanup Resources

Clean up the created resources when done


In [None]:
def cleanup_resources():
    """Clean up all created resources"""
    print("🧹 Cleaning up resources...")
    
    # Delete TrainJob
    try:
        custom_objects_api.delete_namespaced_custom_object(
            group="trainer.kubeflow.org",
            version="v1alpha1",
            namespace=NAMESPACE,
            plural="trainjobs",
            name="trl-sdk-demo"
        )
        print("✅ TrainJob deleted")
    except ApiException as e:
        if e.status == 404:
            print("⚠️  TrainJob not found")
        else:
            print(f"❌ Failed to delete TrainJob: {e}")
    
    # Delete TrainingRuntime
    try:
        custom_objects_api.delete_namespaced_custom_object(
            group="trainer.kubeflow.org",
            version="v1alpha1",
            namespace=NAMESPACE,
            plural="trainingruntimes",
            name="torch-cuda-251-runtime"
        )
        print("✅ TrainingRuntime deleted")
    except ApiException as e:
        if e.status == 404:
            print("⚠️  TrainingRuntime not found")
        else:
            print(f"❌ Failed to delete TrainingRuntime: {e}")
    
    # Delete ConfigMap
    try:
        core_v1.delete_namespaced_config_map(
            name="advanced-trl-script",
            namespace=NAMESPACE
        )
        print("✅ ConfigMap deleted")
    except ApiException as e:
        if e.status == 404:
            print("⚠️  ConfigMap not found")
        else:
            print(f"❌ Failed to delete ConfigMap: {e}")
    
    # Delete PVC (optional - comment out to preserve data)
    # try:
    #     core_v1.delete_namespaced_persistent_volume_claim(
    #         name="shared-checkpoint-storage",
    #         namespace=NAMESPACE
    #     )
    #     print("✅ PVC deleted")
    # except ApiException as e:
    #     if e.status == 404:
    #         print("⚠️  PVC not found")
    #     else:
    #         print(f"❌ Failed to delete PVC: {e}")
    
    print("🏁 Cleanup completed!")

# Uncomment to run cleanup
# cleanup_resources()


## 🎮 Create and Submit TrainJob


In [None]:
# Create the TrainJob using Kubeflow SDK
print(f"🚀 Creating TrainJob: {job_name}")

try:
    trainer_client.train(
        name=job_name,
        trainer=training_config,
        # Enable checkpointing (this will be handled by the enhanced controller)
        # The controller will inject CHECKPOINT_ENABLED, TRAINING_PROGRESS_FILE, etc.
    )
    print(f"✅ TrainJob '{job_name}' created successfully!")
    print(f"📊 Enhanced checkpointing and progress tracking enabled")
    print(f"🔄 Controller will poll progress every 15-60 seconds adaptively")
    
except Exception as e:
    print(f"❌ Failed to create TrainJob: {e}")
    raise


## 📊 Monitor Training Progress

Use the enhanced progress tracking capabilities to monitor training in real-time:


In [None]:
# Monitor the TrainJob status and progress
def monitor_training_progress(job_name, max_iterations=60, sleep_interval=10):
    """Monitor training progress using the enhanced checkpointing API"""
    print(f"📊 Monitoring TrainJob: {job_name}")
    print(f"🔄 Checking every {sleep_interval} seconds for up to {max_iterations} iterations")
    print("" + "="*80)
    
    for i in range(max_iterations):
        try:
            # Get TrainJob status
            job = trainer_client.get_job(job_name)
            
            print(f"\\n[{datetime.now().strftime('%H:%M:%S')}] Iteration {i+1}/{max_iterations}")
            print(f"📋 Status: {job.status}")
            
            # Check for training progress (enhanced checkpointing feature)
            if hasattr(job, 'training_progress') and job.training_progress:
                progress = job.training_progress
                print(f"📈 Progress Details:")
                if hasattr(progress, 'epoch'):
                    print(f"   📅 Epoch: {progress.epoch}/{getattr(progress, 'total_epochs', '?')}")
                if hasattr(progress, 'step'):
                    print(f"   👣 Step: {progress.step}/{getattr(progress, 'total_steps', '?')}")
                if hasattr(progress, 'loss'):
                    print(f"   📉 Loss: {progress.loss}")
                if hasattr(progress, 'learning_rate'):
                    print(f"   🎯 Learning Rate: {progress.learning_rate}")
                if hasattr(progress, 'percent_complete'):
                    print(f"   📊 Progress: {progress.percent_complete}%")
                if hasattr(progress, 'last_update_time'):
                    print(f"   🕒 Last Update: {progress.last_update_time}")
                
                # Check checkpointing status
                if hasattr(progress, 'checkpointing') and progress.checkpointing:
                    checkpoint_status = progress.checkpointing
                    print(f"💾 Checkpointing:")
                    if hasattr(checkpoint_status, 'enabled'):
                        print(f"   ✅ Enabled: {checkpoint_status.enabled}")
                    if hasattr(checkpoint_status, 'latest_checkpoint'):
                        print(f"   📁 Latest: {checkpoint_status.latest_checkpoint}")
                    if hasattr(checkpoint_status, 'checkpoints_created'):
                        print(f"   📊 Created: {checkpoint_status.checkpoints_created}")
            else:
                print("📊 No progress data available yet (controller may still be initializing)")
            
            # Check if training is complete
            if job.status in ['Succeeded', 'Failed', 'Complete']:
                print(f"\\n🎯 Training finished with status: {job.status}")
                break
                
        except Exception as e:
            print(f"⚠️  Error getting job status: {e}")
        
        if i < max_iterations - 1:  # Don't sleep on the last iteration
            time.sleep(sleep_interval)
    
    print("\\n" + "="*80)
    print("📊 Monitoring complete")

# Start monitoring
monitor_training_progress(job_name)


## 📋 Get Detailed Job Information


In [None]:
# Get comprehensive job information
try:
    job = trainer_client.get_job(job_name)
    
    print(f"📋 TrainJob Details: {job_name}")
    print("="*50)
    print(f"Status: {job.status}")
    print(f"Created: {job.creation_timestamp}")
    
    if hasattr(job, 'start_time') and job.start_time:
        print(f"Started: {job.start_time}")
    
    if hasattr(job, 'completion_time') and job.completion_time:
        print(f"Completed: {job.completion_time}")
    
    # Enhanced progress information
    if hasattr(job, 'training_progress') and job.training_progress:
        print("\\n📊 Enhanced Progress Tracking:")
        progress = job.training_progress
        
        for attr in ['epoch', 'total_epochs', 'step', 'total_steps', 'loss', 
                     'accuracy', 'learning_rate', 'percent_complete', 'last_update_time']:
            if hasattr(progress, attr):
                value = getattr(progress, attr)
                if value is not None:
                    print(f"  {attr.replace('_', ' ').title()}: {value}")
    
    print("\\n" + "="*50)
    
except Exception as e:
    print(f"❌ Error getting job details: {e}")


## 📜 View Training Logs


In [None]:
# Get training logs
try:
    print(f"📜 Getting logs for TrainJob: {job_name}")
    logs = trainer_client.get_job_logs(job_name, follow=False)
    
    print("\\n" + "="*80)
    print("TRAINING LOGS")
    print("="*80)
    
    # Display last 50 lines of logs
    log_lines = logs.split('\\n')
    for line in log_lines[-50:]:
        if line.strip():
            print(line)
    
    print("\\n" + "="*80)
    
except Exception as e:
    print(f"❌ Error getting logs: {e}")


## 🧹 Cleanup (Optional)

Clean up the TrainJob when you're done:


In [None]:
def train_gpt2_with_checkpointing(args):
    import random
    import os
    import torch
    import numpy
    from numpy.core.multiarray import _reconstruct
    import torch.serialization
    torch.serialization.add_safe_globals([_reconstruct, numpy.ndarray, numpy.dtype, numpy.dtypes.UInt32DType])
    from datetime import datetime
    import signal
    import torch.distributed as dist
    import logging
    from pathlib import Path
    from cloudpathlib import CloudPath
    
    from datasets import load_dataset
    from transformers import (
        AutoTokenizer,
        TrainingArguments,
        TrainerState,
        TrainerControl,
        TrainerCallback,
        set_seed,
    )
    from transformers.trainer_utils import get_last_checkpoint
    
    from trl import (
        ModelConfig,
        ScriptArguments,
        SFTConfig,
        SFTTrainer,
        TrlParser,
        get_peft_config,
    )

    class SigtermCheckpointCallback(TrainerCallback):
        """Advanced checkpoint callback with SIGTERM handling and distributed coordination."""
        
        def __init__(self, output_dir: str):
            self.output_dir = output_dir
            self.checkpoint_requested = False
            self.save_triggered = False
            self.checkpoint_stream = None
            self.sigterm_tensor = None

        def _log_message(self, message: str):
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            rank = os.environ.get("RANK", "0")
            print(f"[Rank {rank}] [{timestamp}] {message}")

        def _init_distributed_signal_tensor(self):
            try:
                if dist.is_initialized():
                    device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')
                    self.sigterm_tensor = torch.zeros(1, dtype=torch.float32, device=device)
                    self._log_message(f"Initialized distributed SIGTERM tensor on device: {device}")
                else:
                    self._log_message("Distributed training not initialized - using local SIGTERM handling only")
            except Exception as e:
                self._log_message(f"Failed to initialize distributed SIGTERM tensor: {e}. Using local handling only.")

        def _check_distributed_sigterm(self):
            try:
                if dist.is_initialized() and self.sigterm_tensor is not None:
                    dist.all_reduce(self.sigterm_tensor, op=dist.ReduceOp.MAX)
                    return self.sigterm_tensor.item() > 0.5
            except Exception as e:
                self._log_message(f"Distributed SIGTERM check failed: {e}. Using local signal only.")
            return self.checkpoint_requested

        def _sigterm_handler(self, signum, frame):
            self._log_message(f"SIGTERM received, flagging for checkpoint.")
            self.checkpoint_requested = True
            if self.sigterm_tensor is not None:
                self.sigterm_tensor.fill_(1.0)

        def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            rank = os.environ.get("RANK", "0")
            os.makedirs(self.output_dir, exist_ok=True)
            self._init_distributed_signal_tensor()
            
            if torch.cuda.is_available():
                self.checkpoint_stream = torch.cuda.Stream()
                self._log_message(f"Created dedicated CUDA stream for checkpointing.")

            signal.signal(signal.SIGTERM, self._sigterm_handler)
            self._log_message(f"SIGTERM signal handler registered for distributed coordination.")

            try:
                if dist.is_initialized():
                    dist.barrier()
                    self._log_message(f"Distributed coordination setup synchronized across all ranks")
            except Exception as e:
                self._log_message(f"Failed to synchronize distributed setup: {e}")

        def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            if self._check_distributed_sigterm() and not self.save_triggered:
                rank = os.environ.get("RANK", "0")
                self._log_message(f"Distributed SIGTERM detected, initiating checkpoint at step {state.global_step}.")
                self.save_triggered = True
                control.should_save = True
                control.should_training_stop = True

        def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            rank = os.environ.get("RANK", "0")
            if rank != "0":
                return
            self._log_message(f"Checkpoint save completed.")
            if self.checkpoint_requested:
                self._log_message(f"Distributed SIGTERM-triggered checkpoint save finished successfully.")

    # Setup checkpoint directory
    checkpoint_dir = Path(args.get('CHECKPOINT_DIR', '/tmp/checkpoints'))
    
    # TRL configuration parameters
    parameters = {
        'model_name_or_path': args.get('MODEL_NAME', 'gpt2'),
        'model_revision': 'main',
        'torch_dtype': 'bfloat16',
        'use_peft': True,
        'lora_r': int(args.get('LORA_R', '16')),
        'lora_alpha': int(args.get('LORA_ALPHA', '8')),
        'lora_dropout': float(args.get('LORA_DROPOUT', '0.05')),
        'lora_target_modules': ['c_attn', 'c_proj'],
        'dataset_name': args.get('DATASET_NAME', 'gsm8k'),
        'dataset_config': 'main',
        'dataset_train_split': args.get('TRAIN_SPLIT', 'train[:100]'),
        'dataset_test_split': args.get('TEST_SPLIT', 'test[:20]'),
        'max_seq_length': 512,
        'num_train_epochs': int(args.get('MAX_EPOCHS', '1')),
        'per_device_train_batch_size': int(args.get('BATCH_SIZE', '4')),
        'per_device_eval_batch_size': int(args.get('BATCH_SIZE', '4')),
        'eval_strategy': 'steps',
        'eval_steps': int(args.get('EVAL_STEPS', '25')),
        'bf16': True,
        'learning_rate': float(args.get('LEARNING_RATE', '2e-4')),
        'warmup_steps': int(args.get('WARMUP_STEPS', '10')),
        'lr_scheduler_type': 'cosine',
        'optim': 'adamw_torch',
        'max_grad_norm': 1.0,
        'seed': 42,
        'gradient_accumulation_steps': 1,
        'save_strategy': 'steps',
        'save_steps': int(args.get('SAVE_STEPS', '50')),
        'save_total_limit': 5,
        'logging_strategy': 'steps',
        'logging_steps': int(args.get('LOGGING_STEPS', '10')),
        'report_to': [],
        'output_dir': str(checkpoint_dir),
        'dataloader_pin_memory': False,
        'gradient_checkpointing': True,
    }

    # Create cache directories
    cache_dirs = [
        args.get('TRANSFORMERS_CACHE', '/tmp/transformers_cache'),
        args.get('HF_HOME', '/tmp/hf_cache'),
        args.get('HF_DATASETS_CACHE', '/tmp/hf_datasets_cache'),
        str(checkpoint_dir)
    ]
    
    for cache_dir in cache_dirs:
        os.makedirs(cache_dir, exist_ok=True)
    
    rank = os.environ.get("RANK", "0")
    world_size = os.environ.get("WORLD_SIZE", "1")
    local_rank = os.environ.get("LOCAL_RANK", "0")
    
    print(f"[Rank {rank}] Starting TRL-based distributed training with advanced checkpointing")
    print(f"[Rank {rank}] World Size: {world_size}, Local Rank: {local_rank}")
    print(f"[Rank {rank}] Checkpoints will be saved to: {checkpoint_dir}")
    print(f"[Rank {rank}] Model: {parameters['model_name_or_path']}")
    print(f"[Rank {rank}] Dataset: {parameters['dataset_name']} ({parameters['dataset_train_split']})")
    print(f"[Rank {rank}] GPU Available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"[Rank {rank}] GPU Count: {torch.cuda.device_count()}")
        print(f"[Rank {rank}] Current GPU: {torch.cuda.current_device()}")

    # Parse TRL configuration
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_dict(parameters)

    set_seed(training_args.seed)

    print(f"[Rank {rank}] Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, 
        trust_remote_code=model_args.trust_remote_code, 
        use_fast=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Add chat template for GPT-2
    if tokenizer.chat_template is None:
        tokenizer.chat_template = (
            "{% for message in messages %}"
            "{% if message['role'] == 'user' %}"
            "Question: {{ message['content'] }}\\n"
            "{% elif message['role'] == 'assistant' %}"
            "Answer: {{ message['content'] }}{{ eos_token }}\\n"
            "{% endif %}"
            "{% endfor %}"
        )

    print(f"[Rank {rank}] Loading dataset...")
    train_dataset = load_dataset(
        path=script_args.dataset_name,
        name=script_args.dataset_config,
        split=script_args.dataset_train_split,
    )
    test_dataset = None
    if training_args.eval_strategy != "no":
        test_dataset = load_dataset(
            path=script_args.dataset_name,
            name=script_args.dataset_config,
            split=script_args.dataset_test_split,
        )

    def template_dataset(sample):
        messages = [
            {"role": "user", "content": sample['question']},
            {"role": "assistant", "content": sample['answer']},
        ]
        return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}

    print(f"[Rank {rank}] Preprocessing datasets...")
    train_dataset = train_dataset.map(template_dataset, remove_columns=["question", "answer"])
    if test_dataset is not None:
        test_dataset = test_dataset.map(template_dataset, remove_columns=["question", "answer"])

    print(f"[Rank {rank}] Training samples: {len(train_dataset)}")
    if test_dataset:
        print(f"[Rank {rank}] Evaluation samples: {len(test_dataset)}")

    if rank == "0":
        print(f"[Rank {rank}] Sample training data:")
        for i in random.sample(range(len(train_dataset)), min(2, len(train_dataset))):
            print(f"Sample {i}: {train_dataset[i]['text'][:200]}...")

    print(f"[Rank {rank}] Initializing SFTTrainer with checkpointing...")
    trainer = SFTTrainer(
        model=model_args.model_name_or_path,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        peft_config=get_peft_config(model_args),
        processing_class=tokenizer,
        callbacks=[SigtermCheckpointCallback(str(checkpoint_dir))],
    )

    if trainer.accelerator.is_main_process and hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()

    # Check for existing checkpoints
    checkpoint = get_last_checkpoint(training_args.output_dir)
    if checkpoint is None:
        print(f"[Rank {rank}] No checkpoint found, starting training from scratch.")
    else:
        print(f"[Rank {rank}] Resuming from checkpoint: {checkpoint}")

    print(f"[Rank {rank}] Starting training...")
    trainer.train(resume_from_checkpoint=checkpoint)

    # Save final model
    trainer.save_model(training_args.output_dir)
    print(f"[Rank {rank}] Training completed, model checkpoint written to {training_args.output_dir}")
    
    # Upload to cloud storage if specified
    if args.get("BUCKET", None):
        print(f"[Rank {rank}] Uploading model to {args['BUCKET']}")
        (CloudPath(args["BUCKET"]) / "gpt2-gsm8k-checkpoints").upload_from(str(checkpoint_dir))


## List available training runtimes

Check what training runtimes are available in your cluster:


In [None]:
from kubeflow.trainer import TrainerClient, CustomTrainer

print("Available Training Runtimes:")
for r in TrainerClient().list_runtimes():
    print(f"Name: {r.name}, Framework: {r.trainer.framework.value}, Trainer Type: {r.trainer.trainer_type.value}")
    print(f"Entrypoint: {r.trainer.entrypoint[:3]}")
    print("---")


## Configure training parameters

Set up the training configuration with checkpointing and cloud storage options:


In [None]:
# To upload checkpoints to object storage (S3, GCS or Azure Blob Storage), 
# set the bucket with protocol, e.g., "s3://my-bucket/checkpoints"
BUCKET = None

# Training configuration
MODEL_NAME = "gpt2"
DATASET_NAME = "gsm8k"
CHECKPOINT_DIR = "/tmp/checkpoints"

# Training hyperparameters
args = {
    "BUCKET": BUCKET,
    "MODEL_NAME": MODEL_NAME,
    "DATASET_NAME": DATASET_NAME,
    "CHECKPOINT_DIR": CHECKPOINT_DIR,
    "TRAIN_SPLIT": "train[:100]",
    "TEST_SPLIT": "test[:20]",
    "LEARNING_RATE": "2e-4",
    "BATCH_SIZE": "4",
    "MAX_EPOCHS": "1",
    "WARMUP_STEPS": "10",
    "EVAL_STEPS": "25",
    "SAVE_STEPS": "50",
    "LOGGING_STEPS": "10",
    "LORA_R": "16",
    "LORA_ALPHA": "8",
    "LORA_DROPOUT": "0.05",
    "TRANSFORMERS_CACHE": "/tmp/transformers_cache",
    "HF_HOME": "/tmp/hf_cache",
    "HF_DATASETS_CACHE": "/tmp/hf_datasets_cache",
}

print("Training Configuration:")
for key, value in args.items():
    if value is not None:
        print(f"  {key}: {value}")


## Create and submit the TrainJob

Submit the training job with distributed training and checkpointing enabled:


In [None]:
job_id = TrainerClient().train(
    trainer=CustomTrainer(
        func=train_gpt2_with_checkpointing,
        func_args=args,
        num_nodes=2,  # Distributed training across 2 nodes
        packages_to_install=[
            "datasets", 
            "transformers[torch]", 
            "trl", 
            "peft", 
            "accelerate",
            "cloudpathlib[all]"
        ],
        resources_per_node={
            "cpu": "4",
            "memory": "8Gi",
            # Uncomment this to use GPU nodes for training
            # "nvidia.com/gpu": 1,
        },
        # Environment variables for distributed training debugging
        env={
            "NCCL_DEBUG": "INFO",
            "NCCL_DEBUG_SUBSYS": "ALL",
            "TORCH_DISTRIBUTED_DEBUG": "INFO",
            "PYTHONUNBUFFERED": "1",
        }
    ),
)

print(f"TrainJob submitted with ID: {job_id}")


## Monitor the TrainJob

Check the status and details of the submitted TrainJob:


In [None]:
# Check the TrainJob details
print("Recent TrainJobs:")
for job in TrainerClient().list_jobs():
    print(f"TrainJob: {job.name}, Status: {job.status}, Created at: {job.creation_timestamp}")


In [None]:
# Wait for the job to start running
import time

def wait_for_job_running(job_id, max_wait=300):
    """Wait for the TrainJob to reach running status."""
    for i in range(max_wait // 5):
        try:
            trainjob = TrainerClient().get_job(name=job_id)
            for step in trainjob.steps:
                if step.status == "Running":
                    print(f"TrainJob {job_id} is now running!")
                    return True
            print(f"Waiting for TrainJob to start running... ({i*5}s/{max_wait}s)")
            time.sleep(5)
        except Exception as e:
            print(f"Error checking job status: {e}")
            time.sleep(5)
    
    print(f"TrainJob did not start running within {max_wait} seconds")
    return False

wait_for_job_running(job_id)


In [None]:
# Show detailed job information
try:
    trainjob = TrainerClient().get_job(name=job_id)
    print(f"TrainJob Details for {job_id}:")
    print(f"  Status: {trainjob.status}")
    print(f"  Steps:")
    for step in trainjob.steps:
        print(f"    - Step: {step.name}, Status: {step.status}, Devices: {step.device} x {step.device_count}")
except Exception as e:
    print(f"Error getting job details: {e}")


## View training logs

Monitor the training progress and checkpointing in real-time:


In [None]:
# Show the TrainJob logs with checkpointing information
try:
    print(f"Streaming logs for TrainJob {job_id}...")
    print("Look for:")
    print("  - SIGTERM checkpoint callback initialization")
    print("  - Distributed training coordination")
    print("  - LoRA parameter information")
    print("  - Checkpoint saves every 50 steps")
    print("  - Evaluation every 25 steps")
    print("="*60)
    
    _ = TrainerClient().get_job_logs(name=job_id, follow=True)
except KeyboardInterrupt:
    print("\\nLog streaming interrupted by user")
except Exception as e:
    print(f"Error streaming logs: {e}")


## Wait for training completion

Use the SDK's wait functionality to monitor training completion:


In [None]:
# Wait for the TrainJob to complete (optional)
# Uncomment to wait for completion
# try:
#     print(f"Waiting for TrainJob {job_id} to complete...")
#     completed_job = TrainerClient().wait_for_job_status(
#         name=job_id,
#         status={"Complete"},
#         timeout=1800,  # 30 minutes
#         polling_interval=30
#     )
#     print(f"TrainJob {job_id} completed successfully!")
#     print(f"Final status: {completed_job.status}")
# except Exception as e:
#     print(f"Error waiting for job completion: {e}")

print("Training job monitoring setup complete. Uncomment above to wait for completion.")


## Test checkpoint resumption with Initializers

Demonstrate how to use Kubeflow SDK with initializers for reproducible training:


In [None]:
# Example of using initializers with the Kubeflow SDK
from kubeflow.trainer import Initializer, HuggingFaceDatasetInitializer, HuggingFaceModelInitializer

# Create initializer configuration
initializer = Initializer(
    dataset=HuggingFaceDatasetInitializer(
        storage_uri="hf://openai/gsm8k"
    ),
    model=HuggingFaceModelInitializer(
        storage_uri="hf://gpt2"
    )
)

# Example of resuming training with initializers (uncomment to test)
# resume_args = args.copy()
# resume_args["MAX_EPOCHS"] = "2"  # Train for more epochs
# 
# resume_job_id = TrainerClient().train(
#     trainer=CustomTrainer(
#         func=train_gpt2_with_checkpointing,
#         func_args=resume_args,
#         num_nodes=2,
#         packages_to_install=[
#             "datasets", "transformers[torch]", "trl", "peft", "accelerate", "cloudpathlib[all]"
#         ],
#         resources_per_node={
#             "cpu": "4",
#             "memory": "8Gi",
#         },
#         env={
#             "NCCL_DEBUG": "INFO",
#             "TORCH_DISTRIBUTED_DEBUG": "INFO",
#             "PYTHONUNBUFFERED": "1",
#         }
#     ),
#     initializer=initializer,  # Use initializers for reproducible training
#     labels={
#         "experiment": "gpt2-gsm8k-resume",
#         "checkpoint-strategy": "resume"
#     },
#     annotations={
#         "training.kubeflow.org/description": "Resume training from checkpoint with initializers"
#     }
# )
# 
# print(f"Resume TrainJob submitted with ID: {resume_job_id}")

print("Initializer configuration ready. Uncomment to test checkpoint resumption with initializers.")


## Checkpoint management utilities

Utility functions for managing and analyzing checkpoints:


In [None]:
def analyze_checkpoints(checkpoint_dir):
    """Analyze available checkpoints and their properties."""
    import os
    import json
    from pathlib import Path
    
    checkpoint_path = Path(checkpoint_dir)
    
    if not checkpoint_path.exists():
        print(f"Checkpoint directory {checkpoint_dir} does not exist")
        return
    
    print(f"Analyzing checkpoints in {checkpoint_dir}:")
    print("="*50)
    
    # Find all checkpoint directories
    checkpoints = []
    for item in checkpoint_path.iterdir():
        if item.is_dir() and item.name.startswith('checkpoint-'):
            checkpoints.append(item)
    
    if not checkpoints:
        print("No checkpoints found")
        return
    
    # Sort checkpoints by step number
    checkpoints.sort(key=lambda x: int(x.name.split('-')[1]))
    
    for checkpoint in checkpoints:
        step = checkpoint.name.split('-')[1]
        
        # Get checkpoint size
        total_size = sum(f.stat().st_size for f in checkpoint.rglob('*') if f.is_file())
        size_mb = total_size / (1024 * 1024)
        
        # Check for trainer state
        trainer_state_file = checkpoint / 'trainer_state.json'
        training_info = "N/A"
        
        if trainer_state_file.exists():
            try:
                with open(trainer_state_file, 'r') as f:
                    state = json.load(f)
                    epoch = state.get('epoch', 'N/A')
                    global_step = state.get('global_step', 'N/A')
                    training_info = f"Epoch: {epoch}, Step: {global_step}"
            except Exception as e:
                training_info = f"Error reading state: {e}"
        
        print(f"Checkpoint: {checkpoint.name}")
        print(f"  Size: {size_mb:.2f} MB")
        print(f"  Training Info: {training_info}")
        print(f"  Files: {len(list(checkpoint.rglob('*')))}")
        print("-" * 30)

def test_fine_tuned_model(checkpoint_path):
    """Test the fine-tuned GPT-2 model on mathematical reasoning."""
    try:
        from transformers import AutoTokenizer, AutoModelForCausalLM
        from peft import PeftModel
        import torch
        
        # Load the base model and tokenizer
        base_model = AutoModelForCausalLM.from_pretrained("gpt2")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        
        # Load the LoRA adapter
        model = PeftModel.from_pretrained(base_model, checkpoint_path)
        
        # Test questions
        test_questions = [
            "Question: If John has 5 apples and gives 2 to Mary, how many apples does John have left?\\nAnswer:",
            "Question: A rectangle has a length of 8 meters and a width of 3 meters. What is its area?\\nAnswer:",
            "Question: If a train travels 60 miles in 2 hours, what is its average speed?\\nAnswer:"
        ]
        
        print("Testing fine-tuned GPT-2 model on mathematical reasoning:")
        print("="*60)
        
        for i, question in enumerate(test_questions, 1):
            inputs = tokenizer(question, return_tensors="pt")
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=100,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id
                )
            
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            print(f"Test {i}:")
            print(f"Input: {question}")
            print(f"Output: {response[len(question):].strip()}")
            print("-" * 40)
            
    except Exception as e:
        print(f"Error testing model: {e}")
        print("Make sure training has completed and checkpoints are available.")

# Example usage
# analyze_checkpoints("/tmp/checkpoints")
# test_fine_tuned_model("/path/to/checkpoint")
print("Checkpoint analysis and model testing utilities ready")


## Clean up

Clean up resources when done with the training job:


In [None]:
# Uncomment to delete the TrainJob
# try:
#     TrainerClient().delete_job(job_id)
#     print(f"TrainJob {job_id} deleted successfully")
# except Exception as e:
#     print(f"Error deleting TrainJob: {e}")

print("To delete the TrainJob, uncomment the lines above")
print(f"Current TrainJob ID: {job_id}")

# Show final job status
try:
    final_job = TrainerClient().get_job(name=job_id)
    print(f"\\nFinal TrainJob Status: {final_job.status}")
    print("Steps:")
    for step in final_job.steps:
        print(f"  - {step.name}: {step.status}")
except Exception as e:
    print(f"Error getting final job status: {e}")


## Summary

This notebook demonstrated advanced checkpointing features with TRL-based GPT-2 fine-tuning using the Kubeflow SDK:

### 🎯 **Key Features Implemented**:

1. **TRL-based fine-tuning** with GPT-2 and GSM8K dataset for mathematical reasoning
2. **LoRA (Low-Rank Adaptation)** for parameter-efficient fine-tuning  
3. **Advanced checkpointing** with SIGTERM handling and distributed coordination
4. **Distributed training** across multiple nodes with automatic environment setup
5. **Kubeflow SDK integration** with proper error handling and monitoring
6. **Initializers support** for reproducible dataset and model preparation
7. **Checkpoint management** utilities for analysis and resumption

### 🔧 **Technical Implementation**:

- **SigtermCheckpointCallback**: Custom callback for graceful shutdown and distributed checkpointing
- **Distributed coordination**: SIGTERM tensor for cross-node communication
- **Automatic checkpoint resumption**: Detection and loading of latest checkpoints
- **Cloud storage integration**: Optional upload to S3/GCS/Azure for backup
- **Comprehensive monitoring**: Real-time logs and status tracking

### 📊 **Checkpointing Strategy**:

- **Automatic saves** every 50 training steps
- **Evaluation** every 25 steps with metrics tracking
- **SIGTERM signal handling** for graceful shutdown on interruption
- **Distributed coordination** ensures all nodes checkpoint consistently
- **Resume capability** from latest checkpoint on restart

### 🚀 **Production Features**:

- **Fault tolerance** with automatic restart and resume
- **Resource optimization** with gradient checkpointing and memory management
- **Monitoring integration** with comprehensive logging and status tracking
- **Scalability** with distributed training across multiple nodes
- **Reproducibility** with initializers and deterministic training

### 💡 **Usage Patterns**:

```python
# Basic training with checkpointing
job_id = TrainerClient().train(
    trainer=CustomTrainer(func=train_gpt2_with_checkpointing, ...),
    labels={"experiment": "gpt2-gsm8k"},
    annotations={"description": "TRL fine-tuning with checkpointing"}
)

# With initializers for reproducibility
job_id = TrainerClient().train(
    trainer=CustomTrainer(...),
    initializer=Initializer(
        dataset=HuggingFaceDatasetInitializer("hf://openai/gsm8k"),
        model=HuggingFaceModelInitializer("hf://gpt2")
    )
)

# Monitor and wait for completion
TrainerClient().wait_for_job_status(job_id, {"Complete"})
logs = TrainerClient().get_job_logs(job_id, follow=True)
```

This approach ensures **robust, fault-tolerant training** that can handle interruptions and resume seamlessly, making it suitable for production ML workflows.
