In [None]:
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# /// script
# dependencies = [
#     "trl @ git+https://github.com/huggingface/trl.git",
#     "Pillow>=9.4.0",
# ]
# ///

"""
pip install pillow

# Tested on 8x H100 GPUs
accelerate launch
    --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/sft_vlm.py \
    --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 8 \
    --output_dir sft-llava-1.5-7b-hf \
    --bf16 \
    --torch_dtype bfloat16 \
    --gradient_checkpointing

For LLaVA-NeXT, use: (requires transformers>=4.45)
    --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf

For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1)
    --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct
"""

import torch
from datasets import load_from_disk
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration
import os
from PIL import Image
import io
import json
from transformers import EarlyStoppingCallback
import trl 
                             

from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)

device = torch.device("cpu")

error_cnt  = 0 

if __name__ == "__main__":
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config(args=[])

    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    training_args.remove_unused_columns = False
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}

    ################
    # Model, Tokenizer & Processor
    ################
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    processor = AutoProcessor.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )

    model = AutoModelForVision2Seq.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
    )

    ################
    # Create a data collator to encode text and image pairs
    ################
    RESPONSE_MARKER = "<|startofassistant|>"

    def extract_image(example):
        try:
            messages = example["messages"]
            if isinstance(messages, str):
                messages = json.loads(messages) 

            for message in messages:
                if message["role"] == "user":
                    for item in message["content"]:
                        if item["type"] == "image":
                            img_data = item["image"]
                            # Handle the image data (path, PIL.Image, or dict with bytes)
                            if isinstance(img_data, dict) and "bytes" in img_data:
                                img = Image.open(io.BytesIO(img_data["bytes"])).convert("RGB")
                                return img
                            elif isinstance(img_data, str):
                                img = Image.open(img_data).convert("RGB")
                                return img
                            elif hasattr(img_data, "size"):
                                return img_data
                            else:
                                return None
            return None
        except Exception as e:
            print(f"Error extracting image: {e}")
            return None

    def filter_long_examples(example):
        try:
            messages = deserialize_messages(example["messages"])
            if hasattr(processor.tokenizer, "chat_template") and processor.tokenizer.chat_template:
                prompt = processor.apply_chat_template(messages, tokenize=False)
            else:
                prompt = "\n".join(f"<|{m['role']}|> {m['content']}" for m in messages)

            tokenized = processor.tokenizer(prompt, truncation=False, padding=False)
            return len(tokenized["input_ids"]) <= 2100
        except Exception as e:
            print(f"Filter error: {e}")
            return False  # Fail-safe: filter out if any error
        
    def deserialize_messages(messages):
        if isinstance(messages, str):
            return json.loads(messages)
        return messages
        
    def truncate_chat_template(template, max_chars=6000):
        if len(template) <= max_chars:
            return template
        return template[:max_chars]

    def collate_fn(examples):
    
        messages_batch = []
        for example in examples:
            messages = deserialize_messages(example["messages"])
            for m in messages:
                if m["role"] == "assistant" and isinstance(m["content"], list) and len(m["content"]) == 1:
                    m["content"][0]["text"] = RESPONSE_MARKER + m["content"][0]["text"]
            if hasattr(processor.tokenizer, "chat_template") and processor.tokenizer.chat_template:
                prompt = processor.apply_chat_template(messages, tokenize=False)
            else:
                prompt = "\n".join(f"<|{m['role']}|> {m['content']}" for m in messages)

            if RESPONSE_MARKER in prompt:
                messages_batch.append(prompt)
            else:
                error_cnt += 1

        if len(messages_batch) == 0:
            return None

        images = [extract_image(example) for example in examples]

        batch = processor(images=images, text=messages_batch, return_tensors="pt", padding=True, truncation=True)

        input_lengths = batch["input_ids"].ne(processor.tokenizer.pad_token_id).sum(dim=1)
        # Truncate token-level to max 8192 tokens
        max_length = 8192
        batch["input_ids"] = batch["input_ids"][:, :max_length]
        batch["attention_mask"] = batch["attention_mask"][:, :max_length]

        input_ids = batch["input_ids"]
        labels = input_ids.clone()

        # Find response marker in each sequence
        marker_token_ids = processor.tokenizer.encode(RESPONSE_MARKER, add_special_tokens=False)
        marker_len = len(marker_token_ids)
        

        
        for i, ids in enumerate(input_ids):
        # Try to find the marker
            marker_found = False
            for j in range(len(ids) - marker_len + 1):
                if torch.equal(ids[j:j+marker_len], torch.tensor(marker_token_ids, device=ids.device)):
                    labels[i, :j+marker_len] = -100  # Mask everything before and including the marker
                    marker_found = True
                    break
        
            if not marker_found:
                # If marker not found, mask everything
                labels[i] = -100
                #print(f"Warning: Response marker not found in sequence {i}")

        
        print(batch)
        return batch


    ################
    # Dataset
    ################
    dataset = load_from_disk("/mnt/3de36453-6164-4568-91b5-ae973509273e/Git/EE-Gothic-Script-OCR/src/datasets/gold/block_dataset")

    # Filter the dataset
    dataset = dataset.filter(filter_long_examples)
    eval_dataset_full = dataset[script_args.dataset_test_split]
    #eval_dataset = eval_dataset_full.select(range(min(5, len(eval_dataset_full))))
    eval_dataset = eval_dataset_full

    ################
    # Training
    ################
    training_args.metric_for_best_model = "eval_loss"
    training_args.load_best_model_at_end = True

    #trainer = SFTTrainer(
    #    model=model,
    #    args=training_args,
    #    data_collator=collate_fn,
    #    train_dataset=dataset[script_args.dataset_train_split],
    #    eval_dataset=eval_dataset if training_args.eval_strategy != "no" else None,
    #    processing_class=processor,
    #    peft_config=get_peft_config(model_args),
    #    
    #    
    #)
    #
    #trainer.add_callback(EarlyStoppingCallback(
    #early_stopping_patience=3,
    #early_stopping_threshold=0.001))