In [None]:
from google.colab import drive
drive.mount('/content/drive')

# !mkdir -p "/content/drive/MyDrive/Colab Notebooks/Vee/finetune_hub"
# print("Folder created successfully on Google Drive!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [16]:
%cd /content/drive/MyDrive/Colab Notebooks/Vee

/content/drive/MyDrive/Colab Notebooks/Vee


In [None]:
# !pip install -q -U torch torchvision torchaudio transformers datasets peft bitsandbytes accelerate

In [None]:
# !mkdir -p finetune_hub

In [None]:
%%writefile finetune_hub/config.py
from dataclasses import dataclass, field
from typing import Optional, List

@dataclass
class ModelConfig:
    # Model & Data Settings
    model_id: str = "google/paligemma-3b-pt-224"
    dataset_id: str = "deepcopy/MathWriting-human"
    prompt_text: str = "Convert this handwritten math to LaTeX."

    # Quantization & LoRA
    use_4bit: bool = True
    lora_rank: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])

    # Training Hyperparameters
    learning_rate: float = 2e-4
    batch_size: int = 4
    gradient_accumulation_steps: int = 4
    num_train_epochs: int = 1
    save_steps: int = 25

    # System
    max_seq_length: int = 512
    output_dir: str = "./math_vlm_adapter"
    logging_steps: int = 5
    task_type: str = "CAUSAL_LM"

Overwriting finetune_hub/config.py


In [None]:
%%writefile finetune_hub/adapter.py
import torch
from peft import LoraConfig, TaskType
from transformers import BitsAndBytesConfig
from .config import ModelConfig

class AdapterFactory:
    @staticmethod
    def get_qlora_config(cfg: ModelConfig) -> BitsAndBytesConfig:
        return BitsAndBytesConfig(
            load_in_4bit=cfg.use_4bit,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True
        )

    @staticmethod
    def get_lora_config(cfg: ModelConfig) -> LoraConfig:
        return LoraConfig(
            r=cfg.lora_rank,
            lora_alpha=cfg.lora_alpha,
            target_modules=cfg.target_modules,
            lora_dropout=cfg.lora_dropout,
            bias="none",
            task_type=cfg.task_type
        )

Overwriting finetune_hub/adapter.py


In [18]:
%%writefile finetune_hub/engine.py
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from peft import get_peft_model, prepare_model_for_kbit_training
from .config import ModelConfig
from .adapter import AdapterFactory

class VLMEngine:
    def __init__(self, cfg: ModelConfig):
        self.cfg = cfg
        self.processor = None
        self.model = None

    def load_model(self):
        print(f"Loading Base Model: {self.cfg.model_id}...")
        bnb_config = AdapterFactory.get_qlora_config(self.cfg)

        self.model = PaliGemmaForConditionalGeneration.from_pretrained(
            self.cfg.model_id,
            quantization_config=bnb_config,
            device_map="auto",
            dtype=torch.float16
        )

        self.processor = AutoProcessor.from_pretrained(self.cfg.model_id)

        # Enable Gradient Checkpointing (Saves VRAM)
        self.model.gradient_checkpointing_enable()
        self.model = prepare_model_for_kbit_training(self.model)

    def apply_adapter(self):
        lora_config = AdapterFactory.get_lora_config(self.cfg)
        self.model = get_peft_model(self.model, lora_config)

        print("\n--- Trainable Parameters ---")
        self.model.print_trainable_parameters()
        return self.model

Overwriting finetune_hub/engine.py


In [None]:
%%writefile finetune_hub/data.py
import torch
from datasets import load_dataset
from PIL import Image

class DataProcessor:
    def __init__(self, processor, cfg):
        self.processor = processor
        self.cfg = cfg

    def load_data(self, split="train", limit=None):
        print(f"Loading dataset: {self.cfg.dataset_id}...")
        ds = load_dataset(self.cfg.dataset_id, split=split)
        if limit:
            ds = ds.select(range(limit))
        return ds

    def collate_fn(self, examples):
        # Dynamic prompt from config
        texts = [f"{self.cfg.prompt_text}" for _ in examples]
        images = [ex["image"].convert("RGB") for ex in examples]
        labels = [ex["latex"] for ex in examples]

        inputs = self.processor(
            text=texts,
            images=images,
            suffix=labels,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=self.cfg.max_seq_length
        )

        return inputs

Overwriting finetune_hub/data.py


In [17]:
%%writefile finetune_hub/trainer.py
import transformers
from transformers import Trainer, TrainingArguments
from .config import ModelConfig

class TrainerWrapper:
    def __init__(self, model, processor, train_dataset, cfg: ModelConfig, data_collator):
        self.model = model
        self.processor = processor
        self.train_dataset = train_dataset
        self.cfg = cfg
        self.data_collator = data_collator

    def train(self):
        args = TrainingArguments(
            output_dir=self.cfg.output_dir,
            per_device_train_batch_size=self.cfg.batch_size,
            gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
            warmup_steps=10,
            num_train_epochs=self.cfg.num_train_epochs,
            learning_rate=self.cfg.learning_rate,
            logging_steps=self.cfg.logging_steps,
            optim="paged_adamw_8bit",
            save_strategy="steps",
            save_steps=self.cfg.save_steps,
            fp16=True,
            report_to="none",
            remove_unused_columns=False
        )

        trainer = Trainer(
            model=self.model,
            args=args,
            train_dataset=self.train_dataset,
            data_collator=self.data_collator
        )

        print("Starting Training...")
        trainer.train()

        print(f"Saving Trainer State to {self.cfg.output_dir}...")
        trainer.save_state()

        print(f"Saving Adapter to {self.cfg.output_dir}...")
        self.model.save_pretrained(self.cfg.output_dir)
        self.processor.save_pretrained(self.cfg.output_dir)

        return trainer

Overwriting finetune_hub/trainer.py


In [None]:
%%writefile finetune_hub/inference.py
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from peft import PeftModel
from PIL import Image
import os

class InferenceEngine:
    def __init__(self, base_model_id, adapter_path):
        self.processor = AutoProcessor.from_pretrained(base_model_id)

        base_model = PaliGemmaForConditionalGeneration.from_pretrained(
            base_model_id,
            device_map="auto",
            torch_dtype=torch.float16
        )
        self.model = PeftModel.from_pretrained(base_model, adapter_path)
        self.model.merge_and_unload()
        self.model.eval()

    def generate(self, image_path: str, prompt_text: str, output_file: str = "output.tex"):
        image = Image.open(image_path).convert("RGB")
        full_prompt = f"{prompt_text}"

        inputs = self.processor(text=full_prompt, images=image, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=False
            )

        result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        # Clean up the output to remove the prompt
        cleaned_result = result.replace(full_prompt, "").strip()

        with open(output_file, "w") as f:
            f.write(cleaned_result)

        return cleaned_result

Overwriting finetune_hub/inference.py


In [None]:
%%writefile finetune_hub/__init__.py
from .config import ModelConfig
from .engine import VLMEngine
from .data import DataProcessor
from .trainer import TrainerWrapper
from .inference import InferenceEngine

__all__ = [
    "ModelConfig",
    "VLMEngine",
    "DataProcessor",
    "TrainerWrapper",
    "InferenceEngine"
]

Overwriting finetune_hub/__init__.py
