# Florence 2 Object Detection Finetuning

In [1]:
# - Packages
import numpy as np
import torch
import cv2
import os

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F
from torchvision.ops import masks_to_boxes
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

from torch.utils.data import Dataset, DataLoader

import torch
import supervision as sv

from transformers import (
    AdamW,
    AutoModelForCausalLM,
    AutoProcessor,
    get_scheduler
)

from peft import LoraConfig, get_peft_model

In [2]:
# - Global Variables
data_dir= Path("./snemi/" )
raw_image_dir = data_dir / 'image_pngs'
seg_image_dir = data_dir / 'seg_pngs'
itrs = 10000

# CHECKPOINT = "microsoft/Florence-2-base-ft"
# REVISION = 'refs/pr/6'

CHECKPOINT = "microsoft/Florence-2-large-ft"
REVISION = 'refs/pr/19'


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label_num = 170

## - Define Dataloader
BATCH_SIZE = 2
NUM_WORKERS = 0

## - Lora config
rank = 8
alpha = 8

EPOCHS = 10000
LR = 5e-6

## - Model Initialization

model = AutoModelForCausalLM.from_pretrained(
    CHECKPOINT, trust_remote_code=True, revision=REVISION).to(DEVICE)
processor = AutoProcessor.from_pretrained(
    CHECKPOINT, trust_remote_code=True, revision=REVISION)


## Prepare Dataset and Initialize Dataset and Dataloader

In [3]:
# - Finetuning Dataset Processing
## - Prepare dataset
data = []
for ff, name in enumerate(os.listdir(raw_image_dir)):
    data.append({'image': raw_image_dir / f'image{ff:04d}.png', 'annotation': seg_image_dir / f'seg{ff:04d}.png'})
# - split train dataset and validation dataset
valid_data = data[80:]
data = data[:80]

## - Convert Mask to Bounding Boxes
def convert_mask2box(mask:np.ndarray):
    inds = np.unique(mask)[1:] # load all indices

    masks = [] 
    for ind in inds:
        masks.append(mask == ind)

    masks = np.array(masks)
    masks_tensor = torch.from_numpy(masks)

    boxes = masks_to_boxes(masks_tensor)
    valid_input_boxes = boxes.numpy()
    return valid_input_boxes

## - normalize location
def normalize_loc(prefix:str, instance_type:str, image_path:str, mask:np.ndarray, input_boxes:np.ndarray):
    x_res = mask.shape[0]
    y_res = mask.shape[1]
    normal_boxes = [[box[0] / x_res * 1000, box[1]/ y_res * 1000, box[2] / x_res * 1000, box[3] / x_res * 1000] for box in input_boxes]
    normal_boxes = np.rint(normal_boxes)
    suffix = ''
    count = 0
    for i in range(len(normal_boxes)):
        #- reach the max sequence length 1024
        if count == label_num:
            break
        x1 = int(normal_boxes[i][0])
        y1 = int(normal_boxes[i][1])
        x2 = int(normal_boxes[i][2])
        y2 = int(normal_boxes[i][3])
        suffix += f"{instance_type}<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>"
        count += 1
        

    
    return {"image": image_path,"prefix": prefix, "suffix": suffix }

## - Prepare all training dataset and validation dataset
def prepare_dataset(data, instance_type, prefix):
    dataset = []
    for element in data:
        image_path = element['image']
        seg_path = element['annotation']
        mask = np.array(Image.open(seg_path))
        input_boxes = convert_mask2box(mask)
        curated_data = normalize_loc(prefix, instance_type, image_path, mask, input_boxes)
        dataset.append(curated_data)
    return dataset

train_dataset = prepare_dataset(data, 'neuron', "<OD>")
val_dataset = prepare_dataset(valid_data, 'neuron', "<OD>")


In [4]:
# - Initialize Dataset and Dataloader

## - Detection Dataset Class (Dataset Preparation)
class DetectionDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        image = cv2.imread(str(data['image']))
        prefix = data['prefix']
        suffix = data['suffix']
        return prefix, suffix, image

## - Define Dataloader

def collate_fn(batch):
    questions, answers, images = zip(*batch)
    inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(DEVICE)
    return inputs, answers

train_dataset = DetectionDataset(train_dataset)
val_dataset = DetectionDataset(val_dataset)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS, shuffle=False)

## Image Visualization and Finetuning Process

In [5]:
# - Lora Finetuning Configuration

TARGET_MODULES = [
    "q_proj", "o_proj", "k_proj", "v_proj", 
    "linear", "Conv2d", "lm_head", "fc2"
]

config = LoraConfig(
    r=rank,
    lora_alpha=alpha,
    target_modules=TARGET_MODULES,
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    bias="none",
    inference_mode=False,
    use_rslora=True,
    init_lora_weights="gaussian",
    revision=REVISION
)

peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()

trainable params: 4,133,576 || all params: 826,827,464 || trainable%: 0.4999


In [6]:
# - Image HTML Visualization
# @title Run inference with pre-trained Florence-2 model on validation dataset
import io
import base64
import html
import json
from IPython.display import HTML
def render_inline(image: Image.Image, resize=(128, 128)):
    """Convert image into inline html."""
    image.resize(resize)
    with io.BytesIO() as buffer:
        image.save(buffer, format='jpeg')
        image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
        return f"data:image/jpeg;base64,{image_b64}"


def render_example(image: Image.Image, response):
    try:
        detections = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, response, resolution_wh=image.size)
        image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX).annotate(image.copy(), detections)
        image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX).annotate(image, detections)
    except:
        print('failed to redner model response')
    return f"""
<div style="display: inline-flex; align-items: center; justify-content: center;">
    <img style="width:256px; height:256px;" src="{render_inline(image, resize=(128, 128))}" />
    <p style="width:512px; margin:10px; font-size:small;">{html.escape(json.dumps(response))}</p>
</div>
"""


def render_inference_results(model, dataset: DetectionDataset, count: int):
    html_out = ""
    count = min(count, len(dataset))
    for i in range(count):
        data = dataset.dataset[i]
        image = cv2.imread(str(data['image']))
        image = Image.fromarray(image)
        prefix = data['prefix']
        suffix = data['suffix']
        inputs = processor(text=prefix, images=image, return_tensors="pt").to(DEVICE)
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            num_beams=3
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        answer = processor.post_process_generation(generated_text, task='<OD>', image_size=image.size)
        html_out += render_example(image, answer)

    display(HTML(html_out))

render_inference_results(peft_model, val_dataset, 4)

In [7]:
# - Training Process
# - define training loop
from tqdm import tqdm
def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    render_inference_results(peft_model, val_loader.dataset, 6)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):

            input_ids = inputs["input_ids"]
            pixel_values = inputs["pixel_values"]
            labels = processor.tokenizer(
                text=answers,
                return_tensors="pt",
                padding=True,
                return_token_type_ids=False
            ).input_ids.to(DEVICE)

            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

            loss.backward(), optimizer.step(), lr_scheduler.step(), optimizer.zero_grad()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Average Training Loss: {avg_train_loss}")

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for inputs, answers in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):

                input_ids = inputs["input_ids"]
                pixel_values = inputs["pixel_values"]
                labels = processor.tokenizer(
                    text=answers,
                    return_tensors="pt",
                    padding=True,
                    return_token_type_ids=False
                ).input_ids.to(DEVICE)

                outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
                loss = outputs.loss

                val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)
            print(f"Average Validation Loss: {avg_val_loss}")

            if (epoch + 1) % 100 == 0:
                render_inference_results(peft_model, val_loader.dataset, 6)
                
        if (epoch+1) % 100 == 0:
            output_dir = f"./model_checkpoints/large_model/epoch_{epoch+1}"
            os.makedirs(output_dir, exist_ok=True)
            model.save_pretrained(output_dir)
            processor.save_pretrained(output_dir)
            
train_model(train_loader, val_loader, peft_model, processor, epochs=EPOCHS, lr=LR)

Training Epoch 1/10000: 100%|██████████| 40/40 [00:44<00:00,  1.12s/it]


Average Training Loss: 5.38374844789505


Validation Epoch 1/10000: 100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


Average Validation Loss: 4.256768083572387


Training Epoch 2/10000: 100%|██████████| 40/40 [00:45<00:00,  1.13s/it]


Average Training Loss: 4.172733461856842


Validation Epoch 2/10000: 100%|██████████| 10/10 [00:05<00:00,  1.70it/s]


Average Validation Loss: 3.949125123023987


Training Epoch 3/10000: 100%|██████████| 40/40 [00:45<00:00,  1.14s/it]


Average Training Loss: 3.8831851720809936


Validation Epoch 3/10000: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Average Validation Loss: 3.706639266014099


Training Epoch 4/10000: 100%|██████████| 40/40 [00:45<00:00,  1.14s/it]


Average Training Loss: 3.6284681379795076


Validation Epoch 4/10000: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Average Validation Loss: 3.4201030254364015


Training Epoch 5/10000: 100%|██████████| 40/40 [00:46<00:00,  1.15s/it]


Average Training Loss: 3.4264903604984283


Validation Epoch 5/10000: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Average Validation Loss: 3.2634347677230835


Training Epoch 6/10000: 100%|██████████| 40/40 [00:45<00:00,  1.15s/it]


Average Training Loss: 3.3104514956474302


Validation Epoch 6/10000: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Average Validation Loss: 3.172368884086609


Training Epoch 7/10000: 100%|██████████| 40/40 [00:45<00:00,  1.15s/it]


Average Training Loss: 3.219625008106232


Validation Epoch 7/10000: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Average Validation Loss: 3.09794659614563


Training Epoch 8/10000: 100%|██████████| 40/40 [00:45<00:00,  1.15s/it]


Average Training Loss: 3.1438324749469757


Validation Epoch 8/10000: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Average Validation Loss: 3.0360220193862917


Training Epoch 9/10000: 100%|██████████| 40/40 [00:45<00:00,  1.15s/it]


Average Training Loss: 3.0845939040184023


Validation Epoch 9/10000: 100%|██████████| 10/10 [00:05<00:00,  1.68it/s]


Average Validation Loss: 2.9865982294082642


Training Epoch 10/10000: 100%|██████████| 40/40 [00:46<00:00,  1.15s/it]


Average Training Loss: 3.03839316368103


Validation Epoch 10/10000: 100%|██████████| 10/10 [00:05<00:00,  1.67it/s]


Average Validation Loss: 2.9494217157363893


Training Epoch 11/10000: 100%|██████████| 40/40 [00:46<00:00,  1.15s/it]


Average Training Loss: 2.997178840637207


Validation Epoch 11/10000: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Average Validation Loss: 2.922942280769348


Training Epoch 12/10000: 100%|██████████| 40/40 [00:45<00:00,  1.14s/it]


Average Training Loss: 2.9660776913166047


Validation Epoch 12/10000: 100%|██████████| 10/10 [00:06<00:00,  1.67it/s]


Average Validation Loss: 2.893644094467163


Training Epoch 13/10000: 100%|██████████| 40/40 [00:45<00:00,  1.15s/it]


Average Training Loss: 2.942171984910965


Validation Epoch 13/10000: 100%|██████████| 10/10 [00:06<00:00,  1.66it/s]


Average Validation Loss: 2.8646610975265503


Training Epoch 14/10000: 100%|██████████| 40/40 [00:45<00:00,  1.15s/it]


Average Training Loss: 2.911027139425278


Validation Epoch 14/10000: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Average Validation Loss: 2.842896556854248


Training Epoch 15/10000: 100%|██████████| 40/40 [00:46<00:00,  1.15s/it]


Average Training Loss: 2.8834163427352903


Validation Epoch 15/10000: 100%|██████████| 10/10 [00:06<00:00,  1.67it/s]


Average Validation Loss: 2.8268357038497927


Training Epoch 16/10000: 100%|██████████| 40/40 [00:45<00:00,  1.15s/it]


Average Training Loss: 2.8584540128707885


Validation Epoch 16/10000: 100%|██████████| 10/10 [00:06<00:00,  1.65it/s]


Average Validation Loss: 2.8128738164901734


Training Epoch 17/10000:  70%|███████   | 28/40 [00:32<00:14,  1.18s/it]


KeyboardInterrupt: 