# Introduction
This notebook provides a comprehensive guide to fine-tuning Florance-2 with a custom task called **`<WebAgent>`**, designed to detect and identify webpage elements based on natural language descriptions. The goal of this fine-tuning process is to enhance the model’s ability to recognize UI components, such as buttons, links, and images.

# Install required dependencies
Before running the next steps, install the necessary dependencies:

* **matplotlib** for visualization

* **transformers** for model loading and fine-tuning

* **datasets** for loading training data

In [None]:
!pip install -q matplotlib
!pip install -q transformers
!pip install -q datasets

# Model
We will use Florance 2 for creating the base of our webagent. The model should be able to detect any item on a webpage given a description. For example, "A button that contains the text 'Login'"


> Read more about Florance 2 [here](https://www.microsoft.com/en-us/research/publication/florence-2-advancing-a-unified-representation-for-a-variety-of-vision-tasks/)



We will be using the florance model from hugging face for fine tuning on our dataset.


In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoProcessor

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = 'microsoft/Florence-2-base-ft'
custom_task = "<WebAgent>"

def load_model(model_path=None):
   """
    Loads the Florence model and processor.

    Args:
        model_path (str, optional): Path to a locally saved model. If None, loads from Hugging Face.

    Returns:
        processor (AutoProcessor): Pre-trained processor for handling text and images.
        model (AutoModelForCausalLM): Loaded model ready for inference or fine-tuning.
    """
    if(model_path == None):
      model = AutoModelForCausalLM.from_pretrained(
      model_name,
      trust_remote_code=True,
      ).to(device)
      processor = AutoProcessor.from_pretrained(model_name,
                                                trust_remote_code=True)

      return processor, model

    else:
      model = AutoModelForCausalLM.from_pretrained(model_path,
                                                   trust_remote_code=True,).to(device)
      processor = AutoProcessor.from_pretrained(model_path,
                                                trust_remote_code=True)

      return processor, model

processor, model = load_model()

## Processor modifications
Refer to https://huggingface.co/microsoft/Florence-2-large/blob/main/processing_florence2.py for the processor implementation.

We will add a custom task and train the model on that task. We will modify the following variables in the processor which will allow us to introduce our custom task to the existing processor logic.

In [None]:
processor.tasks_answer_post_processing_type[custom_task] = 'description_with_bboxes_or_polygons'
processor.task_prompts_with_input[custom_task] = 'Locate {input} in the image.'

## Helper functions
Below are the functions which will be used to handle model related tasks.

### Plotting bouding boxes and polygons around the detected element

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches

def plot_bbox_and_polygon(image, data):
    """
    Plots bounding boxes and polygons on an image.

    Args:
        image (PIL.Image): Image to be annotated.
        data (dict): Dictionary containing bounding boxes, labels, and polygon information.
    """
    fig, ax = plt.subplots(figsize=(15, 10))
    ax.imshow(image)
    for bbox, bbox_label, in zip(data['bboxes'], data['bboxes_labels']):
        x1, y1, x2, y2 = bbox
        rect = patches.Rectangle((x1, y1),
                                 x2 - x1,
                                 y2 - y1,
                                 linewidth=2,
                                 edgecolor='lime',
                                 facecolor='none')
        ax.add_patch(rect)
        plt.text(x1,
                 y1,
                 bbox_label,
                 color='black',
                 fontsize=8,
                 bbox=dict(facecolor='lime', alpha=1))




    for polygons, polygon_label, in zip(data['polygons'], data['polygons_labels']):
      for _polygon in polygons:
              polygon = np.array(_polygon).reshape(-1, 2)  # Convert list of points to numpy array
              ax.plot(polygon[:, 0], polygon[:, 1], marker='o', linestyle='-', color='red', linewidth=2)
              ax.fill(polygon[:, 0], polygon[:, 1], color='red', alpha=0.3)  # Fill polygon with transparency

              # Label polygon (placing the text at the first point of the polygon)
              plt.text(polygon[0, 0], polygon[0, 1], polygon_label, color='white', fontsize=8,
                      bbox=dict(facecolor='red', alpha=0.8))


    ax.axis('off')
    plt.show()

### Output helper function

In [None]:
def get_model_output(model, processor, images, task_prompt, text_input=None):
    """
    Generates model output for a given image and task prompt.

    Args:
        model (AutoModelForCausalLM): The fine-tuned Florence model.
        processor (AutoProcessor): The processor to handle inputs.
        images (PIL.Image): Input image.
        task_prompt (str): Task-specific prompt.
        text_input (str, optional): Additional text input to refine the prompt.

    Returns:
        dict: Processed response containing detected bounding boxes and polygons.
    """
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)

    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        early_stopping=False,
        do_sample=False,
        num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids,
                                            skip_special_tokens=False)[0]

    print(generated_text)

    parsed_answer = processor.post_process_generation(
          generated_text,
          task=task_prompt,
          image_size=(image.width, image.height))

    return parsed_answer

### Bounding box helper functions.
Florance requires bouding boxes to be formatted to 'loc_x', 'loc_y' The following methods will convert a bounding boxes into the required format.

In [None]:
def convert_bbox_to_relative(box, image):
  """ converts bounding box pixel coordinates to relative coordinates in the
  range 0-999 """
  return [
      round((box[0]/image.width)*999),
      round((box[1]/image.height)*999),
      round((box[2]/image.width)*999),
      round((box[3]/image.height)*999),
    ]

def convert_relative_to_loc(relative_coordinates):
  """ converts a list of relative coordinate positions x1, y1, x2, y2 to a
  string of position tokens """
  return ''.join([f'<loc_{i}>' for i in relative_coordinates])

def convert_bbox_to_loc(box, image):
  """ convert bounding box pixel coordinates to position tokens """
  relative_coordinates = convert_bbox_to_relative(box, image)
  return convert_relative_to_loc(relative_coordinates)

In [None]:
# helper function to render a few examples from the validation dataset during training
def render_inference_results(model, processor, dataset, count, task):
    html_out = ""
    count = min(count, len(dataset))
    for i in range(count):
        prefix, suffix, image = dataset.__getitem__(i)
        modified_prefix = prefix.replace(task, "")
        response = get_model_output(model,processor,image, task,modified_prefix)
        plot_bbox_and_polygon(image,response[task])

## Run the base model
Here we will try out the model with a **`WebAgent`** task that we have introduced into the processor. We will draw a bounding box around the detected object using our helper functions defined above.

In [None]:
import requests
from PIL import Image
from io import BytesIO

response = requests.get("https://media.geeksforgeeks.org/wp-content/cdn-uploads/20210401151214/What-is-Website.png")
image = Image.open(BytesIO(response.content)).convert("RGB")
text_prompt = "Arrows"
response = get_model_output(model,processor,image,custom_task,text_prompt)

print("Processed Response:", response)

In [None]:
plot_bbox_and_polygon(image,response[custom_task])

# Dataset
We will uset the wave-ui-25k dataset to train our model to detect different elements from a web-page.

**wave-ui-25k** - https://huggingface.co/datasets/agentsea/wave-ui-25k


In [None]:
from datasets import load_dataset

ds = load_dataset("agentsea/wave-ui-25k")

In [None]:
# since the split only has train, we will manually split it into train,valid and test

# First, split into 80% train and 20% temp (validation + test)
temp_split = ds["train"].train_test_split(test_size=0.2, seed=42)

# Now, split temp into 50% validation, 50% test (resulting in 10% each)
final_split = temp_split["test"].train_test_split(test_size=0.5, seed=42)

# Assign datasets
train_split = temp_split["train"]
valid_split = final_split["train"]
test_split = final_split["test"]

# Print sizes
print(f"Train size: {len(train_split)}, Validation size: {len(valid_split)}, Test size: {len(test_split)}")

## Create the custom dataset for training
Here we will create a custom dataset that returns a prefix, suffix and the image.
* **The prefix** contains the custom task label followed by the prompt (description of the element we are trying to find in the page)
* **The suffix** is what the model should output, i.e the bounding boxes or a polygon.
If you want to return polygons, you might want to wrap the suffix with a `<poly>` tag. The format would look like - `label<poly><loc_x><loc_y>...</poly>`



In [None]:
import random
from torch.utils.data import Dataset

class WebUI(Dataset):

    def __init__(self,dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        description = item['description']
        image = item['image']

        prefix = custom_task + description
        suffix = description + convert_bbox_to_loc(item['bbox'],item['image'])
        image = item['image'].convert("RGB")
        return prefix, suffix, image

# Fine-tuning

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 4 # adjust the batch size based on your gpu memory
NUM_WORKERS = 0

def collate_fn(batch):
    prefix, suffix, images = zip(*batch)
    inputs = processor(text=list(prefix), images=list(images),
                       return_tensors="pt", padding=True).to(device)
    return inputs, suffix

train_dataset = WebUI(train_split)
valid_dataset = WebUI(valid_split)
test_dataset = WebUI(test_split)

# Slice the dataset to only include the first 1000 samples
# train_subset = torch.utils.data.Subset(train_dataset, range(1000))

# To test the training with a subset of the data, uncomment the above line
# and replace train_dataset below with train_subset.
# This will allow you to ensure your training is running as intended
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                          collate_fn=collate_fn, num_workers=NUM_WORKERS,
                          shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE,
                          collate_fn=collate_fn, num_workers=NUM_WORKERS,
                          shuffle=True)

## LoRA config
We will fine tune our model using LoRA. Since the original model is pretty good at OCR and finding objects from an image, we will try to keep as much knowledge of the original model as possible. Hence **we use small values for r and alpha** ensuring our fine tuned model does not deviate too much from its original state.

In [None]:
from peft import LoraConfig, get_peft_model

TARGET_MODULES = [
    "q_proj", "o_proj", "k_proj", "v_proj",
    "linear", "Conv2d", "lm_head", "fc2"
]

config = LoraConfig(
    r=4, # adjust as required
    lora_alpha=8, # adjust as required
    target_modules=TARGET_MODULES,
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    bias="none",
    inference_mode=False,
    use_rslora=True,
    init_lora_weights="gaussian",
)

peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()

In [None]:
torch.cuda.empty_cache()

## Traning Process

In [None]:
import os
from transformers import (
    AdamW,
    get_scheduler
)
from tqdm import tqdm

def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):
    """
    Trains the Florence model using LoRA fine-tuning.

    Args:
        train_loader (DataLoader): Training dataset loader.
        val_loader (DataLoader): Validation dataset loader.
        model (AutoModelForCausalLM): The model to be trained.
        processor (AutoProcessor): The processor for handling input data.
        epochs (int, optional): Number of training epochs. Default is 5.
        lr (float, optional): Learning rate for optimization. Default is 5e-6.
    """
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )


    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
            input_ids = inputs["input_ids"]
            pixel_values = inputs["pixel_values"]
            labels = processor.tokenizer(
                text=answers,
                return_tensors="pt",
                padding=True,
                return_token_type_ids=False
            ).input_ids.to(device)

            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

            loss.backward(), optimizer.step(), lr_scheduler.step(), optimizer.zero_grad()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Average Training Loss: {avg_train_loss}")

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for inputs, answers in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):

                input_ids = inputs["input_ids"]
                pixel_values = inputs["pixel_values"]
                labels = processor.tokenizer(
                    text=answers,
                    return_tensors="pt",
                    # max_length=1024,
                    # truncation=True,
                    padding=True,
                    return_token_type_ids=False
                ).input_ids.to(device)

                outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
                loss = outputs.loss

                val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)
            print(f"Average Validation Loss: {avg_val_loss}")

            # comment out the line below to avoid displaying images after every epoch
            render_inference_results(model, processor, valid_dataset, 4, custom_task)

        output_dir = f"./model_checkpoints/epoch_{epoch+1}"
        os.makedirs(output_dir, exist_ok=True)
        model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)


In [None]:
%%time

EPOCHS = 5 # adjust as required
LR = 5e-6

train_model(train_loader, valid_loader, peft_model, processor, epochs=EPOCHS, lr=LR)

## Download the model

In [None]:
!zip -r model_checkpoint.zip ./model_checkpoints/epoch_5 # change as required

In [None]:
from google.colab import files

files.download('model_checkpoint.zip')

# Test the model

## Test manually on a webpage screenshot

In [None]:
import requests
from PIL import Image
from io import BytesIO

peft_model.eval()
text_prompt = "Back button"
response = requests.get("https://media.gcflearnfree.org/content/55e07807bae0135431cfdcb3_12_17_2013/read_webpage_labeled.jpg")
image = Image.open(BytesIO(response.content)).convert("RGB")
response = get_model_output(peft_model,processor,image,custom_task,text_prompt)

print("Processed Response:", response)
plot_bbox_and_polygon(image,response[custom_task])