<a href="https://colab.research.google.com/github/DATAGEEKN/SmartVision-Counter/blob/main/Fine_tuning_Gemma_for_Object_Detection_and_Counting_on_a_Custom_Dataset_SmartVision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import requests
import torch
from PIL import Image
from tqdm import tqdm
from datasets import Dataset
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig
from huggingface_hub import login
from zipfile import ZipFile
import logging
from getpass import getpass

# Install dependencies
try:
    import googletrans
    from googletrans import Translator, LANGUAGES
except ImportError:
    !pip install googletrans==4.0.0-rc1
    try:
        from googletrans import Translator, LANGUAGES
    except ImportError:
        !pip install translate
        from translate import Translator as FallbackTranslator
        logger.warning("Using fallback translate library due to googletrans installation failure")

# Set up logging for demo appeal
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Step 1: Prompt for Hugging Face token
def get_hf_token():
    token = getpass("Enter your Hugging Face access token: ")
    if not token.strip():
        raise ValueError("Hugging Face token cannot be empty. Please provide a valid token.")
    return token

try:
    HF_TOKEN = get_hf_token()
    login(token=HF_TOKEN, add_to_git_credential=False)
    logger.info("Hugging Face login successful")
except Exception as e:
    logger.error(f"Error during Hugging Face login: {e}")
    raise

# Step 2: Set up environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Step 3: Create directories
os.makedirs("/content/coco_images", exist_ok=True)
os.makedirs("/content/coco_processed", exist_ok=True)
os.makedirs("/content/paligemma_finetuned", exist_ok=True)

# Step 4: Download and extract COCO 2017 validation annotations
anno_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
anno_zip = "annotations_trainval2017.zip"
anno_path = "/content/coco_annotations/annotations/instances_val2017.json"
if not os.path.exists(anno_path):
    logger.info("Downloading annotations...")
    response = requests.get(anno_url, stream=True)
    with open(anno_zip, "wb") as f:
        for chunk in tqdm(response.iter_content(chunk_size=1024), desc="Downloading annotations"):
            f.write(chunk)
    with ZipFile(anno_zip, "r") as zip_ref:
        zip_ref.extractall("/content/coco_annotations")

# Step 5: Load COCO annotations
try:
    with open(anno_path, "r") as f:
        coco_data = json.load(f)
    logger.info("COCO annotations loaded successfully")
except Exception as e:
    logger.error(f"Error loading COCO annotations: {e}")
    raise

# Step 6: Map category IDs to names
category_map = {cat["id"]: cat["name"] for cat in coco_data["categories"]}

# Step 7: Process COCO images for PaliGemma
paligemma_data = []
image_filenames = ["000000289343.jpg", "000000039769.jpg"]
for img in coco_data["images"][:50]:
    if img["file_name"] not in image_filenames:
        image_filenames.append(img["file_name"])

for filename in tqdm(image_filenames, desc="Processing images"):
    img_path = f"/content/coco_images/{filename}"
    image_url = f"http://images.cocodataset.org/val2017/{filename}"

    if not os.path.exists(img_path):
        try:
            img_data = requests.get(image_url, timeout=10).content
            with open(img_path, "wb") as f:
                f.write(img_data)
        except Exception as e:
            logger.warning(f"Failed to download {filename}: {e}")
            continue

    try:
        Image.open(img_path).verify()
    except Exception as e:
        logger.warning(f"Invalid image {filename}: {e}")
        continue

    image_id = int(filename.split(".")[0])
    image_info = next((img for img in coco_data["images"] if img["id"] == image_id), None)
    if not image_info:
        logger.warning(f"Image {filename} not found in annotations")
        continue

    annotations = [ann for ann in coco_data["annotations"] if ann["image_id"] == image_id]
    if not annotations:
        logger.warning(f"No annotations for {filename}")
        continue

    suffix = ""
    for ann in annotations:
        category_name = category_map[ann["category_id"]]
        bbox = ann["bbox"]
        x1, y1 = bbox[0], bbox[1]
        x2, y2 = x1 + bbox[2], y1 + bbox[3]
        suffix += f"{category_name};{x1},{y1},{x2},{y2};"

    paligemma_data.append({
        "image": filename,
        "prefix": "detect object",
        "suffix": suffix.rstrip(";")
    })

# Step 8: Save PaliGemma JSONL
jsonl_path = "/content/coco_processed/paligemma_dataset.jsonl"
try:
    with open(jsonl_path, "w") as f:
        for entry in paligemma_data:
            f.write(json.dumps(entry) + "\n")
    logger.info(f"PaliGemma JSONL dataset saved to {jsonl_path}")
except Exception as e:
    logger.error(f"Error saving JSONL file: {e}")
    raise

# Step 9: Load dataset manually
try:
    with open(jsonl_path, "r") as f:
        lines = f.readlines()
    dataset_list = [json.loads(line.strip()) for line in lines if line.strip()]
    dataset = Dataset.from_list(dataset_list)
    logger.info(f"Dataset loaded with {len(dataset)} examples")
except Exception as e:
    logger.error(f"Error loading dataset: {e}")
    raise

# Step 10: Load PaliGemma model with quantization
model_id = "google/paligemma-3b-mix-448"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

try:
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    processor = AutoProcessor.from_pretrained(model_id)
    logger.info("Model and processor loaded successfully")
except Exception as e:
    logger.error(f"Error loading model or processor: {e}")
    raise

# Step 11: Enable LoRA
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
logger.info("LoRA enabled; trainable parameters:")
model.print_trainable_parameters()

# Step 12: Preprocess dataset
def preprocess_data(example):
    try:
        img_path = f"/content/coco_images/{example['image']}"
        image = Image.open(img_path).convert("RGB")
        prompt = example["prefix"]
        suffix = example["suffix"]

        # Process inputs with PaliGemma processor
        inputs = processor(text=prompt, images=image, return_tensors="pt", padding=True)

        # Ensure inputs are tensors and squeeze dimensions
        input_ids = inputs["input_ids"].squeeze()
        attention_mask = inputs["attention_mask"].squeeze()
        pixel_values = inputs["pixel_values"].squeeze()

        # Encode labels (suffix) as tensors
        labels = processor.tokenizer.encode(suffix, add_special_tokens=False, return_tensors="pt").squeeze()

        # Validate tensor types
        if not all(isinstance(x, torch.Tensor) for x in [input_ids, attention_mask, pixel_values, labels]):
            logger.warning(f"Invalid tensor types for {example['image']}")
            return None

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "labels": labels
        }
    except Exception as e:
        logger.warning(f"Error processing {example['image']}: {e}")
        return None

processed_dataset = dataset.map(preprocess_data, remove_columns=dataset.column_names)
processed_dataset = processed_dataset.filter(lambda x: x is not None)
logger.info(f"Processed dataset size: {len(processed_dataset)}")

# Step 13: Set up training arguments
training_args = TrainingArguments(
    output_dir="/content/paligemma_finetuned",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=5e-5,
    num_train_epochs=2,
    logging_steps=5,
    save_strategy="epoch",
    remove_unused_columns=False,
    optim="paged_adamw_8bit",
    bf16=True,
    gradient_checkpointing=True,
    warmup_ratio=0.1,
    report_to="none"
)

# Step 14: Custom collator
def data_collator(examples):
    if not examples or all(x is None for x in examples):
        logger.warning("Empty or invalid batch in data_collator")
        return None

    input_ids = []
    attention_masks = []
    labels = []
    pixel_values = []

    for example in examples:
        if example is None:
            continue
        if not all(isinstance(example[key], torch.Tensor) for key in ["input_ids", "attention_mask", "labels", "pixel_values"]):
            logger.warning(f"Skipping invalid example: {example}")
            continue
        input_ids.append(example["input_ids"])
        attention_masks.append(example["attention_mask"])
        labels.append(example["labels"])
        pixel_values.append(example["pixel_values"])

    if not input_ids:
        logger.warning("No valid examples in batch")
        return None

    try:
        return {
            "input_ids": torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=processor.tokenizer.pad_token_id),
            "attention_mask": torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0),
            "labels": torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100),
            "pixel_values": torch.stack(pixel_values)
        }
    except Exception as e:
        logger.error(f"Error in data_collator: {e}")
        return None

# Step 15: Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    data_collator=data_collator
)

# Step 16: Train the model
try:
    trainer.train()
    logger.info("Training completed successfully")
except Exception as e:
    logger.error(f"Error during training: {e}")
    raise

# Step 17: Save the fine-tuned model
model.save_pretrained("/content/paligemma_finetuned_model")
processor.save_pretrained("/content/paligemma_finetuned_model")
logger.info("Model and processor saved successfully")

# Step 18: Multilingual inference and object counting
def detect_and_count_objects(image_path, prompt="detect object", input_lang="en", output_lang="en"):
    try:
        # Initialize translator
        if 'googletrans' in globals():
            translator = Translator()
            def translate_text(text, src, dest):
                return translator.translate(text, src=src, dest=dest).text
        else:
            translator = FallbackTranslator(from_lang=input_lang, to_lang=output_lang)
            def translate_text(text, src, dest):
                return translator.translate(text)

        # Translate prompt to English
        if input_lang != "en":
            prompt = translate_text(prompt, input_lang, "en")
            logger.info(f"Translated prompt from {input_lang} to en: {prompt}")

        image = Image.open(image_path).convert("RGB")
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=100)
        decoded_output = processor.decode(output[0], skip_special_tokens=True)

        # Parse detections
        detections = [det.strip() for det in decoded_output.split(";") if det.strip()]
        object_count = len(detections)

        # Translate detections to output language
        if output_lang != "en":
            translated_detections = []
            for det in detections:
                parts = det.split(";")
                if len(parts) >= 1:
                    class_name = parts[0]
                    coords = ";".join(parts[1:]) if len(parts) > 1 else ""
                    translated_class = translate_text(class_name, "en", output_lang)
                    translated_detections.append(f"{translated_class};{coords}")
                else:
                    translated_detections.append(det)
            detections = translated_detections
            decoded_output = ";".join(detections)
            logger.info(f"Translated detections to {output_lang}")

        return decoded_output, object_count
    except Exception as e:
        logger.error(f"Error during inference: {e}")
        return "", 0

# Step 19: Demo interface for competition
def run_demo():
    logger.info("Starting SmartVision Counter demo")
    test_images = ["/content/coco_images/000000289343.jpg", "/content/coco_images/000000039769.jpg"]
    prompts = [
        ("detect object", "en", "en", "English"),
        ("thola into", "zu", "zu", "Zulu"),
        ("gundua kitu", "sw", "sw", "Swahili")
    ]

    print("\n=== SmartVision Counter Demo ===")
    print("Automated object counting for African retail, agriculture, and manufacturing")
    print("Multilingual support for Zulu and Swahili to empower African communities")
    print("Model: google/paligemma-3b-mix-448\n")

    for test_image in test_images:
        if os.path.exists(test_image):
            logger.info(f"Demo for image: {test_image}")
            print(f"\n=== Processing {test_image.split('/')[-1]} ===")
            for prompt, input_lang, output_lang, lang_name in prompts:
                detections, count = detect_and_count_objects(test_image, prompt, input_lang, output_lang)
                print(f"\nLanguage: {lang_name}")
                print(f"Prompt: '{prompt}'")
                print(f"Detections: {detections}")
                print(f"Object count: {count}")
        else:
            logger.warning(f"Test image {test_image} not found")
            print(f"Test image {test_image} not found")

# Run the demo
run_demo()