# Overview

This notebook is used to give you a demonstration of how image captioning is performed using FROMAGe. Specifically, it aims to prove how in-context learning is applied for image captioning using the Flickr-8k dataset and how a visual augmentation can increase the model's performance for this downstream task.

&nbsp;

<p align="center">
  <img src="./images_report/Visual_augmentation_of_prompt.png" width="920" height="280" />
</p>

## Import model

In [None]:
import pandas as pd 
from PIL import Image
import os
import numpy as np
from fromage import models
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Load the FROMAGe model used in the paper.
model_dir = './fromage_model/'
model = models.load_fromage(model_dir)

# Load the all-MiniLM-L6-v2 model to compare the text embeddings of the output with those of the original caption.
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
lm = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

## Load data

In [None]:
def split_dictionary(input_dict: dict, chunk_size: int) -> list:
    res = []
    new_dict = {}
    for k, v in input_dict.items():
        if len(new_dict) < chunk_size:
            new_dict[k] = v
        else:
            res.append(new_dict)
            new_dict = {k: v}
    res.append(new_dict)
    return res

In [None]:
df = pd.read_csv('./Flickr8k_text/ExpertAnnotations.txt',delimiter='\t')
cropped_df = df.loc[df['expert1'] == 4]
cap_df = pd.read_csv('./Flickr8k_text/Flickr8k.token.txt',delimiter='\t')
cap_dict = pd.Series(cap_df.cap.values,index=cap_df.cap_id).to_dict()
data_dict = {}
for img_id, cap_id in zip(cropped_df.image_id, cropped_df.caption_id):
    caption = cap_dict[cap_id]
    data_dict[img_id] = caption
    
flickr_data = split_dictionary(data_dict,1)

## Some useful functions

- The cos_sim computes function the cosine similarity. We need it to compare the text embeddings of the outputs of the model.
- The mean_pooling function is used as an processing tool for the embeddings.
- The display_interleaved_outputs function is used to plot the images generated by the model.
- The compare_embeddings function uses the Mini-LM-L6 model to generate the text embeddings and then calls the mean_pooling function and the cos_sim function to provide a final score. The following image describes the aforementioned procedure.

&nbsp;

<p align="center">
  <img src="./images_report/embeds_cos_sim.png" width="700" height="200" />
</p>

In [None]:
def cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    if len(a.shape) == 1:
        a = a.unsqueeze(0)
    if len(b.shape) == 1:
        b = b.unsqueeze(0)
    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))


def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    token_embeddings = model_output[0] 
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def display_interleaved_outputs(model_outputs, one_img_per_ret=True):
    for output in model_outputs:
        if type(output) == str:
            print(output)
        elif type(output) == list:
            if one_img_per_ret:
                plt.figure(figsize=(3, 3))
                plt.imshow(np.array(output[0]))
            else:
                fig, ax = plt.subplots(1, len(output), figsize=(3 * len(output), 3))
                for i, image in enumerate(output):
                    image = np.array(image)
                    ax[i].imshow(image)
                    ax[i].set_title(f'Retrieval #{i+1}')
            plt.show()
        elif type(output) == Image.Image:
            plt.figure(figsize=(3, 3))
            plt.imshow(np.array(output))
            plt.show()


def compare_embeddings(unaugmented_input, augmented_input, answer):

    # Generate captions using the unaugmented and the augmented input
    unaugmented_output = model.generate_for_images_and_texts(unaugmented_input, num_words=15, max_num_rets=0)
    augmented_output = model.generate_for_images_and_texts(augmented_input, num_words=15, max_num_rets=0)

    # Tokenize the input
    encoded_unaugmented_input = tokenizer(unaugmented_output, padding=True, truncation=True, return_tensors='pt')
    encoded_augmented_input = tokenizer(augmented_output, padding=True, truncation=True, return_tensors='pt')
    encoded_target_input = tokenizer(answer, padding=True, truncation=True, return_tensors='pt')

    # FF through the model
    with torch.no_grad():
        model_unaugmented_output = lm(**encoded_unaugmented_input)
        model_augmented_output = lm(**encoded_augmented_input)
        model_target_output = lm(**encoded_target_input)

    # Process the embeddings
    unaugmented_embeddings = F.normalize(mean_pooling(model_unaugmented_output, encoded_unaugmented_input['attention_mask']), p=2, dim=1)
    augmented_embeddings = F.normalize(mean_pooling(model_augmented_output, encoded_augmented_input['attention_mask']), p=2, dim=1)
    target_embeddings = F.normalize(mean_pooling(model_target_output, encoded_target_input['attention_mask']), p=2, dim=1)

    # Compute the cos sim 
    augmented_score = cos_sim(augmented_embeddings, target_embeddings)
    unaugmented_score = cos_sim(unaugmented_embeddings, target_embeddings)

    return augmented_score, unaugmented_score

## Inference

We obviously can not pass all the data through the model, because it needs time. So we will just pick 3 examples for the demonstration.

In [None]:
flickr_images_folder = './Flicker8k_Dataset/'
flickr_data = flickr_data[5:8] 

for flickr_dict in flickr_data:
    try:

        flickr_keys = list(flickr_dict.keys())
        flickr_values = list(flickr_dict.values())

        # Load query image & caption
        question_image_path = flickr_keys[0]
        question_image = Image.open(os.path.join(flickr_dict,question_image_path)).resize((224, 224)).convert('RGB')
        print("The original query image is the following :")
        display_interleaved_outputs([question_image])
        question = 'Caption the image.'
        answer = flickr_values[0]

        # Generate caption using the unaugmented prompt
        unaugmented_prompt = [question_image,question]
        unaugmented_output = model.generate_for_images_and_texts(unaugmented_prompt, num_words=15, max_num_rets=0)


        # Use query image to retrieve two similar ones
        prompt_for_ret = [question_image, 'Give a similar image to the previous one [RET]']
        augmented_outputs = model.generate_for_images_and_texts(prompt_for_ret, max_img_per_ret=2) 
        for out in augmented_outputs:
                if type(out) == str:
                    continue
                elif type(out) == list:
                    similar_image1 = out[0]
                    print("The first image similar to the query image retrieved by the model is the following :")
                    display_interleaved_outputs([similar_image1])
                    similar_image2 = out[1]
                    print("The second image similar to the query image retrieved by the model is the following :")
                    display_interleaved_outputs([similar_image2])
                else:
                    continue

        model_augmented_input = [similar_image1, similar_image2, question_image, question]
        augmented_output = model.generate_for_images_and_texts(model_augmented_input, num_words=15, max_num_rets=0)

        augmented_score, unaugmented_score = compare_embeddings(unaugmented_output, augmented_output, answer)

        print("Caption without using any augmentation", unaugmented_output, "| Cos sim with target :",unaugmented_score.item())
        print("Caption using augmentation", augmented_output, "| Cos sim with target :",augmented_score.item())
        print("Ground truth :", answer)
        print("\n")
        
    except:
        continue