In [None]:
import torch
from transformers import AutoModelForImageTextToText, TrainingArguments, Trainer, AutoProcessor,  BitsAndBytesConfig
from datasets import Dataset, load_dataset
from PIL import Image
import json
import os
from torch.utils.data import DataLoader
import numpy as np
from typing import Dict, List, Any
import logging
from data import *

In [None]:
# Setup logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)



def create_amazigh_alphabet_prompt():
    """Create a system prompt that includes Amazigh alphabet information"""
    return """You are an expert OCR system specialized in recognizing Amazigh (Berber) text. 
    You can read text in multiple scripts used for Amazigh languages:
    
    1. Tifinagh script: ⴰⴱⴳⴷⴹⴻⴼⴳⵀⵃⵄⵅⵇⵈⵉⵊⵋⵍⵎⵏⵓⵔⵕⵖⵗⵘⵙⵚⵛⵜⵝⵞⵟⵠⵡⵢⵣⵤⵥ
    
    2. Latin script adaptations for Amazigh languages (with special characters like ɣ, ḥ, ṛ, ṣ, ṭ, ẓ)
    
    3. Arabic script adaptations used in some Amazigh communities
    
    Extract all text accurately, preserving the original script and diacritical marks."""

In [None]:
#!/usr/bin/env python3

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

def setup_model_and_processor(model_name: str = "Qwen/Qwen2-VL-2B"):
    """Setup the model and processor with LoRA + 4-bit quantization"""
    
    processor = AutoProcessor.from_pretrained(model_name)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4"
    )
    
    model = AutoModelForImageTextToText.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )

    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return model, processor

In [None]:
logger.info("Setting up model and processor...")
model, processor = setup_model_and_processor()

logger.info("Loading dataset...")
dataset = AmazighOCRDataset("images", processor)

INFO:__main__:Setting up model and processor...
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.16s/it]
INFO:__main__:L

trainable params: 2,179,072 || all params: 2,211,164,672 || trainable%: 0.0985


In [None]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

In [8]:
def create_data_collator(processor):
    """Create a custom data collator for VL models"""
    
    def collate_fn(batch):
        # Extract components
        input_ids = [item['input_ids'] for item in batch]
        attention_masks = [item['attention_mask'] for item in batch]
        pixel_values = [item['pixel_values'] for item in batch]
        labels = [item['labels'] for item in batch]
        
        # Pad sequences
        max_len = max(len(ids) for ids in input_ids)
        
        padded_input_ids = []
        padded_attention_masks = []
        padded_labels = []
        
        for i in range(len(batch)):
            pad_len = max_len - len(input_ids[i])
            
            padded_input_ids.append(
                torch.cat([input_ids[i], torch.full((pad_len,), processor.tokenizer.pad_token_id)])
            )
            padded_attention_masks.append(
                torch.cat([attention_masks[i], torch.zeros(pad_len)])
            )
            padded_labels.append(
                torch.cat([labels[i], torch.full((pad_len,), -100)])  # -100 for ignored tokens
            )
        
        return {
            'input_ids': torch.stack(padded_input_ids).long(),
            'attention_mask': torch.stack(padded_attention_masks),
            'pixel_values': torch.stack(pixel_values),
            'labels': torch.stack(padded_labels)
        }
    
    return collate_fn

def train_amazigh_ocr_model(
    data_path: str,
    output_dir: str = "./amazigh-ocr-qwen2vl",
    num_epochs: int = 3,
    batch_size: int = 2,
    learning_rate: float = 5e-5,
    save_steps: int = 500
):
    """Main training function"""
    
    logger.info("Setting up model and processor...")
    model, processor = setup_model_and_processor()
    
    logger.info("Loading dataset...")
    dataset = AmazighOCRDataset(data_path, processor)
    
    # Split dataset (80/20 train/val)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    # Setup training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=4,  # Effective batch size = batch_size * 4
        warmup_steps=50,
        learning_rate=learning_rate,
        fp16=False,  # Mixed precision training
        logging_steps=50,
        save_steps=save_steps,
        eval_steps=save_steps,
        eval_strategy="steps",
        save_strategy="steps",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        report_to=None,  # Disable wandb/tensorboard
        dataloader_pin_memory=False,
        remove_unused_columns=False,
    )
    
    # Create data collator
    data_collator = create_data_collator(processor)
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=processor.tokenizer,
    )
    
    logger.info("Starting training...")
    trainer.train()
    
    logger.info("Saving final model...")
    trainer.save_model()
    processor.save_pretrained(output_dir)
    
    return trainer

def prepare_dataset_example():
    """Example function showing how to prepare your dataset"""
    
    example_data = [
        {
            "image_path": "/path/to/tifinagh_text_image.jpg",
            "text": "ⴰⵙⵉⵡⵍ ⵏ ⵜⵎⴰⵣⵉⵖⵜ"  # Example Tifinagh text
        },
        {
            "image_path": "/path/to/latin_amazigh_image.jpg", 
            "text": "tamaziɣt n umaziɣ"  # Example Latin script Amazigh
        },
        {
            "image_path": "/path/to/arabic_amazigh_image.jpg",
            "text": "تامازيغت"  # Example Arabic script adaptation
        }
    ]
    
    # Save as JSON
    with open('amazigh_ocr_dataset.json', 'w', encoding='utf-8') as f:
        json.dump(example_data, f, ensure_ascii=False, indent=2)
    
    print("Example dataset created: amazigh_ocr_dataset.json")
    print("Replace with your actual image paths and corresponding text")

def inference_example(model_path: str, image_path: str):
    """Example inference function"""
    
    # Load fine-tuned model
    model = AutoModelForImageTextToText.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    processor = AutoProcessor.from_pretrained(model_path)
    
    # Load image
    image = Image.open(image_path).convert('RGB')
    
    # Prepare input
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": "Extract all text from this image, including Amazigh text in Tifinagh, Latin, or Arabic scripts:"}
            ]
        }
    ]
    
    text = processor.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    inputs = processor(
        text=[text],
        images=[image],
        padding=True,
        return_tensors="pt"
    ).to(model.device)
    
    # Generate
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            temperature=0.0
        )
    
    # Decode output
    generated_text = processor.batch_decode(
        output[:, inputs['input_ids'].shape[1]:], 
        skip_special_tokens=True
    )[0]
    
    return generated_text



In [None]:
print("Starting training...")
trainer = train_amazigh_ocr_model(
    data_path="images",
    output_dir="./amazigh-ocr-qwen2vl-f inetuned",
    num_epochs=5,
    batch_size=1,  # Reduce if memory issues
    learning_rate=1e-5
)

INFO:__main__:Setting up model and processor...


Starting training...


INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
result = inference_example(
    model_path="./amazigh-ocr-qwen2vl-finetuned",
    image_path="path/to/test_image.jpg"
)

In [None]:
if __name__ == "__main__":
    # Example usage
    
    # 1. Prepare example dataset structure
    print("Creating example dataset structure...")
    prepare_dataset_example()
    
    # 2. Train the model (uncomment when you have real data)
    # print("Starting training...")
    # trainer = train_amazigh_ocr_model(
    #     data_path="amazigh_ocr_dataset.json",
    #     output_dir="./amazigh-ocr-qwen2vl-finetuned",
    #     num_epochs=5,
    #     batch_size=1,  # Reduce if memory issues
    #     learning_rate=1e-5
    # )
    
    # 3. Test inference (after training)
    # 
    # print(f"OCR Result: {result}")
    
    print("\nSetup complete! Follow these steps:")
    print("1. Prepare your Amazigh text images and corresponding text files")
    print("2. Update the dataset paths in the training function")
    print("3. Run the training with: python amazigh_ocr_finetune.py")
    print("4. Use the trained model for OCR inference")