In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [None]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N
!pip install comet-ml

In [None]:
import os
import re
import io
import zipfile
from typing import Tuple, List, Dict, Any, Optional
from PIL import Image
import requests
import comet_ml
from unsloth import FastVisionModel, get_chat_template
import torch
from datasets import load_dataset, Dataset
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


comet_ml is installed but the Comet API Key is not configured. Please set the `COMET_API_KEY` environment variable to enable Comet logging. Check out the documentation for other ways of configuring it: https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key


🦥 Unsloth Zoo will now patch everything to make training faster!


## COMET logging config

In [None]:
from google.colab import userdata
# Load the Comet.ml API key from Colab secrets
os.environ['COMET_API_KEY'] = userdata.get('COMET_API_KEY')
os.environ['COMET_PROJECT'] = userdata.get('COMET_PROJECT')
os.environ['COMET_WORKSPACE'] = userdata.get('COMET_WORKSPACE')

COMET_CONFIG = {
    # API Key - REQUIRED
    # Cách 1: Set environment variable
    # export COMET_API_KEY="your-api-key-here"

    # Cách 2: Set trực tiếp (không khuyến nghị cho production)
    "api_key": os.getenv("COMET_API_KEY"),  # Hoặc thay bằng API key của bạn

    # Workspace - REQUIRED
    # Tên workspace trên Comet ML
    "workspace": os.getenv("COMET_WORKSPACE"),  # Thay bằng workspace của bạn

    # Project Name - REQUIRED
    # Tên project trên Comet ML
    "project": os.getenv("COMET_PROJECT"),  # Có thể thay đổi tên project

    # Experiment Name - OPTIONAL
    # Tên experiment cụ thể (tự động generate nếu không set)
    "experiment_name": "base_line",  # Hoặc đặt tên custom như "exp-001"

    # Tags - OPTIONAL
    # Tags để phân loại experiments
    "tags": [
        "gemma3n",
        "vision-language",
        "math-tutor",
        "vietnamese",
        "sixth-grade",
        "fine-tuning"
    ],

    # Additional Settings
    "auto_metric_logging": True,     # Tự động log metrics
    "auto_param_logging": True,      # Tự động log parameters
    "auto_histogram_weight_logging": True,   # Log weight histograms
    "auto_histogram_gradient_logging": True, # Log gradient histograms
    "auto_histogram_activation_logging": False,  # Tắt để tiết kiệm memory
    "auto_output_logging": "default",  # Log output (stdout/stderr)

    # Model Logging
    "log_model": True,              # Upload model artifacts
    "log_graph": False,             # Log model graph (có thể chậm)
    "log_code": True,               # Log source code
    "log_git_metadata": True,       # Log git information
}

In [None]:
def print_comet_info():
    """Print Comet ML configuration info."""

    print("🔧 Comet ML Configuration:")
    print(f"   Workspace: {COMET_CONFIG['workspace']}")
    print(f"   Project: {COMET_CONFIG['project']}")
    print(f"   API Key: {'✅ Set' if COMET_CONFIG['api_key'] else '❌ Not set'}")
    print(f"   Tags: {', '.join(COMET_CONFIG['tags'])}")

print_comet_info()

🔧 Comet ML Configuration:
   Workspace: mathpal
   Project: mathpal-gemma3n
   API Key: ✅ Set
   Tags: gemma3n, vision-language, math-tutor, vietnamese, sixth-grade, fine-tuning


## Training config

In [None]:
CONFIG = {
    # Model settings
    "model_name": "unsloth/gemma-3n-E4B",
    "max_seq_length": 2048,
    "load_in_4bit": True,

    # Dataset settings
    "dataset_name": "ngohongthai/exam-sixth_grade-instruct-dataset",
    "train_split": "train",
    "test_split": "test",

    # Training settings
    "output_dir": f"{COMET_CONFIG['project']}/{COMET_CONFIG['experiment_name']}",
    "max_steps": 200,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 8,
    "learning_rate": 2e-4,
    "warmup_ratio": 0.03,
    "weight_decay": 0.01,
    "logging_steps": 5,
    "save_steps": 50,

    # LoRA settings
    "lora_r": 32,
    "lora_alpha": 32,
    "lora_dropout": 0.0,

    # System settings
    "use_gradient_checkpointing": False,  # Disabled to avoid CheckpointError
    "report_to": "comet_ml",  # Change to "tensorboard", "wandb" if needed
    "seed": 42,

    # Comet ML settings
    "comet_workspace": COMET_CONFIG['workspace'],  # Set your Comet workspace name
    "comet_project": COMET_CONFIG['project'],  # Set your Comet project name
}

## IMAGE PROCESSING UTILITIES

In [None]:
def url_to_image(url: str, timeout: int = 10) -> Optional[Image.Image]:
    """
    Download and convert URL to PIL Image.

    Args:
        url: Image URL
        timeout: Request timeout in seconds

    Returns:
        PIL Image object or None if failed
    """
    try:
        response = requests.get(url, timeout=timeout)
        response.raise_for_status()
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
        return image
    except (requests.exceptions.RequestException, IOError) as e:
        print(f"Failed to load image from {url}: {e}")
        return None

def extract_image_urls_from_markdown(text: str) -> Tuple[str, List[str]]:
    """
    Extract image URLs from markdown text and replace with placeholders.

    Args:
        text: Markdown text containing image links

    Returns:
        Tuple of (cleaned_text, list_of_image_urls)
    """
    # Pattern for markdown images: ![alt](url)
    image_pattern = r"!\[.*?\]\((.*?)\)"
    image_urls = re.findall(image_pattern, text)

    # Remove image markdown syntax
    cleaned_text = re.sub(image_pattern, " ", text).strip()

    return cleaned_text, image_urls

def process_markdown_for_model(text: str) -> Tuple[str, List[Image.Image]]:
    """
    Process markdown text to extract text and images for multimodal model.

    Args:
        text: Input markdown text

    Returns:
        Tuple of (processed_text, list_of_pil_images)
    """
    cleaned_text, image_urls = extract_image_urls_from_markdown(text)

    # Download images
    images = []
    for url in image_urls:
        image = url_to_image(url)
        if image:
            images.append(image)
        else:
            print(f"Warning: Failed to load image from {url}")

    return cleaned_text, images

## DATASET PROCESSING

In [None]:
def create_conversation_content(text: str, images: List[Image.Image]) -> List[Dict[str, Any]]:
    """
    Create conversation content list with text and images.

    Args:
        text: Text content
        images: List of PIL images

    Returns:
        List of content dictionaries
    """
    content = [{"type": "text", "text": text}]

    # Add images
    for image in images:
        content.append({"type": "image", "image": image})

    return content

def process_math_sample(sample: Dict[str, str]) -> Dict[str, List[Dict[str, Any]]]:
    """
    Process a single math problem sample into conversation format.

    Args:
        sample: Dataset sample with 'question' and 'solution' keys

    Returns:
        Dictionary with 'conversations' key containing the formatted conversation
    """
    # Process question
    question_text, question_images = process_markdown_for_model(sample["question"])
    user_content = create_conversation_content(question_text, question_images)

    # Process solution (usually text-only, but check for images)
    solution_text, solution_images = process_markdown_for_model(sample["solution"])
    assistant_content = create_conversation_content(solution_text, solution_images)

    # Create conversation
    conversations = [
        {
            "role": "user",
            "content": user_content
        },
        {
            "role": "assistant",
            "content": assistant_content
        }
    ]

    return {"conversations": conversations}

def prepare_dataset(dataset_name: str, split: str) -> Dataset:
    """
    Load and prepare the math dataset.

    Args:
        dataset_name: HuggingFace dataset name
        split: Dataset split to load

    Returns:
        Processed Dataset object
    """
    print(f"Loading dataset: {dataset_name}, split: {split}")
    raw_dataset = load_dataset(dataset_name, split=split)

    print(f"Processing {len(raw_dataset)} samples...")
    processed_data = []

    for i, sample in enumerate(raw_dataset):
        try:
            processed_sample = process_math_sample(sample)
            processed_data.append(processed_sample)

            if (i + 1) % 100 == 0:
                print(f"Processed {i + 1}/{len(raw_dataset)} samples")

        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            continue

    print(f"Successfully processed {len(processed_data)} samples")
    return Dataset.from_list(processed_data)

## CUSTOM DATA COLLATOR

In [None]:
class HybridVisionDataCollator:
    """
    Advanced data collator xử lý cả text-only và text+image samples.
    Tự động detect và xử lý mixed batches một cách thông minh.
    """

    def __init__(self, processor, handle_text_only=True):
        self.processor = processor
        self.handle_text_only = handle_text_only
        self.placeholder_image = None

    def _create_placeholder_image(self):
        """Create a minimal placeholder image cho text-only samples."""
        if self.placeholder_image is None:
            # Tạo image nhỏ để minimize memory impact
            self.placeholder_image = Image.new('RGB', (32, 32), color=(245, 245, 245))
        return self.placeholder_image

    def _validate_and_process_image(self, img):
        """Validate and process a single image."""
        if img is None:
            return None

        try:
            if not hasattr(img, 'convert'):
                return None

            img = img.convert('RGB')

            if img.size[0] < 1 or img.size[1] < 1:
                return None

            return img

        except Exception as e:
            return None

    def _extract_images_from_conversation(self, conv):
        """Extract and validate all images from a conversation."""
        images = []

        for message in conv:
            for content in message.get("content", []):
                if content.get("type") == "image" and "image" in content:
                    img = content["image"]
                    processed_img = self._validate_and_process_image(img)
                    if processed_img is not None:
                        images.append(processed_img)

        return images

    def _has_real_images(self, conv):
        """Check if conversation has any real images."""
        images = self._extract_images_from_conversation(conv)
        return len(images) > 0

    def _create_text_only_conversation(self, conv):
        """Convert conversation to text-only format cho processor."""
        text_only_conv = []

        for message in conv:
            text_only_message = {
                "role": message["role"],
                "content": []
            }

            # Extract chỉ text content
            for content in message.get("content", []):
                if content.get("type") == "text":
                    text_only_message["content"].append(content)

            # Ensure có ít nhất empty text content
            if not text_only_message["content"]:
                text_only_message["content"] = [{"type": "text", "text": ""}]

            text_only_conv.append(text_only_message)

        return text_only_conv

    def _insert_image_token_strategically(self, text, num_images=1):
        """Insert image token vào position thông minh trong text."""
        if '<image>' in text:
            return text

        lines = text.split('\n')

        # Strategy 1: Insert sau user role marker
        for i, line in enumerate(lines):
            if any(marker in line.lower() for marker in ['<|user|>', 'user:', 'human:']):
                # Insert sau line này
                insert_pos = i + 1
                for _ in range(num_images):
                    lines.insert(insert_pos, '<image>')
                    insert_pos += 1
                return '\n'.join(lines)

        # Strategy 2: Insert ở đầu content nếu không tìm thấy role marker
        if len(lines) > 0:
            # Insert sau line đầu tiên (thường là role header)
            for _ in range(num_images):
                lines.insert(1, '<image>')
            return '\n'.join(lines)

        # Fallback: insert ở đầu
        image_tokens = '\n'.join(['<image>'] * num_images)
        return image_tokens + '\n' + text

    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """
        Collate mixed batch của text-only và text+image samples.

        Args:
            examples: List of processed conversation examples

        Returns:
            Batch dictionary with tensors
        """
        try:
            print(f"Processing hybrid batch of {len(examples)} examples...")

            # Classify samples
            image_samples = []
            text_only_samples = []

            for idx, example in enumerate(examples):
                conv = example["conversations"]
                if self._has_real_images(conv):
                    image_samples.append((idx, example))
                    print(f"Sample {idx}: IMAGE SAMPLE ✅")
                else:
                    text_only_samples.append((idx, example))
                    print(f"Sample {idx}: TEXT ONLY 📝")

            print(f"Batch composition: {len(image_samples)} image samples, {len(text_only_samples)} text-only")

            # Handle different scenarios
            if len(image_samples) > 0 and len(text_only_samples) > 0:
                # Mixed batch - xử lý hybrid
                return self._process_mixed_batch(image_samples, text_only_samples)
            elif len(image_samples) > 0:
                # Pure image batch
                return self._process_image_batch(image_samples)
            elif len(text_only_samples) > 0 and self.handle_text_only:
                # Pure text batch với placeholder images
                return self._process_text_only_batch(text_only_samples)
            else:
                raise ValueError("No valid samples to process or text-only handling disabled!")

        except Exception as e:
            print(f"Critical error in hybrid data collator: {e}")
            import traceback
            traceback.print_exc()
            raise e

    def _process_image_batch(self, image_samples):
        """Process batch chỉ có image samples."""
        print("Processing pure image batch...")

        texts = []
        images_list = []

        for idx, example in image_samples:
            conv = example["conversations"]
            images = self._extract_images_from_conversation(conv)

            # Generate text
            text = self.processor.apply_chat_template(
                conv, tokenize=False, add_generation_prompt=False
            )

            # Validate token count
            image_token_count = text.count('<image>')
            actual_image_count = len(images)

            print(f"Image sample {idx}: {actual_image_count} images, {image_token_count} tokens")

            # Sync tokens và images
            if image_token_count != actual_image_count:
                if image_token_count < actual_image_count:
                    images = images[:image_token_count] if image_token_count > 0 else images[:1]
                elif image_token_count > actual_image_count:
                    # Thêm placeholder images
                    while len(images) < image_token_count:
                        images.append(self._create_placeholder_image())

            texts.append(text)
            images_list.append(images)

        return self._create_batch(texts, images_list)

    def _process_text_only_batch(self, text_only_samples):
        """Process batch chỉ có text samples với placeholder images."""
        print("Processing text-only batch with placeholder images...")

        texts = []
        images_list = []

        for idx, example in text_only_samples:
            conv = example["conversations"]

            # Convert to text-only format
            text_only_conv = self._create_text_only_conversation(conv)

            # Generate text
            text = self.processor.apply_chat_template(
                text_only_conv, tokenize=False, add_generation_prompt=False
            )

            # Add 1 placeholder image và corresponding token
            placeholder = self._create_placeholder_image()
            text_with_image_token = self._insert_image_token_strategically(text, num_images=1)

            print(f"Text sample {idx}: Added 1 placeholder image and token")

            texts.append(text_with_image_token)
            images_list.append([placeholder])

        return self._create_batch(texts, images_list)

    def _process_mixed_batch(self, image_samples, text_only_samples):
        """Process mixed batch có cả image và text-only samples."""
        print("Processing mixed batch...")

        texts = []
        images_list = []

        # Process image samples first
        for idx, example in image_samples:
            conv = example["conversations"]
            images = self._extract_images_from_conversation(conv)

            text = self.processor.apply_chat_template(
                conv, tokenize=False, add_generation_prompt=False
            )

            image_token_count = text.count('<image>')
            actual_image_count = len(images)

            print(f"Mixed - Image sample {idx}: {actual_image_count} images, {image_token_count} tokens")

            # Sync tokens và images
            if image_token_count != actual_image_count:
                if image_token_count < actual_image_count:
                    images = images[:image_token_count] if image_token_count > 0 else images[:1]
                elif image_token_count > actual_image_count:
                    while len(images) < image_token_count:
                        images.append(self._create_placeholder_image())

            texts.append(text)
            images_list.append(images)

        # Process text-only samples with placeholders
        for idx, example in text_only_samples:
            conv = example["conversations"]
            text_only_conv = self._create_text_only_conversation(conv)

            text = self.processor.apply_chat_template(
                text_only_conv, tokenize=False, add_generation_prompt=False
            )

            # Add placeholder
            placeholder = self._create_placeholder_image()
            text_with_token = self._insert_image_token_strategically(text, num_images=1)

            print(f"Mixed - Text sample {idx}: Added placeholder image and token")

            texts.append(text_with_token)
            images_list.append([placeholder])

        return self._create_batch(texts, images_list)

    def _create_batch(self, texts, images_list):
        """Create final batch tensor."""
        print(f"Creating batch: {len(texts)} texts, {len(images_list)} image lists")

        # Final validation
        for i, (text, imgs) in enumerate(zip(texts, images_list)):
            token_count = text.count('<image>')
            image_count = len(imgs)

            if token_count != image_count:
                print(f"⚠️  Sample {i}: Token/image mismatch ({token_count} vs {image_count})")
                # Fix mismatch
                if token_count > image_count:
                    while len(imgs) < token_count:
                        imgs.append(self._create_placeholder_image())
                    images_list[i] = imgs
                elif image_count > token_count:
                    images_list[i] = imgs[:token_count] if token_count > 0 else imgs[:1]

        # Process with processor
        print("Sending to processor...")
        batch = self.processor(
            text=texts,
            images=images_list,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=CONFIG["max_seq_length"]
        )

        # Create labels
        labels = batch["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        print(f"Batch created successfully: {batch.keys()}")
        if "pixel_values" in batch:
            print(f"pixel_values shape: {batch['pixel_values'].shape}")
        print(f"input_ids shape: {batch['input_ids'].shape}")

        return batch

## MODEL SETUP AND TRAINING

In [None]:
def setup_model_and_processor(config: Dict[str, Any]):
    """
    Load and setup Gemma3N model and processor.

    Args:
        config: Configuration dictionary

    Returns:
        Tuple of (model, processor)
    """
    print("Loading Gemma3N model and processor...")

    # Load model and processor
    model, processor = FastVisionModel.from_pretrained(
        config["model_name"],
        max_seq_length=config["max_seq_length"],
        load_in_4bit=config["load_in_4bit"],
        use_gradient_checkpointing="unsloth" if config["use_gradient_checkpointing"] else False,
    )

    # Apply LoRA
    model = FastVisionModel.get_peft_model(
        model,
        finetune_vision_layers=True,
        finetune_language_layers=True,
        finetune_attention_modules=True,
        finetune_mlp_modules=True,
        r=config["lora_r"],
        lora_alpha=config["lora_alpha"],
        lora_dropout=config["lora_dropout"],
        bias="none",
        random_state=config["seed"],
        use_rslora=False,
        target_modules="all-linear",
        modules_to_save=["lm_head", "embed_tokens"],
    )

    # Setup chat template
    processor = get_chat_template(processor, "gemma-3n")

    print("Model and processor setup complete!")
    return model, processor

def create_trainer(model, processor, train_dataset, config: Dict[str, Any]):
    """
    Create optimized SFTTrainer.

    Args:
        model: Prepared model
        processor: Model processor
        train_dataset: Training dataset
        config: Configuration dictionary

    Returns:
        Configured SFTTrainer
    """
    # Enable training
    FastVisionModel.for_training(model)

    # Create data collator
    data_collator = HybridVisionDataCollator(
        processor,
        handle_text_only=True  # Enable text-only handling for mixed dataset
    )

    # Training arguments
    training_args = SFTConfig(
        # Basic training settings
        output_dir=config["output_dir"],
        max_steps=config["max_steps"],
        per_device_train_batch_size=config["per_device_train_batch_size"],
        gradient_accumulation_steps=config["gradient_accumulation_steps"],

        # Optimization settings
        learning_rate=config["learning_rate"],
        warmup_ratio=config["warmup_ratio"],
        weight_decay=config["weight_decay"],
        optim="adamw_torch_fused",
        lr_scheduler_type="cosine",

        # Memory optimization
        gradient_checkpointing=config["use_gradient_checkpointing"],
        gradient_checkpointing_kwargs={"use_reentrant": False} if config["use_gradient_checkpointing"] else {},
        max_grad_norm=0.3,

        # Logging and saving
        logging_steps=config["logging_steps"],
        save_strategy="steps",
        save_steps=config["save_steps"],
        report_to=config["report_to"],

        # Vision-specific settings
        remove_unused_columns=False,
        dataset_text_field="",
        dataset_kwargs={"skip_prepare_dataset": True},
        max_length=config["max_seq_length"],

        # Reproducibility
        seed=config["seed"],
    )

    # Create trainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        processing_class=processor.tokenizer,
        data_collator=data_collator,
        args=training_args,
    )

    return trainer

## COMET ML SETUP

In [None]:
def setup_comet_ml(config: Dict[str, Any]) -> None:
    """
    Setup Comet ML experiment tracking with full configuration support.

    Args:
        config: Configuration dictionary
    """
    if config["report_to"] == "comet_ml":
        try:
            # Initialize Comet experiment
            experiment_kwargs = {
                "workspace": COMET_CONFIG.get("workspace"),
                "project_name": COMET_CONFIG.get("project"),
                "auto_metric_logging": COMET_CONFIG.get("auto_metric_logging", True),
                "auto_param_logging": COMET_CONFIG.get("auto_param_logging", True),
                "auto_histogram_weight_logging": COMET_CONFIG.get("auto_histogram_weight_logging", True),
                "auto_histogram_gradient_logging": COMET_CONFIG.get("auto_histogram_gradient_logging", True),
                "auto_histogram_activation_logging": COMET_CONFIG.get("auto_histogram_activation_logging", False),
            }

            # Remove None values
            experiment_kwargs = {k: v for k, v in experiment_kwargs.items() if v is not None}

            experiment = comet_ml.Experiment(**experiment_kwargs)

            # Log configuration
            experiment.log_parameters(config)

            # Add tags
            tags = COMET_CONFIG.get("tags", ["gemma3n", "vision-language", "math-tutor"])
            for tag in tags:
                experiment.add_tag(tag)

            # Log additional metadata
            experiment.log_other("dataset", "ngohongthai/exam-sixth_grade-instruct-dataset")
            experiment.log_other("model_base", "unsloth/gemma-3n-E4B")
            experiment.log_other("task", "sixth-grade-math-tutoring")
            experiment.log_other("language", "vietnamese")

            print(f"✅ Comet ML experiment initialized")
            print(f"🔗 Experiment URL: {experiment.url}")
            print(f"📊 Workspace: {COMET_CONFIG.get('workspace', 'default')}")
            print(f"📁 Project: {COMET_CONFIG.get('project', 'gemma3n-math-tutor')}")

            # Set environment variables for transformers integration
            os.environ["COMET_PROJECT_NAME"] = COMET_CONFIG.get("project", "gemma3n-math-tutor")
            if COMET_CONFIG.get("workspace"):
                os.environ["COMET_WORKSPACE"] = COMET_CONFIG["workspace"]

            return experiment

        except ImportError:
            print("❌ comet_ml not installed. Please install with: pip install comet-ml")
            print("Falling back to tensorboard logging...")
            config["report_to"] = "tensorboard"
            return None
        except Exception as e:
            print(f"❌ Failed to initialize Comet ML: {e}")
            print("Possible causes:")
            print("- Invalid API key or workspace/project names")
            print("- Network connection issues")
            print("- Missing permissions")
            print("Falling back to tensorboard logging...")
            config["report_to"] = "tensorboard"
            return None

    return None

## TRAININGGGGGG

In [None]:
os.makedirs(CONFIG["output_dir"], exist_ok=True)

In [None]:
comet_experiment = setup_comet_ml(CONFIG)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/mathpal/mathpal-gemma3n/4e7cd6e71c614c0eb81b9802883202b0

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/content' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


✅ Comet ML experiment initialized
🔗 Experiment URL: https://www.comet.com/mathpal/mathpal-gemma3n/4e7cd6e71c614c0eb81b9802883202b0
📊 Workspace: mathpal
📁 Project: mathpal-gemma3n


In [None]:
# 1. Setup model and processor
model, processor = setup_model_and_processor(CONFIG)

Loading Gemma3N model and processor...
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.54.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Unsloth: Making `model.base_model.model.model.language_model` require gradients
Model and processor setup complete!


In [None]:
# 2. Prepare dataset
train_dataset = prepare_dataset(CONFIG["dataset_name"], CONFIG["train_split"])

# Print dataset statistics
print(f"\nDataset Statistics:")
print(f"- Training samples: {len(train_dataset)}")

# Count samples with images
samples_with_images = 0
for sample in train_dataset:
    for conv in sample["conversations"]:
        for content in conv.get("content", []):
            if content.get("type") == "image":
                samples_with_images += 1
                break
        else:
            continue
        break

print(f"- Samples with images: {samples_with_images}")
print(f"- Text-only samples: {len(train_dataset) - samples_with_images}")

Loading dataset: ngohongthai/exam-sixth_grade-instruct-dataset, split: train
Processing 1010 samples...
Processed 100/1010 samples
Processed 200/1010 samples
Processed 300/1010 samples
Processed 400/1010 samples
Processed 500/1010 samples
Processed 600/1010 samples
Processed 700/1010 samples
Processed 800/1010 samples
Processed 900/1010 samples
Processed 1000/1010 samples
Successfully processed 1010 samples

Dataset Statistics:
- Training samples: 1010
- Samples with images: 117
- Text-only samples: 893


In [None]:
# 3. Create trainer
trainer = create_trainer(model, processor, train_dataset, CONFIG)

In [None]:
 # 4. Start training
print(f"\n🚀 Starting training...")
print(f"- Output directory: {CONFIG['output_dir']}")
print(f"- Max steps: {CONFIG['max_steps']}")
print(f"- Batch size: {CONFIG['per_device_train_batch_size']}")
print(f"- Gradient accumulation: {CONFIG['gradient_accumulation_steps']}")
print(f"- Effective batch size: {CONFIG['per_device_train_batch_size'] * CONFIG['gradient_accumulation_steps']}")

# Train the model
trainer_stats = trainer.train()


🚀 Starting training...
- Output directory: mathpal-gemma3n/base_line
- Max steps: 200
- Batch size: 1
- Gradient accumulation: 8
- Effective batch size: 8


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,010 | Num Epochs = 2 | Total steps = 200
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 76,840,960 of 7,926,819,152 (0.97% trained)
[1;38;5;39mCOMET INFO:[0m An experiment with the same configuration options is already running and will be reused.


Processing hybrid batch of 1 examples...
Sample 0: TEXT ONLY 📝
Batch composition: 0 image samples, 1 text-only
Critical error in hybrid data collator: No valid samples to process or text-only handling disabled!


Traceback (most recent call last):
  File "/tmp/ipython-input-270888095.py", line 147, in __call__
    raise ValueError("No valid samples to process or text-only handling disabled!")
ValueError: No valid samples to process or text-only handling disabled!


ValueError: No valid samples to process or text-only handling disabled!

## DEBUGGING

In [None]:
def debug_raw_dataset():
    """Debug raw dataset trước khi processing."""
    print("🔍 Debugging raw dataset...")

    try:
        # Load raw dataset
        raw_dataset = load_dataset("ngohongthai/exam-sixth_grade-instruct-dataset", split="train")
        print(f"Raw dataset size: {len(raw_dataset)}")

        # Check first few samples
        for i in range(min(3, len(raw_dataset))):
            sample = raw_dataset[i]
            print(f"\nSample {i}:")
            print(f"  Keys: {sample.keys()}")
            print(f"  Question: {sample['question'][:100]}...")
            print(f"  Solution: {sample['solution'][:100]}...")
            print(f"  Question has images: {'![' in sample['question']}")
            print(f"  Solution has images: {'![' in sample['solution']}")

    except Exception as e:
        print(f"Error loading raw dataset: {e}")
        return False

    return True

debug_raw_dataset()

In [None]:
def debug_processed_dataset():
    """Debug processed dataset."""
    print("\n🔍 Debugging processed dataset...")

    try:
        # Process small subset
        raw_dataset = load_dataset("ngohongthai/exam-sixth_grade-instruct-dataset", split="train[:5]")

        processed_data = []
        for i, sample in enumerate(raw_dataset):
            try:
                processed_sample = process_math_sample(sample)
                processed_data.append(processed_sample)

                print(f"\nProcessed sample {i}:")
                print(f"  Conversations: {len(processed_sample['conversations'])}")

                for j, conv in enumerate(processed_sample['conversations']):
                    print(f"  Conv {j}: role={conv['role']}, content_items={len(conv['content'])}")
                    for k, content in enumerate(conv['content']):
                        if content['type'] == 'image':
                            img = content.get('image')
                            print(f"    Content {k}: image, type={type(img)}, size={getattr(img, 'size', 'unknown')}")
                        else:
                            print(f"    Content {k}: text, length={len(content.get('text', ''))}")

            except Exception as e:
                print(f"Error processing sample {i}: {e}")
                import traceback
                traceback.print_exc()

        return len(processed_data) > 0

    except Exception as e:
        print(f"Error in processed dataset debug: {e}")
        import traceback
        traceback.print_exc()
        return False

debug_processed_dataset()

In [None]:
def debug_processor():
    """Debug processor setup."""
    print("\n🔍 Debugging processor...")

    try:
        model, processor = setup_model_and_processor(CONFIG)

        print("✅ Model and processor loaded successfully")
        print(f"Processor type: {type(processor)}")
        print(f"Tokenizer type: {type(processor.tokenizer)}")
        print(f"Has image processor: {hasattr(processor, 'image_processor')}")

        # Test simple processing
        test_text = "Hello world"
        test_image = Image.new('RGB', (224, 224), color='white')

        try:
            batch = processor(
                text=[test_text],
                images=[[test_image]],
                return_tensors="pt",
                padding=True
            )
            print("✅ Processor test successful")
            print(f"Batch keys: {batch.keys()}")

        except Exception as e:
            print(f"❌ Processor test failed: {e}")
            return False

        return True

    except Exception as e:
        print(f"Error in processor debug: {e}")
        import traceback
        traceback.print_exc()
        return False

debug_processor()

In [None]:
def debug_data_collator():
    """Debug data collator với real data."""
    print("\n🔍 Debugging data collator...")

    try:
        # Setup processor
        model, processor = setup_model_and_processor(CONFIG)

        # Create collator
        collator = HybridVisionDataCollator(processor, handle_text_only=True)

        # Load và process một sample
        raw_dataset = load_dataset("ngohongthai/exam-sixth_grade-instruct-dataset", split="train[:3]")

        processed_samples = []
        for sample in raw_dataset:
            try:
                processed = process_math_sample(sample)
                processed_samples.append(processed)
            except Exception as e:
                print(f"Error processing sample: {e}")
                # Tạo fallback sample
                fallback_sample = {
                    "conversations": [
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": "Test question"}]
                        },
                        {
                            "role": "assistant",
                            "content": [{"type": "text", "text": "Test answer"}]
                        }
                    ]
                }
                processed_samples.append(fallback_sample)

        print(f"Processed {len(processed_samples)} samples")

        # Test collator
        try:
            batch = collator(processed_samples)
            print("✅ Data collator test successful!")
            print(f"Batch keys: {batch.keys()}")
            for key, value in batch.items():
                if hasattr(value, 'shape'):
                    print(f"  {key}: {value.shape}")

            return True

        except Exception as e:
            print(f"❌ Data collator failed: {e}")
            import traceback
            traceback.print_exc()
            return False

    except Exception as e:
        print(f"Error in data collator debug: {e}")
        import traceback
        traceback.print_exc()
        return False

debug_data_collator()

In [None]:
def test_data_collator(train_dataset, processor, num_samples=2):
    """Test the data collator with a few samples to catch issues early."""
    print(f"\n🧪 Testing data collator with {num_samples} samples...")

    try:
        # Create data collator
        collator = HybridVisionDataCollator(processor, handle_text_only=True)

        # Test with a small batch
        test_samples = [train_dataset[i] for i in range(min(num_samples, len(train_dataset)))]

        print(f"Test samples prepared: {len(test_samples)}")

        # Try to collate
        batch = collator(test_samples)

        print("✅ Data collator test passed!")
        print(f"Batch keys: {batch.keys()}")
        for key, value in batch.items():
            if hasattr(value, 'shape'):
                print(f"  {key}: {value.shape}")
            else:
                print(f"  {key}: {type(value)}")

        return True

    except Exception as e:
        print(f"❌ Data collator test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


test_data_collator(train_dataset, processor, num_samples=3)