### TThis notebook includes all the code related to the paper titled `"Fine-Tuning Image-to-Text Models on Liechtenstein Tourist Attractions Using Microsoft GIT and Florence-2 models: A Transfer Learning Approach with Model Tracking in Weights & Biases"`

* `Pejman Ebrahimi` & `Johannes Schneider`
* Department of Information Systems & Computer Science, University of Liechtenstein, Liechtenstein, emails: `pejman.ebrahimi@uni.li` & `johannes.schneider@uni.li`

## 1. Fine-tune GIT on a custom dataset for image captioning

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
!pip install pillow

In [None]:
from huggingface_hub import notebook_login

notebook_login()

### 1.1. Load dataset from Hugging Face
* link to dataset: https://huggingface.co/datasets/arad1367/Liechtenstein_tourist_attractions

In [None]:
import transformers
from datasets import load_dataset

# Load the dataset
ds = load_dataset("arad1367/Liechtenstein_tourist_attractions", split="train[:99%]")
ds

In [None]:
# Check an example image and description from dataset
example = ds[0]
image = example["image"]
width, height = image.size
print(display(image.resize((int(0.3*width), int(0.3*height)))))

example["description"]

### 1.2. Create PyTorch Dataset
> Next, we create a standard PyTorch dataset. Each item of the dataset returns the expected inputs for the model, in this case `input_ids`, `attention_mask` and `pixel_values`.

> We use `GitProcessor` to turn each (image, text) pair into the expected inputs. Basically, the text gets turned into `input_ids` and `attention_mask`, and the image gets turned into `pixel_values`.

In [None]:
from torch.utils.data import Dataset

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        encoding = self.processor(images=item["image"], text=item["description"], padding="max_length", return_tensors="pt")

        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}

        return encoding

### 1.3. Processor & Make train dataset

In [None]:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("microsoft/git-base")
train_dataset = ImageCaptioningDataset(ds, processor)

In [None]:
item = train_dataset[0]
for k,v in item.items():
  print(k,v.shape)

### 1.4. Create PyTorch DataLoader
Next, we create a corresponding `PyTorch DataLoader`, which allows us to get batches of data from the dataset.

We need this as neural networks `(like GIT)` are trained on batches of data, using stochastic gradient descent `(SGD)`.

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=6)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

### 1.5. Define model
Next, we instantiate a model. We start from the `pre-trained GIT-base model` (which was already pre-trained on 4 million image-text pairs by Microsoft).

Feel free to start fine-tuning another GIT model from the hub.

In [None]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
# processor = AutoProcessor.from_pretrained("microsoft/git-base")

### 1.5. Dummy forward pass
It's always good to check the initial loss on a batch. See also the blog above.

In [None]:
outputs = model(input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                pixel_values=batch["pixel_values"],
                labels=batch["input_ids"])
outputs.loss # Loss before fine tuning

### 1.6. Train the model & Model tracking with weights and biases
Next, let's train the model! We use native PyTorch here.

We have a small dataset, we'll let the model overfit it. If it's capable of overfitting it (i.e. achieve zero loss), then that's a great way to know that everything is working properly. See also the blog above.

In [None]:
!pip install wandb
!pip install matplotlib
!pip install nltk
!pip install rouge-score
!pip install rouge
!pip install pycocoevalcap

In [None]:
import wandb
wandb.login() # You need your credential code for W & B

### 1.7. Fine tune microsoft Git model --> loss and BLEU

In [None]:
import transformers
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
import wandb
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu

# Initialize WandB
wandb.init(project="image2text-FineTune-loss-BLEU")  # Replace with your WandB project name and entity

# Load the dataset
ds = load_dataset("arad1367/Liechtenstein_tourist_attractions", split="train[:99%]")
print(ds)

# Define the dataset class
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item["description"], padding="max_length", return_tensors="pt")
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        return encoding

# Load processor and dataset
processor = AutoProcessor.from_pretrained("microsoft/git-base")
train_dataset = ImageCaptioningDataset(ds, processor)

# DataLoader
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=6)

# Load model
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Training loop
model.train()
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch: {epoch + 1}/{num_epochs}")
    total_loss = 0
    predictions = []
    references = []

    for idx, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Store predictions and references for evaluation
        pred_ids = torch.argmax(outputs.logits, dim=-1)
        predictions.append(pred_ids.cpu().numpy())
        references.append(input_ids.cpu().numpy())

    # Calculate average loss
    avg_loss = total_loss / len(train_dataloader)
    wandb.log({"loss": avg_loss})

    # Calculate BLEU score (for evaluation)
    # Flatten the predictions and references
    flat_predictions = [pred.flatten() for sublist in predictions for pred in sublist]
    flat_references = [[ref.flatten()] for sublist in references for ref in sublist]

    # Calculate BLEU score
    bleu_score = corpus_bleu(flat_references, flat_predictions)
    wandb.log({"bleu_score": bleu_score})

    print(f"Average Loss: {avg_loss:.4f}, BLEU Score: {bleu_score:.4f}")

# Finish WandB run
wandb.finish()

### 1.8. Fine tune microsoft Git model --> All criteria

In [None]:
import transformers
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
import wandb
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge import Rouge
from pycocoevalcap.cider.cider import Cider
import nltk

# Download the necessary NLTK data for CIDEr (only needs to be done once)
nltk.download('wordnet')

# Initialize WandB
wandb.init(project="image2text-FineTune-metrics-allCriteria")  # Replace with your WandB project name and entity

# Load the dataset
ds = load_dataset("arad1367/Liechtenstein_tourist_attractions", split="train[:99%]")
print(ds)

# Define the dataset class
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # Process images and text together
        encoding = self.processor(images=item["image"], text=item["description"], padding="max_length", return_tensors="pt")
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        return encoding

# Load processor and dataset
processor = AutoProcessor.from_pretrained("microsoft/git-base")
train_dataset = ImageCaptioningDataset(ds, processor)

# DataLoader
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=6)

# Load model
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Initialize ROUGE and CIDEr
rouge = Rouge()
cider = Cider()

# Training loop
model.train()
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch: {epoch + 1}/{num_epochs}")
    total_loss = 0
    predictions = []
    references = []

    for idx, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Store predictions and references for evaluation
        pred_ids = torch.argmax(outputs.logits, dim=-1)
        predictions.append(pred_ids.cpu().numpy())
        references.append(input_ids.cpu().numpy())

    # Calculate average loss
    avg_loss = total_loss / len(train_dataloader)
    wandb.log({"loss": avg_loss})

    # Flatten predictions and references for evaluation
    flat_predictions = [pred.flatten() for sublist in predictions for pred in sublist]
    flat_references = [ref.flatten() for sublist in references for ref in sublist]

    # Decode predictions and references into strings
    decoded_predictions = [processor.decode(pred, skip_special_tokens=True) for pred in flat_predictions]
    decoded_references = [processor.decode(ref, skip_special_tokens=True) for ref in flat_references]

    # Tokenize the references and predictions for BLEU
    tokenized_references = [[ref.split()] for ref in decoded_references]  # Tokenized reference
    tokenized_predictions = [pred.split() for pred in decoded_predictions]  # Tokenized prediction

    # Calculate BLEU score with smoothing
    smoothing_function = SmoothingFunction().method1
    bleu_score = corpus_bleu(tokenized_references, tokenized_predictions, smoothing_function=smoothing_function)
    wandb.log({"bleu_score": bleu_score})

    # Prepare CIDEr format
    gts = {}
    res = {}
    for i in range(len(decoded_references)):
        img_id = f'img_{i}'  # You can customize this id based on your dataset
        gts[img_id] = [decoded_references[i]]
        res[img_id] = [decoded_predictions[i]]

    # Calculate CIDEr score
    cider_score, _ = cider.compute_score(gts, res)
    wandb.log({"cider_score": cider_score})

    # Calculate ROUGE score
    rouge_scores = rouge.get_scores(decoded_predictions, decoded_references, avg=True)
    wandb.log({"rouge-1": rouge_scores['rouge-1']['f'],
                "rouge-2": rouge_scores['rouge-2']['f'],
                "rouge-l": rouge_scores['rouge-l']['f']})

    print(f"Average Loss: {avg_loss:.4f}, BLEU Score: {bleu_score:.4f}, CIDEr Score: {cider_score:.4f}, ROUGE Scores: {rouge_scores}")

# Finish WandB run
wandb.finish()

### 1.9. Push model to HF

In [None]:
model.push_to_hub("arad1367/Microsoft-git-base-Liechtenstein-TA") # replace your space name and model name
processor.push_to_hub("arad1367/Microsoft-git-base-Liechtenstein-TA")

### 1.10. Predict fine-tune model with Custom images

In [None]:
from PIL import Image
import torch
import matplotlib.pyplot as plt

# Load your custom images
image_paths = [
    "/content/1.jpg",  # Change these to your custom image paths
    "/content/2.jpg",
    "/content/3.jpg",
    "/content/4.jpg"
]

# Function to preprocess and generate caption for a single image
def generate_caption(image_path, model, processor, device):
    custom_image = Image.open(image_path)
    encoding = processor(images=custom_image, return_tensors="pt")
    pixel_values = encoding['pixel_values'].to(device)

    with torch.no_grad():
        outputs = model.generate(pixel_values=pixel_values, max_length=50)

    predicted_caption = processor.decode(outputs[0], skip_special_tokens=True)
    return custom_image, predicted_caption

# Set the model to evaluation mode
model.eval()

# Move model to the correct device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Generate captions for all images
images_captions = [generate_caption(image_path, model, processor, device) for image_path in image_paths]

# Function to split the caption into multiple lines
def split_caption(caption, max_length=20):
    words = caption.split()
    lines = []
    current_line = []
    current_length = 0

    for word in words:
        if current_length + len(word) + 1 <= max_length:
            current_line.append(word)
            current_length += len(word) + 1
        else:
            lines.append(' '.join(current_line))
            current_line = [word]
            current_length = len(word) + 1

    if current_line:
        lines.append(' '.join(current_line))

    return lines

# Display images in a 2x2 grid with captions
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()

for ax, (image, caption) in zip(axes, images_captions):
    # Resize the image to 300x300 pixels
    image = image.resize((300, 300))
    ax.imshow(image)
    ax.axis('off')

    # Split the caption into multiple lines
    caption_lines = split_caption(caption)

    # Add a text box below the image
    for i, line in enumerate(caption_lines):
        ax.text(0.5, -0.1 - i * 0.05, line, fontsize=10, ha='center', transform=ax.transAxes,
                bbox=dict(facecolor='white', alpha=0.8))

plt.tight_layout()
plt.subplots_adjust(wspace=0.3, hspace=0.3)
plt.show()

## 2. Fine-tune Florence-2 on a custom dataset for image captioning

In [None]:
!pip install -q datasets flash_attn timm einops

### 2.1. Load dataset from Hugging Face datasets

In [None]:
from datasets import load_dataset, DatasetDict

# Load the dataset
data = load_dataset("arad1367/Liechtenstein_tourist_attractions_VQA") # the same of previous dataset and we just add question column

# Split the 'train' dataset into 'train' and 'validation'
# e.g., 80% for training, 20% for validation
split_data = data['train'].train_test_split(test_size=0.2)

# Reconstruct the DatasetDict to include both 'train' and 'validation'
data = DatasetDict({
    'train': split_data['train'],
    'validation': split_data['test']
})

# Now you can check the new data dictionary
data['train'], data['validation']

### 2.2. Load model & Processor
* We can load the model using `AutoModelForCausalLM` and the processor using `AutoProcessor` classes of transformers library. Note that we need to pass `trust_remote_code as True` since this model is not a transformers model.

In [None]:
from transformers import AutoModelForCausalLM, AutoProcessor
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True, revision='refs/pr/6').to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True, revision='refs/pr/6')
model

In [None]:
torch.cuda.empty_cache()

### 2.3. Function to run example image

In [None]:
# Function to run the model on an example
def run_example(task_prompt, text_input, image):
    prompt = task_prompt + text_input

    # Ensure the image is in RGB mode
    if image.mode != "RGB":
        image = image.convert("RGB")

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
    return parsed_answer

In [None]:
# Test our function
# Identify the place in the image and provide a brief description
for idx in range(3):
  print(run_example("FLVQA", 'Identify the place in the image and provide a brief description', data['train'][idx]['image']))
  display(data['train'][idx]['image'].resize([350, 350]))

### 2.4. Create Pytorch Dataset

In [None]:
# We need to construct our dataset. Note how we are adding a new task prefix <FLVQA> before the question when constructing the prompt.
from torch.utils.data import Dataset

class DocVQADataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        question = "" + example['question']
        first_answer = example['description']
        # first_answer = example['description'][0]
        image = example['image']
        if image.mode != "RGB":
            image = image.convert("RGB")
        return question, first_answer, image

### 2.5. DataLoader & Batch size

In [None]:
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (AdamW, AutoProcessor, get_scheduler)

def collate_fn(batch):
    questions, answers, images = zip(*batch)
    inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)
    return inputs, answers

# Create datasets
train_dataset = DocVQADataset(data['train'])
val_dataset = DocVQADataset(data['validation'])

# Create DataLoader
batch_size = 6
num_workers = 0

train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)

## 2.6. Fine tune Florence-2 (All criteria) & Model tracking with wandb

In [None]:
import os
import wandb
import torch
from tqdm import tqdm
from transformers import AdamW, get_scheduler
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge import Rouge
from pycocoevalcap.cider.cider import Cider
import nltk

# Download NLTK data (if not already installed)
nltk.download('wordnet')

# Initialize wandb
wandb.init(project="florence2_metrics_all28Sep")  # Change project name and entity accordingly

# Initialize ROUGE and CIDEr
rouge = Rouge()
cider = Cider()

# Training function with added metrics (BLEU, ROUGE, CIDEr)
def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    # Watch the model in wandb (optional, to track gradients)
    wandb.watch(model, log="all", log_freq=100)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        predictions = []
        references = []

        # Training loop
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
            inputs, answers = batch

            input_ids = inputs["input_ids"].to(device)
            pixel_values = inputs["pixel_values"].to(device)
            labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)

            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

            # Backpropagation
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            train_loss += loss.item()

            # Store predictions and references for evaluation
            pred_ids = torch.argmax(outputs.logits, dim=-1)
            predictions.append(pred_ids.cpu().numpy())
            references.append(labels.cpu().numpy())

        avg_train_loss = train_loss / len(train_loader)
        print(f"Average Training Loss: {avg_train_loss}")

        # Log training loss to wandb
        wandb.log({"Training Loss": avg_train_loss, "Epoch": epoch + 1})

        # Calculate metrics (BLEU, ROUGE, CIDEr) for training data
        flat_predictions = [pred.flatten() for sublist in predictions for pred in sublist]
        flat_references = [ref.flatten() for sublist in references for ref in sublist]

        decoded_predictions = [processor.decode(pred, skip_special_tokens=True) for pred in flat_predictions]
        decoded_references = [processor.decode(ref, skip_special_tokens=True) for ref in flat_references]

        tokenized_references = [[ref.split()] for ref in decoded_references]  # Tokenized reference
        tokenized_predictions = [pred.split() for pred in decoded_predictions]  # Tokenized prediction

        # BLEU score
        smoothing_function = SmoothingFunction().method1
        bleu_score = corpus_bleu(tokenized_references, tokenized_predictions, smoothing_function=smoothing_function)
        wandb.log({"BLEU Score (Train)": bleu_score})

        # Prepare CIDEr format
        gts = {}
        res = {}
        for i in range(len(decoded_references)):
            img_id = f'img_{i}'  # You can customize this id based on your dataset
            gts[img_id] = [decoded_references[i]]
            res[img_id] = [decoded_predictions[i]]

        # CIDEr score
        cider_score, _ = cider.compute_score(gts, res)
        wandb.log({"CIDEr Score (Train)": cider_score})

        # ROUGE score
        rouge_scores = rouge.get_scores(decoded_predictions, decoded_references, avg=True)
        wandb.log({
            "ROUGE-1 (Train)": rouge_scores['rouge-1']['f'],
            "ROUGE-2 (Train)": rouge_scores['rouge-2']['f'],
            "ROUGE-L (Train)": rouge_scores['rouge-l']['f']
        })

        # Validation phase
        model.eval()
        val_loss = 0
        val_predictions = []
        val_references = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):
                inputs, answers = batch

                input_ids = inputs["input_ids"].to(device)
                pixel_values = inputs["pixel_values"].to(device)
                labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)

                outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
                loss = outputs.loss

                val_loss += loss.item()

                # Store predictions and references for validation evaluation
                pred_ids = torch.argmax(outputs.logits, dim=-1)
                val_predictions.append(pred_ids.cpu().numpy())
                val_references.append(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        print(f"Average Validation Loss: {avg_val_loss}")

        # Log validation loss to wandb
        wandb.log({"Validation Loss": avg_val_loss, "Epoch": epoch + 1})

        # Calculate metrics (BLEU, ROUGE, CIDEr) for validation data
        flat_val_predictions = [pred.flatten() for sublist in val_predictions for pred in sublist]
        flat_val_references = [ref.flatten() for sublist in val_references for ref in sublist]

        decoded_val_predictions = [processor.decode(pred, skip_special_tokens=True) for pred in flat_val_predictions]
        decoded_val_references = [processor.decode(ref, skip_special_tokens=True) for ref in flat_val_references]

        tokenized_val_references = [[ref.split()] for ref in decoded_val_references]  # Tokenized reference
        tokenized_val_predictions = [pred.split() for pred in decoded_val_predictions]  # Tokenized prediction

        # BLEU score for validation
        bleu_val_score = corpus_bleu(tokenized_val_references, tokenized_val_predictions, smoothing_function=smoothing_function)
        wandb.log({"BLEU Score (Validation)": bleu_val_score})

        # Prepare CIDEr format for validation
        val_gts = {}
        val_res = {}
        for i in range(len(decoded_val_references)):
            img_id = f'img_{i}'  # Customize based on dataset
            val_gts[img_id] = [decoded_val_references[i]]
            val_res[img_id] = [decoded_val_predictions[i]]

        # CIDEr score for validation
        cider_val_score, _ = cider.compute_score(val_gts, val_res)
        wandb.log({"CIDEr Score (Validation)": cider_val_score})

        # ROUGE score for validation
        rouge_val_scores = rouge.get_scores(decoded_val_predictions, decoded_val_references, avg=True)
        wandb.log({
            "ROUGE-1 (Validation)": rouge_val_scores['rouge-1']['f'],
            "ROUGE-2 (Validation)": rouge_val_scores['rouge-2']['f'],
            "ROUGE-L (Validation)": rouge_val_scores['rouge-l']['f']
        })

        # Save model checkpoint
        output_dir = f"./model_checkpoints/epoch_{epoch + 1}"
        os.makedirs(output_dir, exist_ok=True)
        model.save_pretrained(output_dir)
        processor.save_pretrained(output_dir)

    # Finish wandb session
    wandb.finish()

# Example for freezing the vision tower parameters
for param in model.vision_tower.parameters():
    param.requires_grad = False

# Call train_model
train_model(train_loader, val_loader, model, processor, epochs=10)


### 2.7. Push model to HF
* `Important note`: We need modify config.json after push to Hub. Othervise, after load the model you received an error. link to correct config.json: https://huggingface.co/arad1367/Florence-2-Liechtenstein-TA-OCR-VQA-modified/blob/main/config.json

In [None]:
# model.push_to_hub("arad1367/Florence-2-Liechtenstein-TA-OCR-VQA-modified")
# processor.push_to_hub("arad1367/Florence-2-Liechtenstein-TA-OCR-VQA-modified")

### 2.8. Predict on custom image for fine tune Florence 2

In [None]:
from PIL import Image
import torch
import matplotlib.pyplot as plt

# Custom image paths
image_paths = [
    "/content/Uni.jpg",
    "/content/bridge.jpg",
    "/content/ski.jpg",
    "/content/rrr.jpg"
]

# Function to preprocess and generate caption for a single image with the new script
def run_example_with_caption(image_path, task_prompt, text_input):
    # Open the custom image from the file path
    image = Image.open(image_path)

    # Combine task prompt and text input
    prompt = f"{task_prompt} {text_input}".strip()

    # Ensure the image is in RGB mode
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Preprocess inputs (text + image)
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)

    # Generate output from the model (keep input_ids in Long type)
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],  # Use input IDs as they are (Long tensor)
        pixel_values=inputs["pixel_values"].float(),  # Ensure pixel values are float32
        max_new_tokens=256,  # Adjust to allow longer outputs
        num_beams=5,         # Increase number of beams for more diversity
        early_stopping=True   # Stop when all beams finish
    )

    # Decode the generated output
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

    # Clean up the generated text
    cleaned_text = generated_text.replace("</s>", "").replace("<s>", "").strip()

    return image, cleaned_text

# Function to split the caption into multiple lines for better display
def split_caption(caption, max_length=20):
    words = caption.split()
    lines = []
    current_line = []
    current_length = 0

    for word in words:
        if current_length + len(word) + 1 <= max_length:
            current_line.append(word)
            current_length += len(word) + 1
        else:
            lines.append(' '.join(current_line))
            current_line = [word]
            current_length = len(word) + 1

    if current_line:
        lines.append(' '.join(current_line))

    return lines

# Example usage with your custom images
task_prompt = "FLVQA:"
text_input = "Name of the place in the image and description"

# Generate captions for all custom images
images_captions = [run_example_with_caption(image_path, task_prompt, text_input) for image_path in image_paths]

# Display images in a 2x2 grid with captions
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()

for ax, (image, caption) in zip(axes, images_captions):
    # Resize the image to 300x300 pixels
    image = image.resize((300, 300))
    ax.imshow(image)
    ax.axis('off')

    # Split the caption into multiple lines
    caption_lines = split_caption(caption)

    # Add a text box below the image
    for i, line in enumerate(caption_lines):
        ax.text(0.5, -0.1 - i * 0.05, line, fontsize=10, ha='center', transform=ax.transAxes,
                bbox=dict(facecolor='white', alpha=0.8))

plt.tight_layout()
plt.subplots_adjust(wspace=0.3, hspace=0.3)
plt.show()