# Training Phi-3-mini-128k-instruct on Swift Code (TPU Fixed Version)

This notebook fine-tunes the Microsoft Phi-3-mini-128k-instruct model on Swift code to learn the Swift programming language. The notebook supports training on TPU, GPU, or CPU with appropriate optimizations for each hardware type. This version fixes TPU compatibility issues.

In [None]:
# Install required libraries with specific versions for compatibility
!pip install transformers==4.38.2 datasets==2.16.1 evaluate==0.4.1 torch==2.1.2 scikit-learn==1.4.0 tqdm==4.66.1 accelerate==0.27.2 peft==0.7.1 bitsandbytes==0.41.3 psutil==5.9.8

# Install TPU-specific libraries if needed
try:
    import torch_xla
    print("PyTorch XLA already installed")
except ImportError:
    print("Installing PyTorch XLA for TPU support...")
    !pip install cloud-tpu-client torch_xla
    
# Verify installations
print("\nVerifying installations:")
!python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
!python -c "import transformers; print(f'Transformers version: {transformers.__version__}')"
!python -c "import peft; print(f'PEFT version: {peft.__version__}')"
!python -c "import bitsandbytes as bnb; print(f'BitsAndBytes version: {bnb.__version__}')"
try:
    !python -c "import torch_xla; print(f'PyTorch XLA version: {torch_xla.__version__}')"
except:
    print("PyTorch XLA not installed (only needed for TPU)")

In [None]:
# Important: Comprehensive imports for Phi-3 training with GPU/TPU support
# System and utility imports
import os
import sys
import json
import time
import gc
import re
import logging
import traceback
import collections
import warnings
import psutil  # For memory monitoring
from typing import Dict, List, Optional, Union, Tuple, Any
from tqdm.auto import tqdm

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Data processing imports
import random
import numpy as np
import pandas as pd

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler

# HuggingFace imports
from datasets import load_dataset, ClassLabel, Dataset as HFDataset, DatasetDict
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    Trainer, 
    TrainingArguments,
    set_seed,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    BitsAndBytesConfig,
    TrainerCallback,
    TrainerState,
    TrainerControl,
    TrainingArguments,
    logging as transformers_logging
)

# PEFT (Parameter-Efficient Fine-Tuning) imports
from peft import (
    LoraConfig, 
    get_peft_model, 
    prepare_model_for_kbit_training,
    PeftModel,
    PeftConfig
    # TaskType is not imported directly to avoid compatibility issues
)

# BitsAndBytes for quantization
from bitsandbytes.nn import LinearNF4, Linear8bitLt

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
transformers_logging.set_verbosity_info()

# Set a seed for reproducibility
set_seed(42)

# Custom callback for detailed training monitoring
class ResourceMonitorCallback(TrainerCallback):
    """Callback to monitor system resources during training"""
    def __init__(self, log_every=100):
        self.log_every = log_every
        self.step = 0
    
    def on_step_end(self, args, state, control, **kwargs):
        self.step += 1
        if self.step % self.log_every == 0:
            process = psutil.Process(os.getpid())
            memory_info = process.memory_info()
            mem = psutil.virtual_memory()
            cpu_percent = psutil.cpu_percent(interval=0.1)
            
            logger.info(f"Step {self.step} - CPU: {cpu_percent}% | RAM: {memory_info.rss / 1024 / 1024:.2f} MB | System RAM: {mem.percent}% used")
            
            # GPU monitoring if available
            if torch.cuda.is_available():
                for i in range(torch.cuda.device_count()):
                    gpu_mem_alloc = torch.cuda.memory_allocated(i) / 1024 / 1024
                    gpu_mem_reserved = torch.cuda.memory_reserved(i) / 1024 / 1024
                    logger.info(f"GPU {i}: Allocated: {gpu_mem_alloc:.2f} MB | Reserved: {gpu_mem_reserved:.2f} MB")

# Enhanced memory management function with detailed reporting
def cleanup_memory():
    """Force garbage collection and clear hardware accelerator cache if available."""
    # Get memory usage before cleanup
    before = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
    
    # Perform cleanup
    gc.collect()
    
    # Clear GPU cache if available
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            gpu_before = torch.cuda.memory_allocated(i) / 1024 / 1024
            torch.cuda.empty_cache()
            gpu_after = torch.cuda.memory_allocated(i) / 1024 / 1024
            print(f"GPU {i} memory cleaned: {gpu_before:.2f} MB → {gpu_after:.2f} MB, Freed: {gpu_before - gpu_after:.2f} MB")
    
    # Try to clear TPU cache if available
    try:
        import torch_xla.core.xla_model as xm
        xm.mark_step()
        print("TPU cache cleared with mark_step()")
    except (ImportError, NameError):
        pass  # TPU not available
    
    # Get memory usage after cleanup
    after = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
    print(f"Process memory cleaned: {before:.2f} MB → {after:.2f} MB, Freed: {before - after:.2f} MB")
    
    # Print system memory info
    mem = psutil.virtual_memory()
    print(f"System memory: {mem.percent}% used, {mem.available / 1024 / 1024:.2f} MB available")
    
    return before - after  # Return amount of memory freed

In [None]:
# Check for available hardware accelerators (GPU, TPU)
import torch
import os

# Function to detect and configure the device
def setup_device():
    # Check for TPU
    try:
        import torch_xla
        import torch_xla.core.xla_model as xm
        print("TPU is available. Setting up TPU device...")
        device = xm.xla_device()
        print(f"Using TPU: {device}")
        return device, "tpu"
    except (ImportError, EnvironmentError):
        print("TPU not available or not properly configured.")
    
    # Check for GPU
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"Number of GPUs available: {torch.cuda.device_count()}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        return device, "gpu"
    
    # Fallback to CPU
    print("No GPU or TPU detected. Using CPU - Note: Training will be much slower on CPU")
    return torch.device('cpu'), "cpu"

# Set up device
device, device_type = setup_device()

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device_type == "gpu":
    torch.cuda.manual_seed_all(SEED)

# Set environment variables for better performance
if device_type == "gpu":
    # For GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use first GPU by default
    torch.backends.cudnn.benchmark = True     # Optimize for fixed input sizes
    torch.backends.cuda.matmul.allow_tf32 = True  # Allow TF32 on Ampere GPUs
    torch.backends.cudnn.allow_tf32 = True
    print("CUDA optimizations enabled")
elif device_type == "tpu":
    # For TPU
    os.environ["XLA_USE_BF16"] = "1"  # Enable bfloat16 for better performance on TPU
    print("TPU optimizations enabled")

In [None]:
# Define model and training parameters
MODEL_NAME = "microsoft/Phi-3-mini-128k-instruct"  # The model we're fine-tuning
MAX_LENGTH = 4096  # Maximum sequence length
BATCH_SIZE = 4  # Batch size for training
LEARNING_RATE = 2e-5  # Learning rate
WEIGHT_DECAY = 0.01  # Weight decay for regularization
NUM_EPOCHS = 3  # Number of training epochs
GRADIENT_ACCUMULATION_STEPS = 8  # Gradient accumulation steps
WARMUP_RATIO = 0.03  # Warmup ratio

# LoRA parameters
LORA_R = 16  # LoRA attention dimension
LORA_ALPHA = 32  # LoRA alpha parameter
LORA_DROPOUT = 0.05  # Dropout probability for LoRA layers

# Print configuration
print("\n" + "=" * 50)
print("MODEL CONFIGURATION")
print("=" * 50)
print(f"Model: {MODEL_NAME}")
print(f"Max Length: {MAX_LENGTH}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Weight Decay: {WEIGHT_DECAY}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Gradient Accumulation Steps: {GRADIENT_ACCUMULATION_STEPS}")
print(f"LoRA r: {LORA_R}, alpha: {LORA_ALPHA}, dropout: {LORA_DROPOUT}")
print(f"Device: {device_type.upper()}")
print("=" * 50 + "\n")

In [None]:
# Load the dataset
print("Loading Swift code dataset...")
try:
    # Load the dataset
    dataset = load_dataset("mvasiliniuc/iva-swift-codeint")
    print(f"Dataset loaded successfully with {len(dataset['train'])} training examples")
    
    # Print dataset information
    print("\nDataset structure:")
    print(dataset)
    
    # Get a sample to understand the data structure
    sample = dataset['train'][0]
    print("\nSample data point:")
    for key, value in sample.items():
        if isinstance(value, str) and len(value) > 100:
            print(f"{key}: {value[:100]}...")
        else:
            print(f"{key}: {value}")
    
    # Get label distribution
    if 'label' in sample:
        labels = [example['label'] for example in dataset['train']]
        label_counts = collections.Counter(labels)
        print("\nLabel distribution:")
        for label, count in label_counts.items():
            print(f"Label {label}: {count} examples ({count/len(dataset['train'])*100:.2f}%)")
    
    # Create category names mapping
    category_names = {
        0: "Models",
        1: "Views",
        2: "Controllers",
        3: "Utilities",
        4: "Tests",
        5: "Configuration"
    }
    
    # Split the dataset into train, validation, and test sets
    if 'validation' not in dataset or 'test' not in dataset:
        print("\nSplitting dataset into train, validation, and test sets...")
        splits = dataset['train'].train_test_split(test_size=0.2, seed=SEED)
        train_data = splits['train']
        
        # Further split the test set into validation and test
        test_splits = splits['test'].train_test_split(test_size=0.5, seed=SEED)
        val_data = test_splits['train']
        test_data = test_splits['test']
        
        print(f"Train set: {len(train_data)} examples")
        print(f"Validation set: {len(val_data)} examples")
        print(f"Test set: {len(test_data)} examples")
    else:
        train_data = dataset['train']
        val_data = dataset['validation']
        test_data = dataset['test']
        print(f"Using existing splits: Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
    
    # Clean up memory
    cleanup_memory()
    
except Exception as e:
    print(f"Error loading dataset: {e}")
    traceback.print_exc()
    raise

In [None]:
# Load tokenizer
print("\nLoading tokenizer...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    # Ensure the tokenizer has padding and EOS tokens
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
    print(f"Vocabulary size: {len(tokenizer)}")
    print(f"Model max length: {tokenizer.model_max_length}")
    print(f"Padding token: {tokenizer.pad_token} (id: {tokenizer.pad_token_id})")
    print(f"EOS token: {tokenizer.eos_token} (id: {tokenizer.eos_token_id})")
    
    # Test tokenization on a sample
    sample_text = "func hello() {\n    print(\"Hello, Swift!\")\n}"
    tokens = tokenizer(sample_text, return_tensors="pt")
    print(f"\nSample text tokenized to {len(tokens.input_ids[0])} tokens")
    
except Exception as e:
    print(f"Error loading tokenizer: {e}")
    traceback.print_exc()
    raise

In [None]:
# Create instruction-based prompts for the model
def create_instruction_prompt(example):
    """Convert a code example into an instruction-based prompt for language learning."""
    code = example['content']
    label = example['label']
    category = category_names.get(label, f"Unknown-{label}")
    
    # Create different types of prompts to help the model learn the language
    prompt_types = [
        # Explain code functionality
        "Explain what this Swift code does and how it works:\n\n",
        
        # Identify patterns and features
        "Identify and explain the key Swift language features used in this code:\n\n",
        
        # Complete or extend code
        "Complete or extend this Swift code with appropriate functionality:\n\n",
        
        # Fix or improve code
        "Suggest improvements or best practices for this Swift code:\n\n",
        
        # Understand code structure
        f"This is a Swift {category.lower()} file. Explain its structure and purpose:\n\n",
        
        # Code generation tasks
        "Write a Swift function that accomplishes the same task as this code but more efficiently:\n\n",
        
        # Language understanding
        "Explain the Swift syntax and language features demonstrated in this code:\n\n",
        
        # Learning from examples
        "Study this Swift code example and explain what you can learn from it:\n\n"
    ]
    
    # Select a random prompt type
    instruction = random.choice(prompt_types)
    
    code_section = f"```swift\n{code}\n```\n\n"
    
    # Create the full prompt
    prompt = instruction + code_section
    
    # Create a detailed response based on the prompt type and code category
    if "Explain what this Swift code does" in instruction:
        response = f"This Swift code is a {category.lower()} file that "
        if category == "Models":
            response += "defines data structures and model objects. "
        elif category == "Views":
            response += "implements user interface components. "
        elif category == "Controllers":
            response += "manages application logic and coordinates between models and views. "
        elif category == "Utilities":
            response += "provides helper functions and extensions. "
        elif category == "Tests":
            response += "contains test cases to verify functionality. "
        elif category == "Configuration":
            response += "configures application settings and parameters. "
        
        response += "The code uses Swift syntax with "
        
        # Add some language-specific details based on code content
        if "class" in code:
            response += "class definitions, "
        if "struct" in code:
            response += "struct definitions, "
        if "func" in code:
            response += "function declarations, "
        if "var" in code:
            response += "variable declarations, "
        if "let" in code:
            response += "constant declarations, "
        if "guard" in code or "if let" in code:
            response += "optional unwrapping, "
        if "extension" in code:
            response += "extensions, "
        if "protocol" in code:
            response += "protocol implementations, "
            
        # Remove trailing comma and space if present
        if response.endswith(", "):
            response = response[:-2] + "."
        else:
            response += "various Swift features."
    
    elif "Identify and explain the key Swift language features" in instruction:
        response = "This Swift code demonstrates several key language features:\n\n"
        
        # Add language features based on code content
        features = []
        if "class" in code:
            features.append("1. **Classes**: Swift classes are reference types that support inheritance and reference counting.")
        if "struct" in code:
            features.append("1. **Structs**: Swift structs are value types that are copied when assigned or passed as arguments.")
        if "protocol" in code:
            features.append("1. **Protocols**: Similar to interfaces in other languages, protocols define a blueprint of methods, properties, and requirements.")
        if "extension" in code:
            features.append("1. **Extensions**: Allow adding new functionality to existing types without modifying the original code.")
        if "enum" in code:
            features.append("1. **Enumerations**: Define a common type for a group of related values and enable you to work with those values in a type-safe way.")
        if "guard" in code:
            features.append("1. **Guard statements**: Early exit mechanism that improves code readability by handling edge cases early.")
        if "if let" in code or "guard let" in code:
            features.append("1. **Optional binding**: Safely unwraps optional values using if let or guard let syntax.")
        if "func" in code:
            features.append("1. **Functions**: First-class citizens in Swift that can be passed around and used like any other value.")
        if "closure" in code or "{" in code and "in" in code:
            features.append("1. **Closures**: Self-contained blocks of functionality that can be passed around and used in your code.")
        
        # If no specific features were identified, provide a generic response
        if not features:
            features = ["1. **Swift syntax**: The code demonstrates standard Swift syntax and conventions.",
                       "2. **Type safety**: Swift's strong type system is evident in the variable declarations.",
                       "3. **Object-oriented programming**: The code follows Swift's object-oriented programming principles."]
        
        # Fix numbering
        for i, feature in enumerate(features):
            features[i] = feature.replace("1.", f"{i+1}.")
        
        response += "\n".join(features)
    
    elif "Complete or extend this Swift code" in instruction or "Write a Swift function" in instruction:
        # For code generation tasks, provide a thoughtful response about how to approach the task
        response = f"To extend this Swift {category.lower()} code, I would consider the following approach:\n\n"
        
        if category == "Models":
            response += "1. Add additional properties to capture more data attributes\n"
            response += "2. Implement Codable protocol for easy JSON serialization\n"
            response += "3. Add validation methods to ensure data integrity\n"
            response += "4. Include computed properties for derived values\n\n"
            response += "Here's an implementation example:\n\n```swift\n"
            
            if "struct" in code:
                response += "// Extension to add Codable conformance\nextension MyStruct: Codable {\n    // Codable implementation\n}\n\n"
                response += "// Add validation method\nextension MyStruct {\n    func validate() -> Bool {\n        // Validation logic\n        return true\n    }\n}\n"
            else:
                response += "// Example extension or additional functionality\n// that would be appropriate for this model\n"
            
            response += "```"
            
        elif category == "Views":
            response += "1. Add UI customization options\n"
            response += "2. Implement additional user interaction handlers\n"
            response += "3. Add accessibility support\n"
            response += "4. Implement view lifecycle methods\n\n"
            response += "Here's an implementation example:\n\n```swift\n"
            response += "// Example extension or additional functionality\n// that would be appropriate for this view\n"
            response += "```"
            
        else:
            response += "1. Add error handling to make the code more robust\n"
            response += "2. Implement additional helper methods\n"
            response += "3. Add documentation comments to improve code readability\n"
            response += "4. Consider performance optimizations where appropriate\n\n"
            response += "Here's an implementation example:\n\n```swift\n"
            response += "// Example extension or additional functionality\n// that would be appropriate for this code\n"
            response += "```"
    
    else:
        # Generic response for other prompt types
        response = f"This Swift code demonstrates typical patterns used in {category.lower()} files. "
        response += "It follows Swift language conventions and showcases proper syntax for defining "
        
        if category == "Models":
            response += "data structures with properties and methods. Swift models typically use structs for value semantics or classes when reference semantics are needed. The code demonstrates Swift's strong typing system and property declarations."
        elif category == "Views":
            response += "UI components with layout and interaction logic. Swift views often use UIKit or SwiftUI frameworks, with clear separation of UI elements and their behaviors. The code shows how Swift handles user interface components and event responses."
        elif category == "Controllers":
            response += "application logic and coordination between components. Controllers in Swift manage the flow of data between models and views, implementing business logic and handling user interactions. The code demonstrates Swift's approach to application architecture."
        elif category == "Utilities":
            response += "helper functions and extensions to enhance functionality. Swift utilities often leverage the language's powerful extension capabilities to add functionality to existing types. The code shows how Swift can be extended and customized through utility functions."
        elif category == "Tests":
            response += "test cases with setup, execution, and verification steps. Swift tests typically use XCTest framework with arrange-act-assert pattern. The code demonstrates Swift's approach to unit testing and verification."
        elif category == "Configuration":
            response += "application settings and configuration parameters. Swift configuration files often define constants, environment settings, and application parameters. The code shows how Swift handles application configuration and settings management."
    
    # Combine prompt and response for instruction tuning
    full_text = f"<|user|>\n{prompt}\n<|assistant|>\n{response}\n"
    
    return {
        "text": full_text,
        "prompt": prompt,
        "response": response,
        "label": label,
        "category": category
    }

# Apply the function to create instruction-based datasets
try:
    # Create instruction datasets
    print("\nCreating instruction-based prompts for training...")
    train_instructions = train_data.map(create_instruction_prompt)
    val_instructions = val_data.map(create_instruction_prompt)
    test_instructions = test_data.map(create_instruction_prompt)
    
    # Print an example to verify
    print("\nExample instruction prompt:")
    print("-" * 80)
    print(train_instructions[0]['text'])
    print("-" * 80)
    
    print(f"Created {len(train_instructions)} training instructions")
    print(f"Created {len(val_instructions)} validation instructions")
    print(f"Created {len(test_instructions)} test instructions")
except Exception as e:
    print(f"Error creating instruction prompts: {e}")
    traceback.print_exc()
    raise

In [None]:
# Tokenize the datasets
print("\nTokenizing datasets...")

# Function to tokenize the text
def tokenize_function(examples):
    # Tokenize the full instruction text
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors=None  # Return Python lists instead of PyTorch tensors
    )
    return tokenized

try:
    # Apply tokenization to the datasets
    print("Tokenizing training set...")
    tokenized_train = train_instructions.map(
        tokenize_function,
        batched=True,
        remove_columns=["content", "label", "text", "prompt", "response", "category"]
    )
    
    print("Tokenizing validation set...")
    tokenized_val = val_instructions.map(
        tokenize_function,
        batched=True,
        remove_columns=["content", "label", "text", "prompt", "response", "category"]
    )
    
    print("Tokenizing test set...")
    tokenized_test = test_instructions.map(
        tokenize_function,
        batched=True,
        remove_columns=["content", "label", "text", "prompt", "response", "category"]
    )
    
    # Print tokenized dataset information
    print(f"\nTokenized training set: {len(tokenized_train)} examples")
    print(f"Tokenized validation set: {len(tokenized_val)} examples")
    print(f"Tokenized test set: {len(tokenized_test)} examples")
    
    # Print a sample of tokenized data
    print("\nSample tokenized data:")
    sample_idx = 0
    print(f"Input IDs (first 10): {tokenized_train[sample_idx]['input_ids'][:10]}")
    print(f"Attention Mask (first 10): {tokenized_train[sample_idx]['attention_mask'][:10]}")
    print(f"Total tokens: {sum(tokenized_train[sample_idx]['attention_mask'])}")
    
    # Clean up memory
    cleanup_memory()
    
except Exception as e:
    print(f"Error tokenizing datasets: {e}")
    traceback.print_exc()
    raise

In [None]:
try:
    # Configure model loading based on device type
    print("\n" + "=" * 50)
    print(f"HARDWARE CONFIGURATION: {device_type.upper()}")
    print("=" * 50)
    
    if device_type == "tpu":
        # TPU-specific configuration
        print("\n[TPU SETUP] Configuring model for TPU training...")
        print(f"[TPU SETUP] Using device: {device}")
        
        # Import TPU-specific modules
        import torch_xla.core.xla_model as xm
        import torch_xla.distributed.parallel_loader as pl
        
        # Get TPU core count and worker information
        tpu_cores = xm.xrt_world_size()
        tpu_worker = xm.get_ordinal()
        print(f"[TPU SETUP] TPU cores available: {tpu_cores}")
        print(f"[TPU SETUP] Current TPU worker: {tpu_worker}")
        
        # TPUs work better with bf16 precision
        print("[TPU SETUP] Loading model with bfloat16 precision...")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,  # Use bfloat16 for TPU
            trust_remote_code=True
        )
        print(f"[TPU SETUP] Model loaded: {MODEL_NAME}")
        
        # Configure LoRA for TPU
        print("[TPU SETUP] Configuring LoRA parameters...")
        lora_config = LoraConfig(
            r=LORA_R,
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        )
        print(f"[TPU SETUP] LoRA config: rank={LORA_R}, alpha={LORA_ALPHA}, dropout={LORA_DROPOUT}")
        
        # Apply LoRA to the model
        print("[TPU SETUP] Applying LoRA to model...")
        model = get_peft_model(model, lora_config)
        
        # Move model to TPU device
        print("[TPU SETUP] Moving model to TPU device...")
        model = model.to(device)
        
        print("[TPU SETUP] TPU configuration complete")
        
    elif device_type == "gpu":
        # GPU-specific configuration with quantization
        print("\n[GPU SETUP] Configuring model for GPU training with quantization...")
        
        # Get GPU information
        gpu_count = torch.cuda.device_count()
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"[GPU SETUP] Using GPU: {gpu_name}")
        print(f"[GPU SETUP] GPU count: {gpu_count}")
        print(f"[GPU SETUP] GPU memory: {gpu_memory:.2f} GB")
        
        # Configure quantization for efficient GPU training
        print("[GPU SETUP] Configuring 4-bit quantization...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True
        )
        print("[GPU SETUP] Quantization type: NF4 (4-bit)")
        
        # Load the Phi-3 model for causal language modeling
        print(f"[GPU SETUP] Loading model: {MODEL_NAME}")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )
        print("[GPU SETUP] Model loaded successfully")
        
        # Prepare the model for training
        print("[GPU SETUP] Preparing model for k-bit training...")
        model = prepare_model_for_kbit_training(model)
        
        # Configure LoRA
        print("[GPU SETUP] Configuring LoRA parameters...")
        lora_config = LoraConfig(
            r=LORA_R,
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        )
        print(f"[GPU SETUP] LoRA config: rank={LORA_R}, alpha={LORA_ALPHA}, dropout={LORA_DROPOUT}")
        
        # Apply LoRA to the model
        print("[GPU SETUP] Applying LoRA to model...")
        model = get_peft_model(model, lora_config)
        
        print("[GPU SETUP] GPU configuration complete")
        
    else:
        # CPU fallback - use smaller model or minimal settings
        print("\n[CPU SETUP] Configuring model for CPU training (minimal configuration)...")
        print("[CPU SETUP] WARNING: Training on CPU will be very slow!")
        
        # Load the model with minimal settings for CPU
        print(f"[CPU SETUP] Loading model: {MODEL_NAME}")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float32,  # Use float32 for CPU
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )
        print("[CPU SETUP] Model loaded with float32 precision")
        
        # Configure LoRA with smaller rank for CPU
        print("[CPU SETUP] Configuring LoRA with reduced parameters for CPU...")
        lora_config = LoraConfig(
            r=8,  # Smaller rank for CPU
            lora_alpha=16,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
        )
        print("[CPU SETUP] Using reduced LoRA rank (r=8) for CPU efficiency")
        
        # Apply LoRA to the model
        print("[CPU SETUP] Applying LoRA to model...")
        model = get_peft_model(model, lora_config)
        
        # Move model to CPU
        print("[CPU SETUP] Moving model to CPU...")
        model = model.to(device)
        
        print("[CPU SETUP] CPU configuration complete")
    
    # Print trainable parameters
    print("\n[MODEL INFO] Trainable parameters:")
    model.print_trainable_parameters()
    
    # Print model architecture summary
    print("\n[MODEL INFO] Model architecture:")
    print(f"  - Base model: {MODEL_NAME}")
    print(f"  - Model type: {model.__class__.__name__}")
    print(f"  - Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  - Training device: {device_type.upper()}")
    
    # Check memory usage
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    print(f"\n[SYSTEM INFO] Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
    
    # Print success message
    print("\n" + "=" * 50)
    print(f"MODEL SUCCESSFULLY LOADED ON {device_type.upper()}")
    print("=" * 50 + "\n")
    
except Exception as e:
    print(f"\n[ERROR] Failed to load model: {e}")
    print("\n[DEBUG INFO] Detailed error information:")
    import traceback
    traceback.print_exc()
    print(f"\n[DEBUG INFO] Device type: {device_type}")
    print(f"[DEBUG INFO] Device: {device}")
    print(f"[DEBUG INFO] Model name: {MODEL_NAME}")
    raise

In [None]:
# Define hardware-specific training arguments
print(f"\n[TRAINING SETUP] Configuring training arguments for {device_type.upper()}...")

# Configure precision based on hardware
if device_type == "tpu":
    print("[TRAINING SETUP] Using TPU-specific training configuration")
    # TPU-specific settings
    use_fp16 = False  # TPUs work better with bf16 than fp16
    use_bf16 = True   # Use bfloat16 for TPU
    tpu_metrics_debug = True  # Enable TPU metrics debugging
    dataloader_workers = 8  # TPUs can handle more workers
    
    # TPU-specific environment variables
    os.environ["XLA_USE_BF16"] = "1"
    print("[TRAINING SETUP] Using bfloat16 precision for TPU training")
    
    # Adjust batch size and gradient accumulation for TPU
    effective_batch_size = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
    print(f"[TRAINING SETUP] Effective batch size: {effective_batch_size}")
    
elif device_type == "gpu":
    print("[TRAINING SETUP] Using GPU-specific training configuration")
    # GPU-specific settings
    use_fp16 = True   # Use fp16 for GPU
    use_bf16 = False  # bf16 not optimal for most GPUs
    tpu_metrics_debug = False
    dataloader_workers = 2  # Fewer workers for GPU to avoid memory issues
    
    # Check if we have Ampere or newer GPU that supports TF32
    if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
        # Enable TF32 for faster training on Ampere+ GPUs
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("[TRAINING SETUP] TF32 enabled for Ampere or newer GPU")
    
    print(f"[TRAINING SETUP] Using fp16 precision for GPU training")
    
else:  # CPU
    print("[TRAINING SETUP] Using CPU-specific training configuration")
    # CPU-specific settings
    use_fp16 = False  # No fp16 for CPU
    use_bf16 = False  # No bf16 for CPU
    tpu_metrics_debug = False
    dataloader_workers = 0  # No multiprocessing for CPU to avoid overhead
    
    # Reduce batch size for CPU
    if BATCH_SIZE > 1:
        print(f"[TRAINING SETUP] WARNING: Reducing batch size from {BATCH_SIZE} to 1 for CPU training")
        BATCH_SIZE = 1
        GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS * 4  # Increase gradient accumulation to compensate
    
    print(f"[TRAINING SETUP] Using float32 precision for CPU training")

# Define training arguments with hardware-specific optimizations
training_args = TrainingArguments(
    output_dir="./phi3_swift_results",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    learning_rate=LEARNING_RATE,
    evaluation_strategy="steps",
    eval_steps=0.1,  # Evaluate every 10% of training
    save_strategy="steps",
    save_steps=0.1,  # Save every 10% of training
    load_best_model_at_end=True,
    logging_dir="./logs",
    logging_steps=50,  # More frequent logging
    save_total_limit=3,
    fp16=use_fp16,
    bf16=use_bf16,
    report_to="none",
    remove_unused_columns=False,
    push_to_hub=False,
    disable_tqdm=False,
    dataloader_num_workers=dataloader_workers,
    dataloader_pin_memory=True,
    group_by_length=True,  # Group similar length sequences for efficiency
    tpu_metrics_debug=tpu_metrics_debug,  # Enable TPU metrics debugging if on TPU
    # Additional optimizations
    gradient_checkpointing=True,  # Enable gradient checkpointing to save memory
    optim="adamw_torch",  # Use PyTorch's AdamW implementation
    ddp_find_unused_parameters=False,  # Optimize distributed training
    torch_compile=False,  # Disable torch.compile for now as it can be unstable
    seed=SEED,  # Ensure reproducibility
)

print(f"[TRAINING SETUP] Training arguments configured for {device_type.upper()}")
print(f"[TRAINING SETUP] Batch size: {BATCH_SIZE}, Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"[TRAINING SETUP] Learning rate: {LEARNING_RATE}, Epochs: {NUM_EPOCHS}")
print(f"[TRAINING SETUP] FP16: {use_fp16}, BF16: {use_bf16}")

# Define callbacks
print("[TRAINING SETUP] Setting up training callbacks...")

# Early stopping callback
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

# Resource monitoring callback
resource_monitor = ResourceMonitorCallback(log_every=50)

# Create data collator for language modeling
print("[TRAINING SETUP] Creating data collator for language modeling...")
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # We're doing causal language modeling, not masked language modeling
)

# Create trainer with hardware-specific configuration
print("[TRAINING SETUP] Creating trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[early_stopping_callback, resource_monitor]
)

# Special handling for TPU
if device_type == "tpu":
    print("[TRAINING SETUP] Configuring TPU-specific training settings...")
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    
    # XLA-specific settings
    xm.rendezvous("trainer_setup")  # Synchronize TPU cores
    print("[TRAINING SETUP] TPU cores synchronized")

print("\n[TRAINING SETUP] Training setup complete and ready to start")
print("=" * 50)

In [None]:
# Function to monitor system resources during training
def monitor_resources():
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    mem = psutil.virtual_memory()
    cpu_percent = psutil.cpu_percent(interval=0.1)
    
    print(f"\n[SYSTEM RESOURCES]")
    print(f"CPU Usage: {cpu_percent}%")
    print(f"Process Memory: {memory_info.rss / 1024 / 1024:.2f} MB")
    print(f"System Memory: {mem.percent}% used, {mem.available / 1024 / 1024:.2f} MB available")
    
    # GPU-specific monitoring
    if device_type == "gpu":
        for i in range(torch.cuda.device_count()):
            gpu_mem_alloc = torch.cuda.memory_allocated(i) / 1024 / 1024
            gpu_mem_reserved = torch.cuda.memory_reserved(i) / 1024 / 1024
            print(f"GPU {i}: Allocated: {gpu_mem_alloc:.2f} MB | Reserved: {gpu_mem_reserved:.2f} MB")
    
    # TPU-specific monitoring
    if device_type == "tpu":
        try:
            import torch_xla.debug.metrics as met
            print("TPU Memory Stats:")
            print(met.metrics_report())
        except:
            print("TPU metrics not available")

# Run training with more detailed monitoring
try:
    print("\n" + "=" * 50)
    print("STARTING TRAINING")
    print("=" * 50)
    
    # Monitor resources before training
    print("\n[PRE-TRAINING] Resources before training:")
    monitor_resources()
    
    # Start training with a timeout
    start_time = time.time()
    
    # Run training
    print("\n[TRAINING] Starting training process...")
    train_result = trainer.train()
    
    # Calculate training time
    training_time = time.time() - start_time
    hours, remainder = divmod(training_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    # Monitor resources after training
    print("\n[POST-TRAINING] Resources after training:")
    monitor_resources()
    
    # Print training results
    print("\n" + "=" * 50)
    print("TRAINING COMPLETED")
    print("=" * 50)
    print(f"Training completed in {int(hours)}h {int(minutes)}m {seconds:.2f}s")
    print(f"Training loss: {train_result.metrics['train_loss']:.4f}")
    
    # Print all metrics
    print("\n[TRAINING METRICS]")
    for key, value in train_result.metrics.items():
        print(f"{key}: {value}")
    
    # Save the model
    print("\n[SAVING MODEL] Saving model to ./phi3_swift_model")
    trainer.save_model("./phi3_swift_model")
    print("Model saved successfully")
    
    # Save tokenizer
    print("[SAVING MODEL] Saving tokenizer")
    tokenizer.save_pretrained("./phi3_swift_model")
    print("Tokenizer saved successfully")
    
    # Clean up memory
    print("\n[CLEANUP] Cleaning up memory...")
    cleanup_memory()
    
except Exception as e:
    print(f"\n[ERROR] Error during training: {e}")
    
    # Print stack trace for debugging
    print("\n[DEBUG INFO] Detailed error information:")
    traceback.print_exc()
    
    # Monitor resources after error
    print("\n[DEBUG INFO] Resources after error:")
    monitor_resources()
    
    # Try to save checkpoint if possible
    try:
        print("\n[RECOVERY] Attempting to save checkpoint after error...")
        trainer.save_model("./phi3_swift_model_error_recovery")
        print("Recovery checkpoint saved to ./phi3_swift_model_error_recovery")
    except:
        print("Could not save recovery checkpoint")
    
    raise

In [None]:
# Test the model with Swift code examples
try:
    print("\n" + "=" * 50)
    print("TESTING THE MODEL")
    print("=" * 50)
    
    # Function to generate responses for test examples
    def generate_response(prompt):
        print(f"[TESTING] Generating response for prompt...")
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        # Hardware-specific generation
        with torch.no_grad():
            if device_type == "tpu":
                import torch_xla.core.xla_model as xm
                outputs = model.generate(
                    inputs.input_ids,
                    max_new_tokens=200,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
                # Ensure the outputs are moved from TPU to CPU for decoding
                outputs = xm.mesh_reduce('outputs', outputs, lambda x: x)
            else:
                outputs = model.generate(
                    inputs.input_ids,
                    max_new_tokens=200,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract just the assistant's response
        if "<|assistant|>" in response:
            response = response.split("<|assistant|>")[-1].strip()
        return response
    
    # Test prompts for different Swift language tasks
    test_prompts = [
        # Explain Swift syntax
        "<|user|>\nExplain the key features of Swift's optional unwrapping syntax:\n\n```swift\nfunc processName(_ name: String?) {\n    guard let unwrappedName = name else {\n        print("No name provided")\n        return\n    }\n    print("Hello, \(unwrappedName)!")\n}\n```\n<|assistant|>",
        
        # Code completion
        "<|user|>\nComplete this Swift function that calculates the factorial of a number:\n\n```swift\nfunc factorial(_ n: Int) -> Int {\n    // Add implementation here\n}\n```\n<|assistant|>",
        
        # Debugging help
        "<|user|>\nWhat's wrong with this Swift code and how can I fix it?\n\n```swift\nclass Person {\n    var name: String\n    var age: Int\n    \n    func greet() {\n        print("Hello, my name is \(name) and I am \(age) years old.")\n    }\n}\n\nlet person = Person()\nperson.greet()\n```\n<|assistant|>",
        
        # Swift best practices
        "<|user|>\nExplain Swift best practices for error handling:\n<|assistant|>"
    ]
    
    # Generate and print responses
    for i, prompt in enumerate(test_prompts):
        print(f"\n[TEST {i+1}]\n{'-'*40}")
        print(f"Prompt: {prompt.split('<|assistant|>')[0].replace('<|user|>', '')}")
        response = generate_response(prompt)
        print(f"\nResponse:\n{response}\n")
        
        # Clean up between tests
        if device_type == "tpu":
            import torch_xla.core.xla_model as xm
            xm.mark_step()
    
    print("\n[TESTING] Testing complete")
except Exception as e:
    print(f"\n[ERROR] Error during testing: {e}")
    traceback.print_exc()

In [None]:
# Final cleanup and summary
print("\n" + "=" * 50)
print("TRAINING SUMMARY")
print("=" * 50)

print(f"Model: {MODEL_NAME}")
print(f"Training device: {device_type.upper()}")
print(f"Dataset: mvasiliniuc/iva-swift-codeint")
print(f"Training examples: {len(tokenized_train)}")
print(f"Validation examples: {len(tokenized_val)}")
print(f"Test examples: {len(tokenized_test)}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"LoRA rank: {LORA_R}")

print("\nModel saved to: ./phi3_swift_model")
print("\nTraining complete!")