In [None]:
%%capture
# Install Unsloth and dependencies
import os
if "COLAB_" not in "".join(os.environ.keys()):
    %pip install unsloth
else:
    # Colab environment
    %pip install --no-deps --upgrade timm  # For Gemma 3N
    %pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton
    %pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    %pip install --no-deps unsloth
    %pip install comet-ml


In [None]:
import os
import re
import io
from typing import Tuple, List, Dict, Any, Optional
from PIL import Image
import requests
import torch
from datasets import load_dataset, Dataset
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
import comet_ml

# Import Unsloth for vision models
from unsloth import FastVisionModel, get_chat_template

print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
print("📦 All dependencies loaded successfully!")


In [None]:
# Configuration - Optimized for Unsloth and Vietnamese Math Tutoring
CONFIG = {
    # Model settings
    "model_name": "unsloth/gemma-3n-E4B",
    "max_seq_length": 2048,
    "load_in_4bit": True,
    
    # Dataset settings
    "dataset_name": "ngohongthai/exam-sixth_grade-instruct-dataset",
    "train_split": "train",
    
    # Training settings - Conservative for stability
    "output_dir": "mathpal-gemma3n/optimized_baseline",
    "max_steps": 100,  # Reduced for testing
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 8,
    "learning_rate": 2e-4,
    "warmup_ratio": 0.03,
    "weight_decay": 0.01,
    "logging_steps": 5,
    "save_steps": 25,
    
    # LoRA settings - Following Unsloth recommendations
    "lora_r": 16,  # Reduced from 32 for stability
    "lora_alpha": 16,
    "lora_dropout": 0.0,  # 0 is optimized for Unsloth
    
    # System settings
    "use_gradient_checkpointing": "unsloth",  # Use Unsloth's optimized version
    "report_to": None,  # Disable for now
    "seed": 42,
    
    # Image processing
    "max_images_per_sample": 3,  # Limit for memory efficiency
    "image_timeout": 5,  # Faster timeout for unreachable images
}

print(f"🔧 Configuration loaded:")
print(f"   Model: {CONFIG['model_name']}")
print(f"   Dataset: {CONFIG['dataset_name']}")
print(f"   Max steps: {CONFIG['max_steps']}")
print(f"   Effective batch size: {CONFIG['per_device_train_batch_size'] * CONFIG['gradient_accumulation_steps']}")


In [None]:
# Simplified Data Processing Functions
def download_image_safe(url: str, timeout: int = 5) -> Optional[Image.Image]:
    """Safely download image with timeout and error handling."""
    try:
        if not url or not url.startswith(('http://', 'https://')):
            return None
            
        response = requests.get(url, timeout=timeout, stream=True)
        response.raise_for_status()
        
        content_type = response.headers.get('content-type', '')
        if not content_type.startswith('image/'):
            return None
            
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
        
        # Validate image size
        if image.size[0] < 10 or image.size[1] < 10:
            return None
            
        return image
        
    except Exception:
        return None

def extract_images_from_markdown(text: str, max_images: int = 3) -> Tuple[str, List[Image.Image]]:
    """Extract images from markdown text and clean the text."""
    image_pattern = r"!\[.*?\]\((.*?)\)"
    image_urls = re.findall(image_pattern, text)
    
    # Limit number of images for memory efficiency
    image_urls = image_urls[:max_images]
    
    # Remove image markdown syntax and clean text
    cleaned_text = re.sub(image_pattern, " [IMAGE] ", text)
    cleaned_text = re.sub(r"\s+", " ", cleaned_text).strip()
    
    # Download images
    images = []
    for url in image_urls:
        image = download_image_safe(url, timeout=CONFIG["image_timeout"])
        if image:
            images.append(image)
    
    return cleaned_text, images

def process_sample_to_conversation(sample: Dict[str, str]) -> Dict[str, Any]:
    """Convert a dataset sample to conversation format."""
    # Process question and solution
    question_text, question_images = extract_images_from_markdown(
        sample["question"], max_images=CONFIG["max_images_per_sample"]
    )
    solution_text, solution_images = extract_images_from_markdown(
        sample["solution"], max_images=CONFIG["max_images_per_sample"]
    )
    
    # Create user message content
    user_content = [{"type": "text", "text": question_text}]
    for img in question_images:
        user_content.append({"type": "image", "image": img})
    
    # Create assistant message content
    assistant_content = [{"type": "text", "text": solution_text}]
    for img in solution_images:
        assistant_content.append({"type": "image", "image": img})
    
    # Create conversation in chat format
    messages = [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": assistant_content}
    ]
    
    return {"messages": messages}

print("✅ Data processing functions loaded")


In [None]:
# Simplified Data Collator for Vision-Language Models
class OptimizedVisionDataCollator:
    """Simplified data collator that handles both text-only and multimodal samples."""
    
    def __init__(self, processor, max_length: int = 2048):
        self.processor = processor
        self.max_length = max_length
        self.placeholder_image = None
    
    def _get_placeholder_image(self):
        """Create a small placeholder image for text-only samples."""
        if self.placeholder_image is None:
            self.placeholder_image = Image.new('RGB', (32, 32), color=(240, 240, 240))
        return self.placeholder_image
    
    def _extract_content(self, messages: List[Dict]) -> Tuple[str, List[Image.Image]]:
        """Extract text and images from messages."""
        images = []
        
        # Apply chat template to get formatted text
        try:
            formatted_text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
        except Exception:
            # Fallback to simple formatting
            formatted_text = ""
            for msg in messages:
                role = msg.get("role", "")
                content = msg.get("content", [])
                for item in content:
                    if item.get("type") == "text":
                        formatted_text += f"{role}: {item.get('text', '')}\\n"
        
        # Extract images from all messages
        for message in messages:
            for content_item in message.get("content", []):
                if content_item.get("type") == "image":
                    img = content_item.get("image")
                    if img and hasattr(img, 'convert'):
                        images.append(img.convert('RGB'))
        
        # Ensure we have at least one image for the processor
        if not images:
            images = [self._get_placeholder_image()]
            # Add image token to text if not present
            if '<image>' not in formatted_text:
                formatted_text = '<image>\\n' + formatted_text
        
        return formatted_text, images
    
    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """Collate a batch of examples."""
        batch_texts = []
        batch_images = []
        
        for example in examples:
            messages = example.get("messages", [])
            text, images = self._extract_content(messages)
            batch_texts.append(text)
            batch_images.append(images)
        
        try:
            # Process with the vision processor
            batch = self.processor(
                text=batch_texts,
                images=batch_images,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length
            )
            
            # Create labels for training
            if "input_ids" in batch:
                labels = batch["input_ids"].clone()
                # Mask padding tokens
                if hasattr(self.processor, 'tokenizer') and hasattr(self.processor.tokenizer, 'pad_token_id'):
                    labels[labels == self.processor.tokenizer.pad_token_id] = -100
                batch["labels"] = labels
            
            return batch
            
        except Exception as e:
            print(f"❌ Error in data collator: {e}")
            # Return minimal valid batch
            return {
                "input_ids": torch.tensor([[0]]),
                "labels": torch.tensor([[-100]])
            }

print("✅ Optimized data collator loaded")


In [None]:
# Model Setup with FastVisionModel - Optimized for Gemma3N
def setup_gemma3n_model(config: Dict[str, Any]):
    """Setup Gemma3N model using FastVisionModel."""
    print("🔧 Loading Gemma3N model with FastVisionModel...")
    
    # Load model and processor
    model, processor = FastVisionModel.from_pretrained(
        config["model_name"],
        max_seq_length=config["max_seq_length"],
        load_in_4bit=config["load_in_4bit"],
        use_gradient_checkpointing=config["use_gradient_checkpointing"]
    )
    
    print("🎯 Applying PEFT (LoRA) configuration...")
    
    # Apply LoRA with Unsloth optimizations
    model = FastVisionModel.get_peft_model(
        model,
        # Vision and language layers
        finetune_vision_layers=True,
        finetune_language_layers=True,
        finetune_attention_modules=True,
        finetune_mlp_modules=True,
        
        # LoRA settings - optimized for Unsloth
        r=config["lora_r"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],  # 0 is optimized
        bias="none",  # "none" is optimized
        
        # Unsloth optimizations
        use_gradient_checkpointing=config["use_gradient_checkpointing"],
        random_state=config["seed"],
        use_rslora=False,  # Disabled for stability
        
        # Target modules - let Unsloth decide for vision models
        target_modules="all-linear",
        modules_to_save=["lm_head", "embed_tokens"]
    )
    
    # Setup chat template for Gemma3N
    processor = get_chat_template(processor, "gemma-3n")
    
    print("✅ Model and processor setup complete!")
    print(f"   Model type: {type(model).__name__}")
    print(f"   Processor type: {type(processor).__name__}")
    
    return model, processor

# Prepare dataset function
def prepare_dataset_optimized(dataset_name: str, split: str) -> Dataset:
    """Load and prepare the dataset with progress tracking."""
    print(f"📥 Loading dataset: {dataset_name}, split: {split}")
    raw_dataset = load_dataset(dataset_name, split=split)
    
    print(f"🔄 Processing {len(raw_dataset)} samples...")
    
    processed_data = []
    errors = 0
    
    for i, sample in enumerate(raw_dataset):
        try:
            processed_sample = process_sample_to_conversation(sample)
            processed_data.append(processed_sample)
        except Exception as e:
            errors += 1
            if errors <= 5:  # Log first 5 errors
                print(f"⚠️ Error processing sample {i}: {e}")
        
        # Progress update
        if (i + 1) % 100 == 0:
            print(f"   Processed {i + 1}/{len(raw_dataset)} samples (errors: {errors})")
    
    success_rate = (len(processed_data) / len(raw_dataset)) * 100
    print(f"✅ Successfully processed {len(processed_data)}/{len(raw_dataset)} samples ({success_rate:.1f}%)")
    
    if errors > 0:
        print(f"⚠️ {errors} samples failed to process")
    
    return Dataset.from_list(processed_data)

print("✅ Model setup and dataset functions loaded")


In [None]:
# Training Setup - Optimized for Unsloth
def create_optimized_trainer(model, processor, train_dataset, config: Dict[str, Any]):
    """Create optimized SFTTrainer with Unsloth best practices."""
    print("🔧 Creating optimized trainer...")
    
    # Enable training mode
    FastVisionModel.for_training(model)
    
    # Create data collator
    data_collator = OptimizedVisionDataCollator(
        processor, 
        max_length=config["max_seq_length"]
    )
    
    # Training arguments optimized for Unsloth
    training_args = SFTConfig(
        # Basic settings
        output_dir=config["output_dir"],
        max_steps=config["max_steps"],
        per_device_train_batch_size=config["per_device_train_batch_size"],
        gradient_accumulation_steps=config["gradient_accumulation_steps"],
        
        # Optimization
        learning_rate=config["learning_rate"],
        warmup_ratio=config["warmup_ratio"],
        weight_decay=config["weight_decay"],
        
        # Unsloth optimized settings
        optim="adamw_8bit",  # 8-bit optimizer for memory efficiency
        lr_scheduler_type="cosine",
        
        # Memory optimization
        gradient_checkpointing=True,
        dataloader_pin_memory=False,  # Can cause issues with images
        max_grad_norm=0.3,
        
        # Logging and saving
        logging_steps=config["logging_steps"],
        save_strategy="steps",
        save_steps=config["save_steps"],
        report_to=config["report_to"],
        
        # Vision-specific settings
        remove_unused_columns=False,  # Important for vision models
        dataset_text_field="",  # We handle formatting in collator
        dataset_kwargs={"skip_prepare_dataset": True},
        
        # Reproducibility
        seed=config["seed"],
        
        # Stability
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported()
    )
    
    # Create trainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        processing_class=processor.tokenizer,
        data_collator=data_collator,
        args=training_args
    )
    
    print("✅ Optimized trainer created successfully!")
    return trainer

print("✅ Training setup functions loaded")


In [None]:
# Step 1: Create output directory and load model
os.makedirs(CONFIG["output_dir"], exist_ok=True)
print(f"📁 Output directory: {CONFIG['output_dir']}")

# Load model and processor
model, processor = setup_gemma3n_model(CONFIG)


In [None]:
# Step 2: Prepare dataset
train_dataset = prepare_dataset_optimized(CONFIG["dataset_name"], CONFIG["train_split"])

# Dataset statistics
print(f"\n📊 Dataset Statistics:")
print(f"   Total samples: {len(train_dataset)}")

# Count multimodal vs text-only samples
multimodal_count = 0
for sample in train_dataset[:100]:  # Check first 100 for speed
    messages = sample.get("messages", [])
    has_image = any(
        content.get("type") == "image" 
        for msg in messages 
        for content in msg.get("content", [])
    )
    if has_image:
        multimodal_count += 1

estimated_multimodal = (multimodal_count / min(100, len(train_dataset))) * len(train_dataset)
print(f"   Estimated multimodal samples: {estimated_multimodal:.0f}")
print(f"   Estimated text-only samples: {len(train_dataset) - estimated_multimodal:.0f}")


In [None]:
# Step 3: Test data collator
print("🧪 Testing data collator...")
try:
    test_collator = OptimizedVisionDataCollator(processor, max_length=CONFIG["max_seq_length"])
    test_batch = test_collator([train_dataset[0]])
    
    print("✅ Data collator test passed!")
    print(f"   Batch keys: {list(test_batch.keys())}")
    for key, value in test_batch.items():
        if hasattr(value, 'shape'):
            print(f"   {key}: {value.shape}")
            
except Exception as e:
    print(f"❌ Data collator test failed: {e}")
    print("Please check the data format and try again.")
    raise


In [None]:
# Step 4: Create trainer and start training
trainer = create_optimized_trainer(model, processor, train_dataset, CONFIG)

print(f"\n🚀 Starting optimized training...")
print(f"   Model: {CONFIG['model_name']}")
print(f"   Dataset: {CONFIG['dataset_name']}")
print(f"   Max steps: {CONFIG['max_steps']}")
print(f"   Batch size: {CONFIG['per_device_train_batch_size']}")
print(f"   Gradient accumulation: {CONFIG['gradient_accumulation_steps']}")
print(f"   Effective batch size: {CONFIG['per_device_train_batch_size'] * CONFIG['gradient_accumulation_steps']}")
print(f"   Learning rate: {CONFIG['learning_rate']}")
print(f"   LoRA rank: {CONFIG['lora_r']}")

try:
    trainer_stats = trainer.train()
    
    print("\n🎉 Training completed successfully!")
    print(f"   Final loss: {trainer_stats.training_loss:.4f}")
    
except Exception as e:
    print(f"\n❌ Training failed: {e}")
    import traceback
    traceback.print_exc()
    
    # Clean up GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    raise


In [None]:
# Step 5: Save model
print("💾 Saving trained model...")

try:
    # Save LoRA adapters
    model.save_pretrained_merged(
        CONFIG["output_dir"], 
        processor.tokenizer, 
        save_method="lora"
    )
    
    print(f"✅ Model saved to {CONFIG['output_dir']}")
    
    # Optionally save to HuggingFace Hub
    # model.push_to_hub_merged(
    #     "your-username/gemma3n-math-tutor", 
    #     processor.tokenizer, 
    #     save_method="lora",
    #     private=True
    # )
    
except Exception as e:
    print(f"❌ Failed to save model: {e}")

print("\n🎉 Fine-tuning completed successfully!")
print(f"📁 Model artifacts saved in: {CONFIG['output_dir']}")


In [None]:
# Step 6: Test inference with the trained model
print("🔮 Testing inference...")

try:
    # Enable inference mode
    FastVisionModel.for_inference(model)
    
    # Test prompt
    test_messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Giải bài toán sau: Một hình chữ nhật có chiều dài 8cm và chiều rộng 5cm. Tính chu vi của hình chữ nhật."}
            ]
        }
    ]
    
    # Format input
    formatted_input = processor.apply_chat_template(
        test_messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # Tokenize
    inputs = processor(
        text=[formatted_input],
        images=[[Image.new('RGB', (32, 32), color=(240, 240, 240))]],  # Placeholder
        return_tensors="pt",
        padding=True
    )
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs.to(model.device),
            max_new_tokens=256,
            temperature=0.7,
            do_sample=True,
            pad_token_id=processor.tokenizer.eos_token_id
        )
    
    # Decode response
    response = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print("✅ Inference test successful!")
    print(f"\n📝 Test Response:")
    print(response)
    
except Exception as e:
    print(f"❌ Inference test failed: {e}")
    import traceback
    traceback.print_exc()

print("\n🎯 Fine-tuning pipeline completed!")
