In [1]:
!wget "https://github.com/liberationfonts/liberation-fonts.git" -O LiberationSans-Regular.ttf
print("✅ Font downloaded.")

!pip install faker
# python -m pip install Faker
# Step 3: Import all necessary modules
import random
import json
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from faker import Faker
from albumentations import Compose, Rotate, GaussNoise, Blur
from datasets import Dataset
import os
from tqdm.auto import tqdm

# --- Configuration ---
fake = Faker()
Faker.seed(42)

# CORRECTED: Use the direct path to the single font we just downloaded.
FONT_PATH = './LiberationSans-Regular.ttf'

# Check if the font file exists
if os.path.exists(FONT_PATH):
    print(f"✅ Font file found at: {FONT_PATH}")
else:
    # This error should not occur with the new download method.
    print(f"❌ Font file not found at: {FONT_PATH}. There might be a network issue.")

--2025-07-28 09:54:59--  https://github.com/liberationfonts/liberation-fonts.git
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://github.com/liberationfonts/liberation-fonts [following]
--2025-07-28 09:54:59--  https://github.com/liberationfonts/liberation-fonts
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘LiberationSans-Regular.ttf’

LiberationSans-Regu     [ <=>                ] 273.50K  --.-KB/s    in 0.03s   

2025-07-28 09:55:00 (9.00 MB/s) - ‘LiberationSans-Regular.ttf’ saved [280060]

✅ Font downloaded.
Collecting faker
  Downloading faker-37.4.2-py3-none-any.whl.metadata (15 kB)
Downloading faker-37.4.2-py3-none-any.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m38.2 MB/s[

In [2]:
def generate_text_image(text, font_path):
    """Generates an image of text using the downloaded font."""
    font_size = 22
    font = ImageFont.truetype(font_path, font_size)

    dummy_img = Image.new('RGB', (1, 1))
    draw = ImageDraw.Draw(dummy_img)
    left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
    text_width = right - left
    text_height = bottom - top

    padding = 10
    image = Image.new('RGB', (text_width + 2*padding, text_height + 2*padding), 'white')
    draw = ImageDraw.Draw(image)
    draw.text((padding, padding), text, font=font, fill='black')
    return image

def generate_table_image(data, title):
    """Generates an image of a table using matplotlib."""
    fig, ax = plt.subplots(figsize=(6, 2.5), dpi=100)
    ax.axis('off')
    ax.set_title(title, fontweight="bold")

    table = ax.table(cellText=data['values'], colLabels=data['headers'], loc='center', cellLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1, 1.8)

    fig.canvas.draw()
    img_rgba = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    img_rgb = img_rgba.reshape(fig.canvas.get_width_height()[::-1] + (4,))[..., :3]
    plt.close(fig)
    return Image.fromarray(img_rgb)

def generate_graph_image(data, title, xlabel, ylabel):
    """Generates a line graph image using matplotlib."""
    fig, ax = plt.subplots(figsize=(6, 4), dpi=100)
    ax.plot(data['x'], data['y'], marker='o', color='b')
    ax.set_title(title, fontweight="bold", fontsize=14)
    ax.set_xlabel(xlabel, fontsize=12)
    ax.set_ylabel(ylabel, fontsize=12)
    ax.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()

    fig.canvas.draw()
    img_rgba = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    img_rgb = img_rgba.reshape(fig.canvas.get_width_height()[::-1] + (4,))[..., :3]
    plt.close(fig)
    return Image.fromarray(img_rgb)

# --- Augmentations ---
augmentations = Compose([
    Rotate(limit=3, p=0.6),
    GaussNoise(p=0.4),
    Blur(blur_limit=3, p=0.4),
])

In [3]:
# --- Main Data Generation Loop ---
def create_instruction_dataset(num_samples=50):
    dataset = []
    task_types = ['handwriting_ocr', 'table_extraction', 'graph_qa']

    print(f"Generating {num_samples} synthetic samples...")
    for i in range(num_samples):
        task = random.choice(task_types)
        canvas = Image.new('RGB', (800, 600), 'white')
        instruction, answer = "", ""

        header_font = ImageFont.truetype(FONT_PATH, 28)
        header_text = "Community Health Clinic - Patient Record"
        draw = ImageDraw.Draw(canvas)
        draw.text((50, 40), header_text, font=header_font, fill='black')
        draw.line([(50, 80), (750, 80)], fill='black', width=2)

        if task == 'handwriting_ocr':
            notes = f"Doctor's Notes:\n{fake.paragraph(nb_sentences=3)}"
            text_img = generate_text_image(notes, FONT_PATH)
            canvas.paste(text_img, (50, 120))
            instruction = "Transcribe the text under 'Doctor's Notes'."
            answer = notes.replace("Doctor's Notes:\n", "").strip()

        elif task == 'table_extraction':
            table_data = {
                'headers': ['Test', 'Result', 'Reference Range'],
                'values': [
                    ['WBC', f"{random.uniform(4.0, 11.0):.1f} x10^9/L", '4.0-11.0'],
                    ['HGB', f"{random.uniform(12.0, 16.0):.1f} g/dL", '12.0-16.0'],
                    ['PLT', f"{random.randint(150, 450)} x10^9/L", '150-450']
                ]
            }
            table_img = generate_table_image(table_data, "Complete Blood Count (CBC)")
            canvas.paste(table_img, (100, 150))
            instruction = "Extract the test results for HGB and PLT in JSON format."
            answer_dict = {
                "HGB": table_data['values'][1][1],
                "PLT": table_data['values'][2][1]
            }
            answer = json.dumps(answer_dict)

        elif task == 'graph_qa':
            dates = [f'07-{15+d}' for d in range(5)]
            glucose_levels = [random.randint(90, 180) for _ in dates]
            graph_data = {'x': dates, 'y': glucose_levels}
            graph_img = generate_graph_image(graph_data, "Fasting Glucose Trend (mg/dL)", "Date (July 2025)", "Glucose Level")
            canvas.paste(graph_img, (100, 120))

            qa_idx = random.randint(0, len(dates) - 1)
            qa_date = dates[qa_idx]
            qa_glucose = glucose_levels[qa_idx]

            instruction = f"What was the patient's fasting glucose level on {qa_date}?"
            answer = f"{qa_glucose} mg/dL"

        augmented_image_np = augmentations(image=np.array(canvas))['image']
        final_image = Image.fromarray(augmented_image_np)

        dataset.append({"image": final_image, "instruction": instruction, "answer": answer})

    return Dataset.from_list(dataset)

# --- Generate the dataset ---
my_dataset = create_instruction_dataset(num_samples=100) # Generating 100 samples
print("\n✅ Dataset generation complete.")

Generating 100 synthetic samples...

✅ Dataset generation complete.


In [4]:
output_dir = "/content/generated_data/"
images_dir = os.path.join(output_dir, "images")
labels_dir = os.path.join(output_dir, "labels")

os.makedirs(images_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True)

# --- Loop through the dataset to save each sample ---
for i, sample in enumerate(tqdm(my_dataset, desc="Saving samples")):
    image = sample['image']
    instruction = sample['instruction']
    answer = sample['answer']

    # Define file paths
    image_filename = f"sample_{i+1}.png"
    label_filename = f"sample_{i+1}.json"

    image_path = os.path.join(images_dir, image_filename)
    label_path = os.path.join(labels_dir, label_filename)

    # Save the image
    image.save(image_path)

    # Save the instruction and answer in a JSON file
    label_data = {
        "instruction": instruction,
        "answer": answer
    }
    with open(label_path, 'w') as f:
        json.dump(label_data, f, indent=4)

print(f"\n✅ All {len(my_dataset)} samples have been saved to {output_dir}")

Saving samples:   0%|          | 0/100 [00:00<?, ?it/s]


✅ All 100 samples have been saved to /content/generated_data/


In [5]:
# Uninstall libraries to ensure a clean environment
!pip uninstall -y torch torchvision torchaudio transformers accelerate bitsandbytes peft

# Install compatible libraries for your GPU
!pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
!pip install transformers==4.41.2 accelerate==0.30.1 bitsandbytes==0.43.1 peft==0.11.1

print("✅ Installation complete. Please restart the runtime now.")

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Found existing installation: transformers 4.52.4
Uninstalling transformers-4.52.4:
  Successfully uninstalled transformers-4.52.4
Found existing installation: accelerate 1.8.1
Uninstalling accelerate-1.8.1:
  Successfully uninstalled accelerate-1.8.1
[0mFound existing installation: peft 0.15.2
Uninstalling peft-0.15.2:
  Successfully uninstalled peft-0.15.2
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.3.0
  Downloading https://download.pytorch.org/whl/cu121/torch-2.3.0%2Bcu121-cp311-cp311-linux_x86_64.whl (781.0 MB)
[2K     [90m

In [6]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [7]:
import torch
import transformers
import bitsandbytes
import peft
import accelerate

print("--- Library Versions ---")
print(f"Torch: {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"Bitsandbytes: {bitsandbytes.__version__}")
print(f"PEFT: {peft.__version__}")
print(f"Accelerate: {accelerate.__version__}")

AttributeError: LOCAL_NN_MODULE

In [None]:
import sys
print(f"Python version: {sys.version}")

In [None]:
!pip uninstall flash-attn -y
!pip install flash-attn==2.5.8 --no-build-isolation # Configured according to CUDA version

In [None]:
# !pip list

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from transformers.modeling_utils import PreTrainedModel

model_id = "microsoft/Florence-2-base"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# BitsAndBytes configuration for 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

# Load the base model with quantization
print("Loading base model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    trust_remote_code=True,
    # ignore_mismatched_sizes=True
    attn_implementation="eager"  # Use the standard attention implementation
)
print("✅ Base model loaded.")

# LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "dense"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply the LoRA adapter to the model
print("\nApplying LoRA configuration...")
peft_model = get_peft_model(model, lora_config)
print("✅ LoRA model ready for training.")

# Print the number of trainable parameters
print("\nTrainable parameters:")
peft_model.print_trainable_parameters()

In [None]:
import os
import json
import glob
from PIL import Image
from sklearn.model_selection import train_test_split

# First, let's examine the structure of your JSON files
def examine_json_structure(data_dir="/content/generated_data"):
    """Examine the structure of your JSON label files"""
    labels_dir = os.path.join(data_dir, "labels")
    sample_json = os.path.join(labels_dir, "sample_1.json")

    if os.path.exists(sample_json):
        with open(sample_json, 'r') as f:
            sample_data = json.load(f)
        print("Sample JSON structure:")
        print(json.dumps(sample_data, indent=2))
        return sample_data
    else:
        print("sample_1.json not found. Available files:")
        print(os.listdir(labels_dir)[:5])
        return None

# Load and prepare your dataset
def load_florence2_data(data_dir="/content/generated_data", task_type="OCR"):
    """
    Load your dataset for Florence-2 training

    Args:
        data_dir: Path to your data directory
        task_type: Type of task - "OCR", "OD" (Object Detection), "VQA", etc.
    """
    images_dir = os.path.join(data_dir, "images")
    labels_dir = os.path.join(data_dir, "labels")

    # Get all image files
    image_files = sorted(glob.glob(os.path.join(images_dir, "*.png")))

    dataset = []

    for image_path in image_files:
        # Get corresponding JSON file
        image_name = os.path.basename(image_path).replace('.png', '')
        json_path = os.path.join(labels_dir, f"{image_name}.json")

        if os.path.exists(json_path):
            try:
                with open(json_path, 'r') as f:
                    label_data = json.load(f)

                # Process based on task type and JSON structure
                if task_type == "OCR":
                    # For OCR task using your answer data
                    task_prompt = "<OCR>"

                    # Use the answer as OCR target (medical values/text)
                    if 'answer' in label_data:
                        target = label_data['answer']
                    elif 'instruction' in label_data:
                        target = label_data['instruction']
                    else:
                        target = str(label_data)

                elif task_type == "VQA":
                    # For Visual Question Answering
                    if 'instruction' in label_data and 'answer' in label_data:
                        task_prompt = f"<VQA>{label_data['instruction']}"
                        target = label_data['answer']
                    else:
                        continue  # Skip if no instruction/answer data

                elif task_type == "OD":
                    # For Object Detection
                    task_prompt = "<OD>"

                    # Format object detection results
                    if 'objects' in label_data:
                        objects = label_data['objects']
                        target = format_od_target(objects)
                    elif 'bboxes' in label_data:
                        target = format_bbox_target(label_data)
                    else:
                        continue  # Skip if no object data

                elif task_type == "VQA":
                    # For Visual Question Answering
                    if 'question' in label_data and 'answer' in label_data:
                        task_prompt = f"<VQA>{label_data['question']}"
                        target = label_data['answer']
                    else:
                        continue  # Skip if no QA data

                else:
                    # Generic captioning
                    task_prompt = "<CAPTION>"
                    target = str(label_data)

                dataset.append({
                    'image': image_path,
                    'task_prompt': task_prompt,
                    'target': target,
                    'image_name': image_name
                })

            except Exception as e:
                print(f"Error processing {json_path}: {e}")
                continue

    print(f"Loaded {len(dataset)} samples for {task_type} task")
    return dataset

def format_od_target(objects):
    """Format object detection target for Florence-2"""
    # Florence-2 expects format: "<loc_x1><loc_y1><loc_x2><loc_y2>label"
    formatted_objects = []
    for obj in objects:
        if 'bbox' in obj and 'label' in obj:
            bbox = obj['bbox']
            label = obj['label']
            # Convert bbox to Florence-2 format (normalized coordinates)
            formatted_objects.append(f"<loc_{bbox[0]}><loc_{bbox[1]}><loc_{bbox[2]}><loc_{bbox[3]}>{label}")

    return " ".join(formatted_objects)

def format_bbox_target(label_data):
    """Format bbox data for Florence-2"""
    # Adjust this based on your specific bbox format
    return str(label_data)

# Split dataset into train/validation
def split_dataset(dataset, test_size=0.2, random_state=42):
    """Split dataset into train and validation sets"""
    if len(dataset) < 2:
        print("Dataset too small to split. Using all data for training.")
        return dataset, []

    train_data, eval_data = train_test_split(
        dataset,
        test_size=test_size,
        random_state=random_state,
        shuffle=True
    )

    print(f"Train samples: {len(train_data)}")
    print(f"Eval samples: {len(eval_data)}")

    return train_data, eval_data

# Main function to prepare your data
def prepare_training_data(data_dir="/content/generated_data", task_type="OCR"):
    """
    Complete function to prepare your data for training
    """
    print("Examining JSON structure...")
    sample_structure = examine_json_structure(data_dir)

    print(f"\nLoading data for {task_type} task...")
    dataset = load_florence2_data(data_dir, task_type)

    if len(dataset) == 0:
        print("No data loaded! Please check your JSON structure and task_type.")
        return None, None

    print("\nSample data point:")
    print(f"Image: {dataset[0]['image']}")
    print(f"Task prompt: {dataset[0]['task_prompt']}")
    print(f"Target: {dataset[0]['target'][:100]}...")  # First 100 chars

    print("\nSplitting dataset...")
    train_data, eval_data = split_dataset(dataset)

    return train_data, eval_data

# Usage example:
if __name__ == "__main__":
    # Examine your data structure first
    examine_json_structure()

    # Load data for OCR task (change task_type as needed)
    train_data, eval_data = prepare_training_data(
        data_dir="/content/generated_data",
        task_type="OCR"  # Change to "OD", "VQA", etc. based on your data
    )

In [None]:
# !ls -R /content/generated_data/

In [None]:
!apt-get update && apt-get install -y tesseract-ocr
!pip install pytesseract

In [None]:
import os
import json
import glob
from PIL import Image
import pytesseract

def create_ocr_labels_from_images(data_dir="/content/generated_data"):
    """
    Create OCR labels by extracting all text from images using Tesseract
    This will give you proper OCR training data
    """
    images_dir = os.path.join(data_dir, "images")
    ocr_labels_dir = os.path.join(data_dir, "ocr_labels")

    # Create OCR labels directory
    os.makedirs(ocr_labels_dir, exist_ok=True)

    # Get all image files
    image_files = sorted(glob.glob(os.path.join(images_dir, "*.png")))

    print("Extracting OCR text from images...")

    for i, image_path in enumerate(image_files):
        try:
            # Load image
            image = Image.open(image_path)

            # Extract text using Tesseract OCR
            extracted_text = pytesseract.image_to_string(image, config='--psm 6')

            # Clean up the text
            cleaned_text = extracted_text.strip().replace('\n', ' ').replace('\t', ' ')
            while '  ' in cleaned_text:
                cleaned_text = cleaned_text.replace('  ', ' ')

            # Save OCR label
            image_name = os.path.basename(image_path).replace('.png', '')
            ocr_label_path = os.path.join(ocr_labels_dir, f"{image_name}.json")

            ocr_data = {
                "image": image_path,
                "ocr_text": cleaned_text,
                "extraction_method": "tesseract"
            }

            with open(ocr_label_path, 'w') as f:
                json.dump(ocr_data, f, indent=2)

            if i % 10 == 0:
                print(f"Processed {i+1}/{len(image_files)} images")
                print(f"Sample text from {image_name}: {cleaned_text[:100]}...")

        except Exception as e:
            print(f"Error processing {image_path}: {e}")

    print(f"OCR labels created in {ocr_labels_dir}")
    return ocr_labels_dir

def load_ocr_dataset(data_dir="/content/generated_data", use_generated_ocr=False):
    """
    Load dataset for OCR training

    Args:
        data_dir: Path to your data directory
        use_generated_ocr: If True, use auto-generated OCR labels; if False, use existing answers
    """
    images_dir = os.path.join(data_dir, "images")

    if use_generated_ocr:
        # Use auto-generated OCR labels
        labels_dir = os.path.join(data_dir, "ocr_labels")
        if not os.path.exists(labels_dir):
            print("OCR labels not found. Creating them...")
            create_ocr_labels_from_images(data_dir)
    else:
        # Use existing VQA answers as OCR targets
        labels_dir = os.path.join(data_dir, "labels")

    # Get all image files
    image_files = sorted(glob.glob(os.path.join(images_dir, "*.png")))

    dataset = []

    for image_path in image_files:
        image_name = os.path.basename(image_path).replace('.png', '')
        json_path = os.path.join(labels_dir, f"{image_name}.json")

        if os.path.exists(json_path):
            try:
                with open(json_path, 'r') as f:
                    label_data = json.load(f)

                # Extract OCR text based on source
                if use_generated_ocr:
                    if 'ocr_text' in label_data:
                        ocr_text = label_data['ocr_text']
                    else:
                        continue
                else:
                    # Use answer from VQA data as OCR text
                    if 'answer' in label_data:
                        ocr_text = label_data['answer']
                    else:
                        continue

                # Skip empty text
                if not ocr_text or ocr_text.strip() == "":
                    continue

                dataset.append({
                    'image': image_path,
                    'task_prompt': '<OCR>',
                    'target': ocr_text.strip(),
                    'image_name': image_name
                })

            except Exception as e:
                print(f"Error processing {json_path}: {e}")
                continue

    print(f"Loaded {len(dataset)} OCR samples")
    return dataset

# Install tesseract if needed
def install_tesseract():
    """Install Tesseract OCR if not available"""
    try:
        import pytesseract
        pytesseract.image_to_string(Image.new('RGB', (100, 100), 'white'))
        print("✅ Tesseract already available")
    except:
        print("Installing Tesseract OCR...")
        os.system("apt-get update && apt-get install -y tesseract-ocr")
        os.system("pip install pytesseract")
        print("✅ Tesseract installed")

In [None]:
# Load the OCR dataset you just created
from sklearn.model_selection import train_test_split

# Load OCR dataset with proper labels
ocr_dataset = load_ocr_dataset(
    data_dir="/content/generated_data",
    use_generated_ocr=True  # Use the OCR labels you just created
)

# Split into train/eval
train_data, eval_data = train_test_split(
    ocr_dataset,
    test_size=0.2,
    random_state=42
)

print(f"OCR Training samples: {len(train_data)}")
print(f"OCR Eval samples: {len(eval_data)}")

# Verify the OCR data looks good
for i in range(min(3, len(train_data))):
    print(f"\nOCR Example {i+1}:")
    print(f"Image: {os.path.basename(train_data[i]['image'])}")
    print(f"Task: {train_data[i]['task_prompt']}")
    print(f"OCR Text: {train_data[i]['target'][:100]}...")  # First 100 chars

In [None]:
from datasets import Dataset
import pandas as pd
from transformers import TrainingArguments, Trainer
import os
import json
from typing import Any, Dict, List
from PIL import Image

# --- Load your saved dataset ---
def load_saved_dataset(image_dir, label_dir):
    dataset_list = []
    label_files = sorted(os.listdir(label_dir))
    for label_file in label_files:
        if label_file.endswith('.json'):
            image_file = label_file.replace('.json', '.png')
            image_path = os.path.join(image_dir, image_file)
            label_path = os.path.join(label_dir, label_file)

            if os.path.exists(image_path):
                with open(label_path, 'r') as f:
                    label_data = json.load(f)

                # Load the PIL image directly
                dataset_list.append({
                    "image": Image.open(image_path).convert("RGB"),
                    "instruction": label_data["instruction"],
                    "answer": label_data["answer"]
                })

    hf_dataset = Dataset.from_list(dataset_list)
    return hf_dataset

# Create a proper train/validation split
images_dir = "/content/generated_data/images"
labels_dir = "/content/generated_data/labels"
full_dataset = load_saved_dataset(images_dir, labels_dir)

split_dataset = full_dataset.train_test_split(test_size=0.1)
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

print(f"✅ Loaded and split dataset: {len(train_dataset)} training samples, {len(eval_dataset)} evaluation samples.")

# --- Custom Data Collator ---
class FlorenceDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        texts = []
        for feature in features:
            instruction = feature["instruction"]
            answer = feature["answer"]
            texts.append(f"Analyze the image and respond to the following task.\n{instruction}\n{answer}")

        images = [feature["image"] for feature in features]

        inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True, max_length=1024)

        # Manually cast pixel_values to float16 to prevent dtype errors
        inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)

        inputs["labels"] = inputs["input_ids"].clone()

        return inputs

data_collator = FlorenceDataCollator(processor)

# --- Training ---
training_args = TrainingArguments(
    output_dir="florence2_healthcare_finetune",
    num_train_epochs=15,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    warmup_steps=10,
    learning_rate=1e-5,
    weight_decay=0.01,
    logging_steps=50, # Log less frequently to see progress bar
    save_total_limit=2,
    do_eval=True, # Enable evaluation
    eval_steps=50, # Evaluate every 50 steps
    fp16=True,
    push_to_hub=False,
    remove_unused_columns=False,
)

# Use the standard transformers.Trainer
trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

print("\nStarting training...")
trainer.train()
print("✅ Training complete.")

# Save the final LoRA adapter
trainer.save_model("florence2_healthcare_final_adapter")
print("✅ Final LoRA adapter saved to 'florence2_healthcare_final_adapter'")

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import pandas as pd
import os
import json
from PIL import Image
import random

# --- Configuration ---
# The original base model ID
model_id = "microsoft/florence-2-base"

# The path where you saved your fine-tuned adapter
adapter_path = "florence2_healthcare_final_adapter"

# The specific task prompt the model was trained on. This MUST match your training setup.
# In your training notebook, you used "<doc_ocr>".
TRAINING_TASK_PROMPT = "<doc_ocr>"

# --- Load the Fine-Tuned Model for Inference ---
# You only need to load the processor once
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Load the base model with the same quantization config used for training
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    trust_remote_code=True
)

# Attach the fine-tuned LoRA adapter to the base model
finetuned_model = PeftModel.from_pretrained(base_model, adapter_path)
print("✅ Fine-tuned model loaded for inference.")


# --- Run a Test Prediction ---
# Select a random sample from the unseen evaluation set
sample = eval_dataset[random.randint(0, len(eval_dataset) - 1)]
image = sample['image']
# The ground truth is the full text extracted from the document
ground_truth = sample['answer']
# The instruction used during training was just the task prompt itself
instruction_prompt = TRAINING_TASK_PROMPT

# Preprocess the input using the correct prompt format
inputs = processor(text=instruction_prompt, images=image, return_tensors="pt").to("cuda")

# --- Generate the Output ---
print("\nGenerating prediction...")
with torch.no_grad():
    generated_ids = finetuned_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"], # The model handles dtype conversion internally
        max_new_tokens=1024, # Increase token limit for full document OCR
        num_beams=3,
        early_stopping=True
    )

# Use the processor's post-processing function for robust parsing
# The `task` argument MUST match the prompt used for generation
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=TRAINING_TASK_PROMPT, image_size=(image.width, image.height))


# --- Display Results ---
print("\n--- Inference Result ---")
print(f"Task Prompt: {TRAINING_TASK_PROMPT}")
display(image)
print(f"\nGround Truth Answer:\n{ground_truth}")
# We access the specific task's output from the parsed dictionary
print(f"\nModel's Predicted Answer:\n{parsed_answer.get(TRAINING_TASK_PROMPT, 'Parsing failed')}")

model didnt learn anything - classic case of underfitting \
Increase training samples \
Train on Kaggle