In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [2]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N
!pip install comet-ml

## Load model

In [3]:
from unsloth import FastVisionModel # FastLanguageModel for LLMs
import torch

model, processor = FastVisionModel.from_pretrained(
    "unsloth/gemma-3n-E4B",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


comet_ml is installed but the Comet API Key is not configured. Please set the `COMET_API_KEY` environment variable to enable Comet logging. Check out the documentation for other ways of configuring it: https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key


🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.54.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.72G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/769 [00:00<?, ?B/s]

In [4]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 32,                           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 32,                  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,               # We support rank stabilized LoRA
    loftq_config = None,               # And LoftQ
    target_modules = "all-linear",    # Optional now! Can specify a list if needed
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


## Data preparation

In [5]:
import re
import requests
from PIL import Image
from io import BytesIO
from typing import Tuple, List, Dict, Any, Optional

def url_to_image(url: str) -> Optional[Image.Image]:
    """
    Tải ảnh từ một URL và chuyển đổi nó thành đối tượng PIL.Image.

    Args:
        url: Đường dẫn URL của hình ảnh.

    Returns:
        Một đối tượng PIL.Image nếu thành công, ngược lại trả về None.
    """
    try:
        # Gửi yêu cầu GET để tải nội dung của ảnh
        response = requests.get(url, timeout=10)
        # Kiểm tra nếu yêu cầu thành công (status code 200)
        response.raise_for_status()

        # Đọc nội dung ảnh từ response và mở bằng Pillow
        image = Image.open(BytesIO(response.content)).convert("RGB")
        return image
    except requests.exceptions.RequestException as e:
        print(f"Lỗi khi tải ảnh từ URL {url}: {e}")
        return None
    except IOError as e:
        print(f"Lỗi khi mở file ảnh từ URL {url}: {e}")
        return None

def extract_and_replace_images(markdown_text: str) -> Tuple[str, List[str]]:
    """
    Tách các đường link ảnh từ văn bản Markdown, thay thế chúng bằng thẻ <image>
    và trả về văn bản đã xử lý cùng với danh sách các link ảnh.
    """
    image_pattern = r"!\[.*?\]\((.*?)\)"
    image_urls = re.findall(image_pattern, markdown_text)
    processed_text = re.sub(image_pattern, " ", markdown_text)
    return processed_text, image_urls

def process_markdown_for_model(markdown_text: str) -> Dict[str, Any]:
    """
    Xử lý toàn bộ văn bản markdown để chuẩn bị cho model đa phương thức.

    Hàm này sẽ:
    1. Trích xuất URL ảnh và thay thế bằng thẻ <image>.
    2. Tải và chuyển đổi các URL thành danh sách đối tượng PIL.Image.

    Args:
        markdown_text: Chuỗi văn bản đầu vào.

    Returns:
        Một dictionary chứa:
        - 'text': văn bản đã xử lý.
        - 'images': danh sách các đối tượng PIL.Image.
    """
    # Bước 1: Trích xuất text và URLs
    cleaned_text, image_urls = extract_and_replace_images(markdown_text)

    # Bước 2: Chuyển đổi URLs thành đối tượng PIL.Image
    pil_images = []
    for url in image_urls:
        image_obj = url_to_image(url)
        if image_obj:
            pil_images.append(image_obj)

    return cleaned_text, pil_images


In [6]:
from datasets import load_dataset
train_dataset = load_dataset("ngohongthai/exam-sixth_grade-instruct-dataset", split = "train")
test_dataset = load_dataset("ngohongthai/exam-sixth_grade-instruct-dataset", split = "test")

README.md:   0%|          | 0.00/408 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/245k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/29.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1010 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/113 [00:00<?, ? examples/s]

In [7]:
def process_message(text, role):
  processed_text, images_data = process_markdown_for_model(text)
  user_content = [
      {"type": "text", "text": processed_text},
  ]
  if images_data is not None:
    # Kiểm tra xem images_data có phải là một list (mảng) hay không
    if isinstance(images_data, list):
      # Nếu là một danh sách ảnh, duyệt qua và thêm từng ảnh
      for image in images_data:
        user_content.append({"type": "image", "image": image})
    else:
      # Nếu không phải list, coi đó là một ảnh duy nhất và thêm vào
      user_content.append({"type": "image", "image": images_data})

  return {
          "role": role,
          "content": user_content,
      }


In [8]:
def process_conversation(sample):
  user_message = process_message(sample["question"], "user")
  assistant_message = process_message(sample["solution"], "assistant")
  conversation = [user_message, assistant_message]

  return {"conversations": conversation}


In [9]:
test_sample = test_dataset[6]
test_sample

{'question': 'Cho hình chữ nhật ABCD. Trên AB lấy điểm E sao cho AB = AE x 3, DB cắt EC tại G.\n\n![](https://img.loigiaihay.com/picture/2023/0728/18_1.png)\n\nBiết diện tích hình chữ nhật ABCD là 144',
 'solution': 'Nội dung này không có lời giải chi tiết.'}

In [10]:
test_conversation = process_conversation(test_sample)
test_conversation

{'conversations': [{'role': 'user',
   'content': [{'type': 'text',
     'text': 'Cho hình chữ nhật ABCD. Trên AB lấy điểm E sao cho AB = AE x 3, DB cắt EC tại G.\n\n \n\nBiết diện tích hình chữ nhật ABCD là 144'},
    {'type': 'image',
     'image': <PIL.Image.Image image mode=RGB size=345x241>}]},
  {'role': 'assistant',
   'content': [{'type': 'text',
     'text': 'Nội dung này không có lời giải chi tiết.'}]}]}

In [11]:
converted_train_dataset = [process_conversation(sample) for sample in train_dataset]
converted_train_dataset[0]

{'conversations': [{'role': 'user',
   'content': [{'type': 'text',
     'text': 'Tìm x biết $x + 3,8 = 3,5 \\times 1,5$'}]},
  {'role': 'assistant',
   'content': [{'type': 'text',
     'text': 'x + 3,8 = 3, 5 x 1, 5\nx + 3,8 = 5, 25\nx = 5, 25 - 3,8\nx = 1, 45'}]}]}

In [12]:
  converted_train_dataset[1]

{'conversations': [{'role': 'user',
   'content': [{'type': 'text',
     'text': 'Hình dưới có bao nhiêu hình vuông?\n '},
    {'type': 'image',
     'image': <PIL.Image.Image image mode=RGB size=165x132>}]},
  {'role': 'assistant',
   'content': [{'type': 'text',
     'text': 'Số hình vuông được tạo từ một ô vuông là 16 (hình)\nSố hình vuông được tạo từ 4 ô vuông là 9 (hình)\nSố hình vuông được tạo từ 9 ô vuông là 4 (hình)\nSố hình vuông được tại từ 16 ô vuông là 1 (hình)\n$ \\Rightarrow $ Có tất cả 16 + 9 + 4 + 1 = 30 (hình)'}]}]}

In [13]:
from unsloth import get_chat_template

processor = get_chat_template(
    processor,
    "gemma-3n"
)

## Training

In [14]:
import os
from google.colab import userdata

# Load the Comet.ml API key from Colab secrets
comet_ml_api_key = userdata.get('COMET_API_KEY')

# Set the COMET_API_KEY environment variable
os.environ['COMET_API_KEY'] = comet_ml_api_key

In [15]:
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

FastVisionModel.for_training(model) # Enable for training!

trainer = SFTTrainer(
    model=model,
    train_dataset=converted_train_dataset,
    processing_class=processor.tokenizer,
    data_collator=UnslothVisionDataCollator(model, processor),
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        gradient_checkpointing = True,

        # use reentrant checkpointing
        gradient_checkpointing_kwargs = {"use_reentrant": False},
        max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper
        warmup_ratio = 0.03,
        max_steps = 60,
        #num_train_epochs = 2,          # Set this instead of max_steps for full training runs
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "comet_ml",             # For Weights and Biases

        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
    )
)

Unsloth: Model does not have a default image size - using 512


In [16]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,010 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 76,840,960 of 7,926,819,152 (0.97% trained)
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/mathpal/general/19460f5fcaf44a248b80b2ca65a104a5

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/content' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


ValueError: Invalid input type. Must be a single image, a list of images, or a list of batches of images.

In [None]:
# Debug và test các phương pháp fix lỗi
import torch
from datasets import Dataset

# 1. Kiểm tra processor có đúng không
print("Processor type:", type(processor))
print("Processor tokenizer:", type(processor.tokenizer))
print("Has image processor:", hasattr(processor, 'image_processor'))

# 2. Test manual collate function
def debug_collate_fn(examples):
    print(f"Input examples type: {type(examples)}")
    print(f"Number of examples: {len(examples)}")
    
    if len(examples) > 0:
        print(f"First example keys: {examples[0].keys()}")
        
    # Test từng bước
    try:
        # Bước 1: Extract conversations
        conversations_list = [example["conversations"] for example in examples]
        print(f"Extracted {len(conversations_list)} conversations")
        
        # Bước 2: Test apply_chat_template
        texts = []
        images_list = []
        
        for conv in conversations_list:
            try:
                # Test chat template
                text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
                texts.append(text)
                print(f"Chat template success, text length: {len(text)}")
                
                # Extract images
                images = []
                for message in conv:
                    for content in message.get("content", []):
                        if content.get("type") == "image" and "image" in content:
                            images.append(content["image"])
                
                images_list.append(images)
                print(f"Extracted {len(images)} images from conversation")
                
            except Exception as e:
                print(f"Error processing conversation: {e}")
                return None
        
        # Bước 3: Test processor
        print("Testing processor with texts and images...")
        batch = processor(
            text=texts,
            images=images_list,
            return_tensors="pt",
            padding=True
        )
        
        print("Processor success!")
        print("Batch keys:", batch.keys())
        for key, value in batch.items():
            if hasattr(value, 'shape'):
                print(f"{key} shape: {value.shape}")
        
        # Tạo labels
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        
        return batch
        
    except Exception as e:
        print(f"Error in collate function: {e}")
        import traceback
        traceback.print_exc()
        return None

# 3. Test với 1 sample
print("\n=== Testing with single sample ===")
if isinstance(converted_train_dataset, list):
    converted_train_dataset = Dataset.from_list(converted_train_dataset)

test_batch = [converted_train_dataset[0]]
result = debug_collate_fn(test_batch)

# 4. Kiểm tra thêm một vài samples khác để tìm sample có images
print("\n=== Checking multiple samples for images ===")
for i in range(min(5, len(converted_train_dataset))):
    sample = converted_train_dataset[i]
    image_count = 0
    for conv in sample["conversations"]:
        for content in conv.get("content", []):
            if content.get("type") == "image":
                image_count += 1
    print(f"Sample {i}: {image_count} images")
    
    if image_count > 0:
        print(f"Testing sample {i} with images...")
        try:
            result_with_images = debug_collate_fn([sample])
            print(f"Sample {i} processed successfully!")
            break
        except Exception as e:
            print(f"Sample {i} failed: {e}")

print("\n=== Root cause analysis ===")
print("Vấn đề: Processor nhận empty list images_list = [[]] thay vì format hợp lệ")
print("Giải pháp: Cần handle trường hợp empty images đúng cách")


In [None]:
# Phương pháp fix 1: Thử custom collate function thay vì UnslothVisionDataCollator
def custom_vision_collate_fn(examples):
    """Custom collate function để thay thế UnslothVisionDataCollator"""
    try:
        # Extract conversations
        conversations_list = [example["conversations"] for example in examples]
        
        # Apply chat template để tạo text
        texts = []
        images_list = []
        
        for conv in conversations_list:
            # Tạo text từ conversation
            text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
            texts.append(text)
            
            # Extract images từ conversation
            images = []
            for message in conv:
                for content in message.get("content", []):
                    if content.get("type") == "image" and "image" in content:
                        img = content["image"]
                        if img is not None:
                            images.append(img)
            
            # Đảm bảo có ít nhất 1 image (có thể là placeholder)
            if not images:
                # Tạo placeholder image nếu không có image
                from PIL import Image
                placeholder = Image.new('RGB', (224, 224), color='white')
                images = [placeholder]
            
            images_list.append(images)
        
        # Process với processor
        batch = processor(
            text=texts,
            images=images_list,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )
        
        # Tạo labels
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        
        return batch
        
    except Exception as e:
        print(f"Error in custom collate function: {e}")
        import traceback
        traceback.print_exc()
        raise e

# Test custom collate function
print("\n=== Testing custom collate function ===")
try:
    test_result = custom_vision_collate_fn([converted_train_dataset[0]])
    print("Custom collate function works!")
    print("Keys:", test_result.keys())
    for key, value in test_result.items():
        if hasattr(value, 'shape'):
            print(f"{key} shape: {value.shape}")
except Exception as e:
    print(f"Custom collate function failed: {e}")

# Test với multiple samples
print("\n=== Testing custom collate function with multiple samples ===")
try:
    # Test với 2 samples
    test_batch_multi = [converted_train_dataset[0], converted_train_dataset[1]]
    test_result_multi = custom_vision_collate_fn(test_batch_multi)
    print("Custom collate function works with multiple samples!")
    print("Keys:", test_result_multi.keys())
    for key, value in test_result_multi.items():
        if hasattr(value, 'shape'):
            print(f"{key} shape: {value.shape}")
except Exception as e:
    print(f"Custom collate function with multiple samples failed: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Phương pháp fix 2: Thử với trainer sử dụng custom collate function
print("\n=== Testing trainer with custom collate function ===")

try:
    # Tạo trainer mới với custom collate function
    trainer_custom = SFTTrainer(
        model=model,
        train_dataset=converted_train_dataset,
        processing_class=processor.tokenizer,
        data_collator=custom_vision_collate_fn,  # Thay UnslothVisionDataCollator
        args = SFTConfig(
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 4,
            gradient_checkpointing = True,
            gradient_checkpointing_kwargs = {"use_reentrant": False},
            max_grad_norm = 0.3,
            warmup_ratio = 0.03,
            max_steps = 3,  # Chỉ test 3 steps
            learning_rate = 2e-4,
            logging_steps = 1,
            save_strategy="no",  # Không save khi test
            optim = "adamw_torch_fused",
            weight_decay = 0.01,
            lr_scheduler_type = "cosine",
            seed = 3407,
            output_dir = "outputs",
            report_to = "none",  # Tắt logging khi test
            
            # Vision finetuning requirements
            remove_unused_columns = False,
            dataset_text_field = "",
            dataset_kwargs = {"skip_prepare_dataset": True},
            max_length = 2048,
        )
    )
    
    print("Trainer with custom collate function created successfully!")
    
    # Test training
    print("Starting test training...")
    trainer_stats = trainer_custom.train()
    print("Training with custom collate function successful!")
    
except Exception as e:
    print(f"Error with custom collate function trainer: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Phương pháp fix 3: Kiểm tra và fix UnslothVisionDataCollator
print("\n=== Debugging UnslothVisionDataCollator ===")

# Tạo instance của UnslothVisionDataCollator và test
try:
    unsloth_collator = UnslothVisionDataCollator(model, processor)
    print("UnslothVisionDataCollator created successfully")
    
    # Test với 1 sample
    print("Testing UnslothVisionDataCollator with 1 sample...")
    test_batch_unsloth = unsloth_collator([converted_train_dataset[0]])
    print("UnslothVisionDataCollator works!")
    print("Keys:", test_batch_unsloth.keys())
    
except Exception as e:
    print(f"UnslothVisionDataCollator failed: {e}")
    import traceback
    traceback.print_exc()
    
    print("\n=== Trying alternative approach ===")
    # Nếu UnslothVisionDataCollator fail, thử approach khác
    
    # 1. Kiểm tra processor configuration
    print("Processor configuration:")
    print(f"Tokenizer type: {type(processor.tokenizer)}")
    print(f"Tokenizer vocab size: {processor.tokenizer.vocab_size}")
    
    if hasattr(processor, 'image_processor'):
        print(f"Image processor type: {type(processor.image_processor)}")
    else:
        print("No image processor found!")
        
    # 2. Thử manual fix cho UnslothVisionDataCollator
    class FixedUnslothVisionDataCollator:
        def __init__(self, model, processor):
            self.model = model
            self.processor = processor
            
        def __call__(self, examples):
            try:
                # Process giống như custom function
                conversations_list = [example["conversations"] for example in examples]
                
                texts = []
                images_list = []
                
                for conv in conversations_list:
                    text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
                    texts.append(text)
                    
                    images = []
                    for message in conv:
                        for content in message.get("content", []):
                            if content.get("type") == "image" and "image" in content:
                                img = content["image"]
                                if img is not None:
                                    # Đảm bảo image ở định dạng đúng
                                    if hasattr(img, 'convert'):
                                        img = img.convert('RGB')
                                    images.append(img)
                    
                    # Handle trường hợp không có image
                    if not images:
                        from PIL import Image
                        placeholder = Image.new('RGB', (224, 224), color='white')
                        images = [placeholder]
                    
                    images_list.append(images)
                
                # Xử lý với processor
                batch = self.processor(
                    text=texts,
                    images=images_list,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=2048
                )
                
                # Tạo labels
                labels = batch["input_ids"].clone()
                labels[labels == self.processor.tokenizer.pad_token_id] = -100
                batch["labels"] = labels
                
                return batch
                
            except Exception as e:
                print(f"FixedUnslothVisionDataCollator error: {e}")
                raise e
    
    # Test fixed collator
    print("\nTesting FixedUnslothVisionDataCollator...")
    fixed_collator = FixedUnslothVisionDataCollator(model, processor)
    test_fixed = fixed_collator([converted_train_dataset[0]])
    print("FixedUnslothVisionDataCollator works!")
    print("Keys:", test_fixed.keys())


In [None]:
# GIẢI PHÁP CUỐI CÙNG: Trainer với custom collate function đã fix
print("\n=== FINAL SOLUTION: Training với fixed custom collate function ===")

# Đảm bảo model đã được enable cho training
FastVisionModel.for_training(model)

# Tạo trainer với custom collate function
trainer_final = SFTTrainer(
    model=model,
    train_dataset=converted_train_dataset,
    processing_class=processor.tokenizer,
    data_collator=custom_vision_collate_fn,  # Sử dụng custom function đã fix placeholder
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        gradient_checkpointing = True,
        gradient_checkpointing_kwargs = {"use_reentrant": False},
        max_grad_norm = 0.3,
        warmup_ratio = 0.03,
        max_steps = 60,  # Training đầy đủ
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        save_steps=20,
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "comet_ml",
        
        # Vision finetuning requirements
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
    )
)

print("Final trainer created successfully!")
print("Ready to start training with fixed collate function.")
print("Run: trainer_final.train() to start training")


In [None]:
# FIX: Custom collate function để handle image tokens đúng cách
def fixed_vision_collate_fn(examples):
    """Fixed collate function để đồng bộ image tokens và images"""
    try:
        conversations_list = [example["conversations"] for example in examples]
        
        texts = []
        images_list = []
        
        for conv in conversations_list:
            # Extract images trước
            images = []
            for message in conv:
                for content in message.get("content", []):
                    if content.get("type") == "image" and "image" in content:
                        img = content["image"]
                        if img is not None:
                            if hasattr(img, 'convert'):
                                img = img.convert('RGB')
                            images.append(img)
            
            # Tạo text từ conversation - ĐÂY LÀ KEY FIX
            if images:
                # Nếu có images, apply chat template bình thường
                text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
            else:
                # Nếu không có images, tạo placeholder image và modify conversation
                from PIL import Image
                placeholder = Image.new('RGB', (224, 224), color='white')
                images = [placeholder]
                
                # Thêm image token vào conversation để đồng bộ
                modified_conv = []
                for message in conv:
                    if message["role"] == "user":
                        # Thêm image token vào user message
                        modified_content = [{"type": "image"}] + message["content"]
                        modified_message = {
                            "role": "user",
                            "content": modified_content
                        }
                        modified_conv.append(modified_message)
                    else:
                        modified_conv.append(message)
                
                text = processor.apply_chat_template(modified_conv, tokenize=False, add_generation_prompt=False)
            
            texts.append(text)
            images_list.append(images)
        
        # Debug: in ra số image tokens trong text
        for i, text in enumerate(texts):
            image_token_count = text.count("<image>") if "<image>" in text else 0
            print(f"Sample {i}: {len(images_list[i])} images, {image_token_count} image tokens in text")
        
        # Process với processor
        batch = processor(
            text=texts,
            images=images_list,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )
        
        # Tạo labels
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        
        return batch
        
    except Exception as e:
        print(f"Error in fixed collate function: {e}")
        import traceback
        traceback.print_exc()
        raise e

# Test fixed function
print("=== Testing fixed collate function ===")
try:
    test_fixed = fixed_vision_collate_fn([converted_train_dataset[0]])
    print("Fixed collate function works!")
    print("Keys:", test_fixed.keys())
    for key, value in test_fixed.items():
        if hasattr(value, 'shape'):
            print(f"{key} shape: {value.shape}")
except Exception as e:
    print(f"Fixed collate function failed: {e}")


In [None]:
# Phương pháp thay thế: Filter dataset chỉ lấy samples có images
print("=== Alternative: Filter dataset to only include samples with images ===")

def has_images(example):
    """Check if conversation has images"""
    for conv in example["conversations"]:
        for content in conv.get("content", []):
            if content.get("type") == "image" and "image" in content and content["image"] is not None:
                return True
    return False

# Filter dataset
filtered_dataset = converted_train_dataset.filter(has_images)
print(f"Original dataset size: {len(converted_train_dataset)}")
print(f"Filtered dataset size (with images): {len(filtered_dataset)}")

if len(filtered_dataset) > 0:
    print("Testing with filtered dataset...")
    
    # Tạo simple collate function cho samples có images
    def simple_vision_collate_fn(examples):
        """Simple collate function cho samples đã có images"""
        try:
            conversations_list = [example["conversations"] for example in examples]
            
            texts = []
            images_list = []
            
            for conv in conversations_list:
                # Apply chat template
                text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
                texts.append(text)
                
                # Extract images
                images = []
                for message in conv:
                    for content in message.get("content", []):
                        if content.get("type") == "image" and "image" in content:
                            img = content["image"]
                            if img is not None:
                                if hasattr(img, 'convert'):
                                    img = img.convert('RGB')
                                images.append(img)
                
                images_list.append(images)
            
            # Process với processor
            batch = processor(
                text=texts,
                images=images_list,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=2048
            )
            
            # Tạo labels
            labels = batch["input_ids"].clone()
            labels[labels == processor.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
            
            return batch
            
        except Exception as e:
            print(f"Error in simple collate function: {e}")
            import traceback
            traceback.print_exc()
            raise e
    
    # Test với filtered dataset
    try:
        test_simple = simple_vision_collate_fn([filtered_dataset[0]])
        print("Simple collate function works with filtered dataset!")
        print("Keys:", test_simple.keys())
        for key, value in test_simple.items():
            if hasattr(value, 'shape'):
                print(f"{key} shape: {value.shape}")
    except Exception as e:
        print(f"Simple collate function failed: {e}")
        
else:
    print("No samples with images found in dataset!")


In [None]:
# TRAINER CUỐI CÙNG - Chọn 1 trong 2 phương pháp dưới đây

print("=== CHOOSE YOUR TRAINING APPROACH ===")
print("1. Use fixed_vision_collate_fn (handles all samples, adds placeholders)")
print("2. Use filtered dataset + simple_vision_collate_fn (only image samples)")

# OPTION 1: Trainer với fixed collate function (xử lý tất cả samples)
print("\n=== OPTION 1: Trainer with fixed collate function ===")
try:
    trainer_option1 = SFTTrainer(
        model=model,
        train_dataset=converted_train_dataset,  # Full dataset
        processing_class=processor.tokenizer,
        data_collator=fixed_vision_collate_fn,  # Fixed function
        args = SFTConfig(
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 4,
            gradient_checkpointing = True,
            gradient_checkpointing_kwargs = {"use_reentrant": False},
            max_grad_norm = 0.3,
            warmup_ratio = 0.03,
            max_steps = 60,
            learning_rate = 2e-4,
            logging_steps = 1,
            save_strategy="steps",
            save_steps=20,
            optim = "adamw_torch_fused",
            weight_decay = 0.01,
            lr_scheduler_type = "cosine",
            seed = 3407,
            output_dir = "outputs",
            report_to = "comet_ml",
            remove_unused_columns = False,
            dataset_text_field = "",
            dataset_kwargs = {"skip_prepare_dataset": True},
            max_length = 2048,
        )
    )
    print("Option 1 trainer created successfully!")
except Exception as e:
    print(f"Option 1 failed: {e}")

# OPTION 2: Trainer với filtered dataset (chỉ samples có images)
if 'filtered_dataset' in locals() and len(filtered_dataset) > 0:
    print("\n=== OPTION 2: Trainer with filtered dataset ===")
    try:
        trainer_option2 = SFTTrainer(
            model=model,
            train_dataset=filtered_dataset,  # Filtered dataset
            processing_class=processor.tokenizer,
            data_collator=simple_vision_collate_fn,  # Simple function
            args = SFTConfig(
                per_device_train_batch_size = 1,
                gradient_accumulation_steps = 4,
                gradient_checkpointing = True,
                gradient_checkpointing_kwargs = {"use_reentrant": False},
                max_grad_norm = 0.3,
                warmup_ratio = 0.03,
                max_steps = 60,
                learning_rate = 2e-4,
                logging_steps = 1,
                save_strategy="steps",
                save_steps=20,
                optim = "adamw_torch_fused",
                weight_decay = 0.01,
                lr_scheduler_type = "cosine",
                seed = 3407,
                output_dir = "outputs",
                report_to = "comet_ml",
                remove_unused_columns = False,
                dataset_text_field = "",
                dataset_kwargs = {"skip_prepare_dataset": True},
                max_length = 2048,
            )
        )
        print("Option 2 trainer created successfully!")
    except Exception as e:
        print(f"Option 2 failed: {e}")

print("\n=== RECOMMENDATIONS ===")
print("- If you want to use ALL data: Use trainer_option1.train()")
print("- If you want ONLY image samples: Use trainer_option2.train()")
print("- Option 2 is recommended if most of your dataset has images")


In [None]:
# ULTIMATE FIX: Disable gradient checkpointing và fix image token issue
print("=== ULTIMATE FIX: Disable gradient checkpointing ===")

# 1. Kiểm tra chat template để tìm image token
print("Debug chat template:")
sample_conv = converted_train_dataset[0]["conversations"]
template_result = processor.apply_chat_template(sample_conv, tokenize=False, add_generation_prompt=False)
print("Template result:", template_result[:200] + "...")
print("Contains <image>:", "<image>" in template_result)

# 2. Kiểm tra processor special tokens
print("\nProcessor special tokens:")
print("Special tokens map:", processor.tokenizer.special_tokens_map)
if hasattr(processor.tokenizer, 'boi_token'):
    print("BOI token:", processor.tokenizer.boi_token)
if hasattr(processor.tokenizer, 'eoi_token'):
    print("EOI token:", processor.tokenizer.eoi_token)

# 3. Tạo final collate function với explicit image token handling
def ultimate_vision_collate_fn(examples):
    """Ultimate fix cho image token và gradient checkpointing issues"""
    try:
        conversations_list = [example["conversations"] for example in examples]
        
        texts = []
        images_list = []
        
        for conv in conversations_list:
            # Extract images
            images = []
            has_real_images = False
            
            for message in conv:
                for content in message.get("content", []):
                    if content.get("type") == "image" and "image" in content:
                        img = content["image"]
                        if img is not None:
                            if hasattr(img, 'convert'):
                                img = img.convert('RGB')
                            images.append(img)
                            has_real_images = True
            
            # Handle text generation
            if has_real_images:
                # Có images thật - apply template bình thường
                text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
            else:
                # Không có images - tạo placeholder và modify conversation
                from PIL import Image
                placeholder = Image.new('RGB', (224, 224), color='white')
                images = [placeholder]
                
                # Manually insert image token vào text
                text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
                
                # Find the right place to insert <image> token
                # Thường sẽ insert ở đầu user message
                user_start = text.find("user")
                if user_start != -1:
                    # Find end of "user" và insert image token
                    insert_pos = text.find("\n", user_start)
                    if insert_pos != -1:
                        text = text[:insert_pos+1] + "<image>\n" + text[insert_pos+1:]
                    else:
                        text = text + "\n<image>"
                else:
                    # Fallback: add at beginning
                    text = "<image>\n" + text
            
            texts.append(text)
            images_list.append(images)
        
        # Debug print
        for i, text in enumerate(texts):
            image_token_count = text.count("<image>")
            print(f"Sample {i}: {len(images_list[i])} images, {image_token_count} image tokens")
        
        # Process với processor
        batch = processor(
            text=texts,
            images=images_list,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )
        
        # Tạo labels
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        
        return batch
        
    except Exception as e:
        print(f"Error in ultimate collate function: {e}")
        import traceback
        traceback.print_exc()
        raise e

# Test ultimate function
print("\n=== Testing ultimate collate function ===")
try:
    test_ultimate = ultimate_vision_collate_fn([converted_train_dataset[0]])
    print("Ultimate collate function works!")
    print("Keys:", test_ultimate.keys())
    for key, value in test_ultimate.items():
        if hasattr(value, 'shape'):
            print(f"{key} shape: {value.shape}")
except Exception as e:
    print(f"Ultimate collate function failed: {e}")


In [None]:
# FINAL WORKING TRAINER: Tắt gradient checkpointing để tránh lỗi
print("=== FINAL WORKING TRAINER ===")

FastVisionModel.for_training(model)

# Trainer với gradient checkpointing TẮT
trainer_final_working = SFTTrainer(
    model=model,
    train_dataset=converted_train_dataset,
    processing_class=processor.tokenizer,
    data_collator=ultimate_vision_collate_fn,  # Ultimate fixed function
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        
        # TUYỆT ĐỐI KHÔNG dùng gradient checkpointing
        gradient_checkpointing = False,  # TẮT để tránh CheckpointError
        
        max_grad_norm = 0.3,
        warmup_ratio = 0.03,
        max_steps = 60,
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        save_steps=20,
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "comet_ml",
        
        # Vision finetuning requirements
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
    )
)

print("Final working trainer created successfully!")
print("Key changes:")
print("- gradient_checkpointing = False (tránh CheckpointError)")
print("- ultimate_vision_collate_fn (fix image tokens)")
print("- Ready to train: trainer_final_working.train()")


In [None]:
# ALTERNATIVE: Nếu vẫn gặp vấn đề, thử với UnslothVisionDataCollator + no gradient checkpointing
print("=== ALTERNATIVE: UnslothVisionDataCollator without gradient checkpointing ===")

# Test lại UnslothVisionDataCollator với gradient checkpointing tắt
try:
    trainer_unsloth_no_gc = SFTTrainer(
        model=model,
        train_dataset=filtered_dataset if 'filtered_dataset' in locals() and len(filtered_dataset) > 0 else converted_train_dataset,
        processing_class=processor.tokenizer,
        data_collator=UnslothVisionDataCollator(model, processor),  # Original UnslothVisionDataCollator
        args = SFTConfig(
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 4,
            
            # TẮT gradient checkpointing
            gradient_checkpointing = False,
            
            max_grad_norm = 0.3,
            warmup_ratio = 0.03,
            max_steps = 60,
            learning_rate = 2e-4,
            logging_steps = 1,
            save_strategy="steps",
            save_steps=20,
            optim = "adamw_torch_fused",
            weight_decay = 0.01,
            lr_scheduler_type = "cosine",
            seed = 3407,
            output_dir = "outputs",
            report_to = "comet_ml",
            
            # Vision requirements
            remove_unused_columns = False,
            dataset_text_field = "",
            dataset_kwargs = {"skip_prepare_dataset": True},
            max_length = 2048,
        )
    )
    print("Alternative trainer with UnslothVisionDataCollator created!")
    print("- Uses filtered dataset if available")
    print("- gradient_checkpointing = False")
    print("- Use: trainer_unsloth_no_gc.train()")
    
except Exception as e:
    print(f"Alternative trainer failed: {e}")

print("\n=== FINAL RECOMMENDATIONS ===")
print("1. FIRST TRY: trainer_final_working.train() - custom collate + no gradient checkpointing")
print("2. IF FAILS: trainer_unsloth_no_gc.train() - original UnslothVisionDataCollator + no gradient checkpointing")
print("3. Key issue was gradient_checkpointing = True causing tensor shape conflicts")
print("4. Memory usage will be higher without gradient checkpointing, but should work")
