# Training Operation for Llama 3.1 Fine-tuning

This notebook contains the fine-tuning logic for Llama 3.1 models that will be embedded into the KFP pipeline.

## Workflow
1. Edit and test your training code in this notebook
2. Build the pipeline using the build script
3. Submit the generated pipeline YAML to your pipeline orchestrator

## Key Function
The main `trainOp()` function is what gets embedded into the KFP component.

In [None]:
from typing import Dict, List, Any

In [None]:
def trainOp(data_name: str = 'financial_sentiment_data.jsonl',
            data_source: str = '@auto-populate-data-source',
            data_bucket_prefix: str = '@auto-populate-object-prefix', 
            model_relative_path: str = 'model', 
            model_name: str = 'llama31-financial-sentiment',
            model_id_or_path: str = 'meta-llama/Meta-Llama-3.1-8B-Instruct',
            model_cache_dir: str = '@auto-populate-modelcache-path',
            num_train_epochs: float = 1,
            per_device_train_batch_size: int = 4, 
            gradient_accumulation_steps: int = 4, 
            learning_rate: float = 2e-4,
            max_length: int = 1024,
            lora_rank: int = 16,
            lora_alpha: int = 32,
            precision: str = 'bf16',
            use_quantization: bool = True,
            full_finetune: bool = False,
            quantization_type: str = 'nf4',
            model_version: str = '@auto-timestamp',
            kfp_output_path: str = None):
    """
    Main entry point for fine-tuning and saving a LLM model.
    
    Args:
        data_name (str): Relative name for input data, default 'financial_sentiment_data.jsonl'
        data_source (str): Data source ('minio' or 'gcs'), default '@auto-populate-data-source' (auto-populated by caller based on environment)
        data_bucket_prefix (str): Object storage path prefix (default: '@auto-populate-object-prefix' - will be auto-populated by caller)
                                 Format: gs://{bucket-name}/{project-id}/ or s3://{bucket-name}/{project-id}/
                                 Usage: data_bucket_prefix + data_name (e.g., 'gs://my-bucket/my-project/financial_sentiment_data.jsonl')
        model_relative_path (str): Relative path for model storage, default 'model'
        model_name (str): Name of the model, default 'llama31-financial-sentiment'
        model_id_or_path (str): Relative path for model path, default '/mnt/pretrained-models/'
        num_train_epochs (float): Number of training epochs, default 1
        per_device_train_batch_size (int): Number of train batch size, default 4
        gradient_accumulation_steps (int): Number of gradient accumulation steps, default 4
        learning_rate (float): Number of lr, default 2e-4
        max_length (int): The maximum length the generated tokens can have, default 1024
        lora_rank (int): Lora attention dimension (the "rank"), default 16
        lora_alpha (int): The alpha parameter for Lora scaling, default 32
        precision (str): Training precision (fp32, bf16, or fp16), default 'bf16'
        use_quantization (bool): Use 4-bit quantization (useful for smaller GPUs), default True
        full_finetune (bool): Do full fine-tuning instead of LoRA (requires more VRAM), default False
        quantization_type (str): Quantization type, default 'nf4'
        model_version (str): Model version number (default: '@auto-timestamp' - will be auto-generated by caller)
        kfp_output_path (str): KFP v2 Output[Model] artifact path (auto-uploaded to GCS on Vertex AI)
    
    Returns:
        None
    """
    import os
    import json
    import logging
    import subprocess
    import numpy as np
    import torch
    import zipfile
    from typing import Dict, List, Any
    
    home = '/home/jovyan'
    input_dir = os.path.join(home, 'data')
    
    # Create input directory
    os.makedirs(input_dir, exist_ok=True)
    
    # Set up logging
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)
    
    # Use KFP output path if provided (for Vertex AI auto-upload to GCS)
    # Otherwise use local path (for local testing/runs)
    if kfp_output_path:
        output_dir = kfp_output_path
        logger.info(f"Using KFP Output artifact path: {output_dir}")
        logger.info("Model will be automatically uploaded to GCS by Vertex AI")
    else:
        output_dir = os.path.join(home, model_relative_path, model_name)
        logger.info(f"Using local path: {output_dir}")

    # ========== GPU DEVICE AVAILABILITY CHECK ==========
    print("=" * 80)
    print("CHECKING GPU DEVICE AVAILABILITY")
    print("=" * 80)

    # Check /dev for GPU devices
    try:
        result = subprocess.run(['ls', '-la', '/dev'], capture_output=True, text=True)
        gpu_devices = [line for line in result.stdout.split('\n') if 'nvidia' in line.lower()]
        if gpu_devices:
            print(f"✓ Found {len(gpu_devices)} NVIDIA device(s):")
            for dev in gpu_devices:
                print(f"  {dev}")
        else:
            print("✗ NO NVIDIA devices found in /dev!")
            print("This means GPU is NOT mounted by Vertex AI")
    except Exception as e:
        print(f"Error checking devices: {e}")

    # Check environment
    print("\nNVIDIA environment variables:")
    for key in ['NVIDIA_VISIBLE_DEVICES', 'NVIDIA_DRIVER_CAPABILITIES', 'NVIDIA_REQUIRE_CUDA']:
        print(f"  {key}={os.environ.get(key, 'NOT SET')}")

    # ========== GPU DEBUGGING INFORMATION ==========
    logger.info("=" * 80)
    logger.info("GPU INFORMATION")
    logger.info("=" * 80)

    # Basic CUDA availability
    logger.info(f"CUDA available: {torch.cuda.is_available()}")

    if torch.cuda.is_available():
        # Number of GPUs
        gpu_count = torch.cuda.device_count()
        logger.info(f"Number of GPUs detected: {gpu_count}")

        # Current GPU
        current_device = torch.cuda.current_device()
        logger.info(f"Current CUDA device: {current_device}")

        # Detailed info for each GPU
        for i in range(gpu_count):
            logger.info(f"\n--- GPU {i} Details ---")
            logger.info(f"  Name: {torch.cuda.get_device_name(i)}")

            # Get device properties
            props = torch.cuda.get_device_properties(i)
            logger.info(f"  Compute Capability: {props.major}.{props.minor}")
            logger.info(f"  Total Memory: {props.total_memory / 1024**3:.2f} GB")
            logger.info(f"  Multi Processor Count: {props.multi_processor_count}")

            # Memory info
            logger.info(f"  Memory Allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
            logger.info(f"  Memory Reserved: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")
            logger.info(f"  Max Memory Allocated: {torch.cuda.max_memory_allocated(i) / 1024**3:.2f} GB")
            logger.info(f"  Max Memory Reserved: {torch.cuda.max_memory_reserved(i) / 1024**3:.2f} GB")

        # CUDA version
        logger.info(f"\nCUDA Version: {torch.version.cuda}")
        logger.info(f"cuDNN Version: {torch.backends.cudnn.version()}")
        logger.info(f"cuDNN Enabled: {torch.backends.cudnn.enabled}")

        # Try running nvidia-smi
        logger.info("\n--- NVIDIA-SMI Output ---")
        try:
            result = subprocess.run(['nvidia-smi'],
                                  capture_output=True,
                                  text=True,
                                  timeout=10)
            if result.returncode == 0:
                logger.info(f"\n{result.stdout}")
            else:
                logger.warning(f"nvidia-smi failed with code {result.returncode}: {result.stderr}")
        except FileNotFoundError:
            logger.warning("nvidia-smi command not found")
        except subprocess.TimeoutExpired:
            logger.warning("nvidia-smi command timed out")
        except Exception as e:
            logger.warning(f"Failed to run nvidia-smi: {e}")

        # Check GPU visibility from environment variables
        logger.info("\n--- GPU Environment Variables ---")
        gpu_env_vars = ['CUDA_VISIBLE_DEVICES', 'NVIDIA_VISIBLE_DEVICES',
                       'CUDA_DEVICE_ORDER', 'CUDA_LAUNCH_BLOCKING']
        for var in gpu_env_vars:
            value = os.environ.get(var, 'Not set')
            logger.info(f"  {var}: {value}")
    else:
        logger.warning("CUDA is not available! Training will run on CPU (very slow).")
        logger.info("\nPossible issues:")
        logger.info("  1. No GPU present in the container")
        logger.info("  2. NVIDIA drivers not installed on host")
        logger.info("  3. NVIDIA Container Toolkit not configured")
        logger.info("  4. GPU not mounted/passed to container")

        # Check for nvidia-smi anyway to see if drivers are present
        try:
            result = subprocess.run(['nvidia-smi'],
                                  capture_output=True,
                                  text=True,
                                  timeout=10)
            if result.returncode == 0:
                logger.info("\nnvidia-smi is available but PyTorch can't see CUDA!")
                logger.info("This usually means PyTorch was not built with CUDA support.")
                logger.info(f"\n{result.stdout}")
            else:
                logger.info("\nnvidia-smi not available or failed")
        except:
            logger.info("\nnvidia-smi command not found")

    logger.info("=" * 80)
    logger.info(f"Model ID or path: {model_id_or_path}")
    logger.info(f"Model cache directory: {model_cache_dir}")
    logger.info(f"Model version: {model_version}")
    logger.info(f"Data source: {data_source}")
    logger.info(f"Data bucket prefix: {data_bucket_prefix}")
    logger.info("=" * 80)
    
    def download_data():
        """Download fine-tuning data from MinIO or GCS."""
        if data_source == 'gcs':
            # Download from GCS using google-cloud-storage Python library
            # Construct full GCS path: data_bucket_prefix + data_name
            gcs_full_path = f"{data_bucket_prefix.rstrip('/')}/{data_name}"
            local_path = os.path.join(input_dir, data_name)

            logger.info(f'Downloading data from GCS: {gcs_full_path}')
            logger.info(f'Destination: {local_path}')

            try:
                from google.cloud import storage

                # Parse gs://bucket-name/path/to/file
                if not gcs_full_path.startswith('gs://'):
                    raise ValueError(f"Invalid GCS path: {gcs_full_path}")

                path_parts = gcs_full_path[5:].split('/', 1)  # Remove 'gs://' and split
                bucket_name = path_parts[0]
                blob_name = path_parts[1] if len(path_parts) > 1 else ''

                logger.info(f"Bucket: {bucket_name}, Blob: {blob_name}")

                # Initialize GCS client (uses Application Default Credentials)
                client = storage.Client()
                bucket = client.bucket(bucket_name)
                blob = bucket.blob(blob_name)

                # Download the file
                blob.download_to_filename(local_path)
                logger.info('Download from GCS completed successfully')

            except ImportError:
                logger.error('google-cloud-storage library not found. Install with: pip install google-cloud-storage')
                raise
            except Exception as e:
                logger.error(f'Failed to download from GCS: {e}')
                raise
        else:
            # Download from MinIO (default)
            from tintin.file.minio import FileManager as FileManager
            logger.info(f'Downloading data from MinIO: {data_name}')
            logger.info(f'Destination: {input_dir}')
            debug = False
            recursive = False
            mgr = FileManager('', debug)
            mgr.download(input_dir, [data_name], recursive)
            logger.info('Download from MinIO completed')

    def load_jsonl_data(file_path: str) -> List[Dict[str, Any]]:
        """Load training examples from a jsonl file."""
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
        logger.info(f"Loaded {len(data)} examples from {file_path}")
        return data
    
    def format_instruction(example: Dict[str, str]) -> str:
        """Format the instruction, input, and output for Llama-3.1 chat format."""
        system_prompt = "You are a financial analyst who specializes in detecting emotional subtext in earnings calls."
        
        # Get the text fields and clean them
        instruction = str(example.get('instruction', '')).strip()
        input_text = str(example.get('input', '')).strip()
        output_text = str(example.get('output', '')).strip()
        
        # Create the formatted text - DO NOT add <|begin_of_text|> here, tokenizer will add it
        formatted = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{instruction}\n\n{input_text}<|im_end|>\n<|im_start|>assistant\n{output_text}<|im_end|>"
        
        return formatted
    
    def train():
        # Log configuration
        logger.info(f"Configuration:")
        logger.info(f"  Precision: {precision}")
        logger.info(f"  Quantization: {'Yes' if use_quantization else 'No'}")
        logger.info(f"  Training type: {'Full fine-tune' if full_finetune else 'LoRA'}")
        logger.info(f"  Batch size: {per_device_train_batch_size}")
        logger.info(f"  Gradient accumulation: {gradient_accumulation_steps}")

        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        download_data()
        input_file = os.path.join(input_dir, data_name)
        
        # Load data
        logger.info("Loading data...")
        examples = load_jsonl_data(input_file)
        
        # Filter examples with output
        valid_examples = [ex for ex in examples if ex.get('output')]
        logger.info(f"Found {len(valid_examples)} valid examples with output")
        
        # Take a subset for testing
        if len(valid_examples) > 10000:
            valid_examples = valid_examples[:10000]
            logger.info(f"Using subset of {len(valid_examples)} examples for testing")
        
        # Format examples
        logger.info("Formatting examples...")
        formatted_texts = []
        for i, example in enumerate(valid_examples):
            try:
                formatted = format_instruction(example)
                formatted_texts.append(formatted)
            except Exception as e:
                logger.warning(f"Failed to format example {i}: {e}")
                continue
        
        logger.info(f"Successfully formatted {len(formatted_texts)} examples")
        
        # Show first example
        if formatted_texts:
            logger.info(f"First example (first 300 chars):\n{formatted_texts[0][:300]}...")
        
        # Load tokenizer and model
        logger.info("Loading tokenizer and model...")
        from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
        import torch
        
        # Get the actual model cache path from environment variable
        # Environment variable pattern: MODEL_CACHE_PATH_META_LLAMA_META_LLAMA_3_1_8B_INSTRUCT
        env_var_name = f"MODEL_CACHE_PATH_{model_id_or_path.upper().replace('/', '_').replace('-', '_').replace('.', '_')}"
        model_cache_path = os.getenv(env_var_name) or os.getenv("MODEL_CACHE_DIR")
        
        logger.info(f"Checking environment variable: {env_var_name}")
        logger.info(f"MODEL_CACHE_PATH from env: {model_cache_path}")
        
        if model_cache_path:
            actual_model_path = model_cache_path
            logger.info(f"Using model cache path from environment variable: {actual_model_path}")
        else:
            # Fallback to old behavior
            if model_cache_dir != '@auto-populate-modelcache-path':
                actual_model_path = os.path.join(model_cache_dir, model_id_or_path) if not model_id_or_path.startswith('/') else model_id_or_path
                logger.info(f"Loading model from cache dir parameter: {actual_model_path}")
            else:
                # Download from HuggingFace
                actual_model_path = model_id_or_path
                logger.info(f"Downloading model from HuggingFace: {actual_model_path}")
        
        # Debug: Check what's available in the parent directories
        if actual_model_path.startswith('/gcs/'):
            # List GCS mount root
            gcs_root = '/gcs'
            if os.path.exists(gcs_root):
                logger.info(f"GCS mount root exists: {gcs_root}")
                try:
                    buckets = os.listdir(gcs_root)
                    logger.info(f"Available buckets: {buckets}")
                except Exception as e:
                    logger.warning(f"Cannot list GCS root: {e}")
            else:
                logger.error(f"GCS mount root does not exist: {gcs_root}")
            
            # List model cache directory
            cache_dir_parts = actual_model_path.split('/')
            if len(cache_dir_parts) >= 4:
                # /gcs/bucket-name/model-cache
                bucket_path = '/'.join(cache_dir_parts[:3])  # /gcs/bucket-name
                cache_path = '/'.join(cache_dir_parts[:4])   # /gcs/bucket-name/model-cache
                
                if os.path.exists(bucket_path):
                    logger.info(f"Bucket path exists: {bucket_path}")
                    try:
                        dirs = os.listdir(bucket_path)
                        logger.info(f"Contents of {bucket_path}: {dirs}")
                    except Exception as e:
                        logger.warning(f"Cannot list {bucket_path}: {e}")
                
                if os.path.exists(cache_path):
                    logger.info(f"Cache path exists: {cache_path}")
                    try:
                        dirs = os.listdir(cache_path)
                        logger.info(f"Contents of {cache_path}: {dirs}")
                        
                        # Check meta-llama subdirectory if it exists
                        meta_llama_path = os.path.join(cache_path, 'meta-llama')
                        if os.path.exists(meta_llama_path):
                            logger.info(f"meta-llama directory exists: {meta_llama_path}")
                            try:
                                meta_llama_contents = os.listdir(meta_llama_path)
                                logger.info(f"Contents of {meta_llama_path}: {meta_llama_contents}")
                            except Exception as e:
                                logger.warning(f"Cannot list {meta_llama_path}: {e}")
                        else:
                            logger.warning(f"meta-llama directory does not exist at: {meta_llama_path}")
                    except Exception as e:
                        logger.warning(f"Cannot list {cache_path}: {e}")
        
        # Validate that model path exists and is accessible
        if not os.path.exists(actual_model_path):
            error_msg = f"Model path does not exist: {actual_model_path}"
            if actual_model_path.startswith('/gcs/'):
                error_msg += "\n  Possible issues:"
                error_msg += "\n  1. GCS bucket not mounted (check CSI driver or gcsfuse)"
                error_msg += "\n  2. Model not present in the bucket"
                error_msg += "\n  3. Incorrect path (check model_cache_dir and model_id_or_path)"
            logger.error(error_msg)
            raise FileNotFoundError(error_msg)
        
        # Check if directory contains model files
        model_files = os.listdir(actual_model_path) if os.path.isdir(actual_model_path) else []
        logger.info(f"Found {len(model_files)} files/directories in model path")
        if model_files:
            logger.info(f"Sample files: {model_files[:5]}")
        
        # Check for essential model files
        essential_files = ['config.json', 'tokenizer_config.json']
        missing_files = [f for f in essential_files if not os.path.exists(os.path.join(actual_model_path, f))]
        if missing_files:
            logger.warning(f"Missing essential files: {missing_files}")
            logger.warning("Model loading may fail or fall back to downloading from HuggingFace")
        
        tokenizer = AutoTokenizer.from_pretrained(
            actual_model_path, 
            trust_remote_code=True
        )
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"
        
        # Determine torch dtype based on precision
        if precision == "fp32":
            dtype = torch.float32
            logger.info("Using FP32 precision")
        elif precision == "fp16":
            dtype = torch.float16
            logger.info("Using FP16 precision")
        else:  # bf16
            dtype = torch.bfloat16
            logger.info("Using BF16 precision")
        
        # Setup model loading arguments
        # NOTE: device_map="auto" causes meta tensor errors with Trainer
        # Only use device_map for quantized models (required for quantization)
        model_kwargs = {
            "trust_remote_code": True,
            "torch_dtype": dtype,
        }
        
        # Add quantization config if requested
        # Only nf4 quantization is supported (via bitsandbytes)
        if use_quantization and quantization_type == 'nf4':
            logger.info("Using 4-bit quantization...")
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=dtype,
                bnb_4bit_use_double_quant=True,
            )
            model_kwargs["quantization_config"] = bnb_config
            # device_map is required for quantized models
            model_kwargs["device_map"] = "auto"
            logger.info("Using device_map='auto' for quantized model")
        else:
            logger.info("Loading model without quantization...")
            # For non-quantized models, let Trainer handle device placement
            # Using low_cpu_mem_usage to reduce memory during loading
            model_kwargs["low_cpu_mem_usage"] = True
        
        model = AutoModelForCausalLM.from_pretrained(
            actual_model_path,
            **model_kwargs
        )
        

        
        # Setup PEFT (only if not doing full fine-tuning)
        if not full_finetune:
            logger.info("Setting up LoRA...")
            from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, TaskType
            
            model.config.use_cache = False
            
            # Prepare for k-bit training if using quantization
            if use_quantization and quantization_type == 'nf4':
                model = prepare_model_for_kbit_training(model)
            else:
                # For non-quantized models, enable gradient checkpointing
                model.gradient_checkpointing_enable()
            
            peft_config = LoraConfig(
                r=lora_rank,
                lora_alpha=lora_alpha,
                target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type=TaskType.CAUSAL_LM,
            )
            
            model = get_peft_model(model, peft_config)
            model.print_trainable_parameters()
        else:
            logger.info("Setting up full fine-tuning...")
            model.config.use_cache = False
            model.gradient_checkpointing_enable()
            
            # Count total parameters
            total_params = sum(p.numel() for p in model.parameters())
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            logger.info(f"Total parameters: {total_params:,}")
            logger.info(f"Trainable parameters: {trainable_params:,}")
        
        # Tokenize data
        logger.info("Tokenizing data...")
        from datasets import Dataset
        
        def tokenize_function(examples):
            # Tokenize the texts
            model_inputs = tokenizer(
                examples["text"],
                truncation=True,
                max_length=max_length,
                padding=False,
                return_tensors=None,
                add_special_tokens=True,
            )
            
            # For causal language modeling, labels are the same as input_ids
            model_inputs["labels"] = model_inputs["input_ids"].copy()
            return model_inputs
        
        # Create dataset and tokenize
        dataset = Dataset.from_dict({"text": formatted_texts})
        tokenized_dataset = dataset.map(
            tokenize_function, 
            batched=True, 
            remove_columns=["text"],
            desc="Tokenizing data"
        )
        
        # Filter out sequences that are too short (less than 10 tokens)
        def filter_short_sequences(example):
            return len(example["input_ids"]) >= 10
        
        tokenized_dataset = tokenized_dataset.filter(filter_short_sequences)
        logger.info(f"Dataset size after filtering short sequences: {len(tokenized_dataset)}")
        
        # Check tokenized data
        logger.info(f"Tokenized dataset size: {len(tokenized_dataset)}")
        if len(tokenized_dataset) > 0:
            # Check sequence lengths
            lengths = [len(example["input_ids"]) for example in tokenized_dataset]
            logger.info(f"Sequence length stats: min={min(lengths)}, max={max(lengths)}, avg={sum(lengths)/len(lengths):.1f}")
            
            sample = tokenized_dataset[0]
            logger.info(f"Sample tokenized length: {len(sample['input_ids'])}")
            # Check for duplicate begin tokens
            input_ids = sample['input_ids']
            if len(input_ids) > 1 and input_ids[0] == input_ids[1] == 128000:
                logger.warning("Found duplicate <|begin_of_text|> tokens!")
            
            # Decode to check format
            decoded = tokenizer.decode(input_ids[:100], skip_special_tokens=False)
            logger.info(f"Sample decoded (first 100 tokens): {decoded}")
        
        if len(tokenized_dataset) == 0:
            logger.error("No valid tokenized examples found!")
            return
        
        # Setup training
        logger.info("Setting up training...")
        from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
        
        # Configure training arguments based on precision
        training_kwargs = {
            "output_dir": output_dir,
            "num_train_epochs": num_train_epochs,
            "per_device_train_batch_size": per_device_train_batch_size,
            "gradient_accumulation_steps": gradient_accumulation_steps,
            "learning_rate": learning_rate,
            "logging_steps": 10,
            "save_steps": 500,
            "gradient_checkpointing": True,
            "report_to": "none",
        }
        
        # Set precision-specific arguments
        if precision == "fp32":
            # No special precision flags for FP32
            pass
        elif precision == "fp16":
            training_kwargs["fp16"] = True
        else:  # bf16
            training_kwargs["bf16"] = True
        
        training_args = TrainingArguments(**training_kwargs)
        
        # Custom data collator that handles variable length sequences properly
        from transformers.data.data_collator import DataCollatorMixin
        from dataclasses import dataclass
        from typing import Any, Dict, List, Union
        import torch
        
        @dataclass
        class DataCollatorForCausalLM(DataCollatorMixin):
            """
            Data collator for causal language modeling that properly handles padding and labels.
            """
            tokenizer: Any
            pad_to_multiple_of: int = None
            return_tensors: str = "pt"
            
            def torch_call(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
                # Handle the input_ids and labels
                batch = {}
                
                # Get all input_ids and labels
                input_ids = [example["input_ids"] for example in examples]
                labels = [example["labels"] for example in examples]
                
                # Pad sequences to the same length
                batch["input_ids"] = self._pad_sequences(input_ids, self.tokenizer.pad_token_id)
                batch["labels"] = self._pad_sequences(labels, -100)  # -100 is ignored in CrossEntropy loss
                
                # Create attention mask
                batch["attention_mask"] = (batch["input_ids"] != self.tokenizer.pad_token_id).long()
                
                return batch
            
            def _pad_sequences(self, sequences: List[List[int]], pad_value: int) -> torch.Tensor:
                """Pad sequences to the same length."""
                max_length = max(len(seq) for seq in sequences)
                
                # Pad to multiple if specified
                if self.pad_to_multiple_of is not None:
                    max_length = ((max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of) * self.pad_to_multiple_of
                
                padded_sequences = []
                for seq in sequences:
                    padded_seq = seq + [pad_value] * (max_length - len(seq))
                    padded_sequences.append(padded_seq)
                
                return torch.tensor(padded_sequences, dtype=torch.long)
        
        data_collator = DataCollatorForCausalLM(
            tokenizer=tokenizer,
            pad_to_multiple_of=8,
            return_tensors="pt"
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
            data_collator=data_collator,
        )
        
        # Check training setup
        train_dataloader = trainer.get_train_dataloader()
        logger.info(f"Training dataloader batches: {len(train_dataloader)}")
        logger.info(f"Total training steps: {len(train_dataloader) * num_train_epochs}")
        
        # Print memory info if using CUDA
        if torch.cuda.is_available():
            logger.info("\n--- GPU Memory Before Training ---")
            for i in range(torch.cuda.device_count()):
                logger.info(f"GPU {i} - Allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
                logger.info(f"GPU {i} - Reserved: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")
        
        if len(train_dataloader) == 0:
            logger.error("No training batches! Check your data and batch size.")
            return
        
        # Start training
        logger.info("Starting training...")
        try:
            result = trainer.train()
            logger.info("Training completed successfully!")
            
            # Save model
            trainer.save_model()
            logger.info(f"Model saved to {output_dir}")
            
            # Print final memory usage
            if torch.cuda.is_available():
                logger.info("\n--- GPU Memory After Training ---")
                for i in range(torch.cuda.device_count()):
                    logger.info(f"GPU {i} - Allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
                    logger.info(f"GPU {i} - Reserved: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")
                    logger.info(f"GPU {i} - Max Allocated: {torch.cuda.max_memory_allocated(i) / 1024**3:.2f} GB")
                    logger.info(f"GPU {i} - Max Reserved: {torch.cuda.max_memory_reserved(i) / 1024**3:.2f} GB")
            
            # Print metrics
            if result and hasattr(result, 'metrics'):
                logger.info(f"Final metrics: {result.metrics}")
            
        except Exception as e:
            logger.error(f"Training failed: {e}")
            import traceback
            traceback.print_exc()
            raise
    
    train()
    return None


## Testing

Uncomment and run the cell below to test the trainOp function locally:

In [None]:
# Test trainOp locally (uncomment to run)
# trainOp(
#     data_name='financial_sentiment_data-20250715.jsonl',
#     model_relative_path='model',
#     model_name='llama31-financial-sentiment',
#     model_dir='/mnt/pretrained-models/',
#     num_train_epochs=0.1,  # Use fractional epoch for quick testing
#     per_device_train_batch_size=2,
#     gradient_accumulation_steps=2,
#     learning_rate=2e-4,
#     max_length=512,  # Shorter for testing
#     lora_rank=8,  # Smaller for testing
#     lora_alpha=16,
#     precision='bf16',
#     use_quantization=True,
#     full_finetune=False,
#     quantization_type='nf4'
# )