# Overview

This notebook is used to give you a demonstration of how visual question-answering is performed using FROMAGe. Specifically, it aims to prove how in-context learning is applied for vqa using the guided-vqa dataset and how a visual augmentation does not always mean that it an augmentation of the prompt will increase the model's performance for a downstream task. The image below describes the procedure in a more comprehensible way.

&nbsp;

<p align="center">
  <img src="./images_report/gvqa.png" width="1400" height="400" />
</p>

## Import model

In [None]:
import os
from fromage import models
import json 
import numpy as np
from PIL import Image
import torch
from transformers import AutoProcessor, CLIPSegForImageSegmentation, OneFormerProcessor, OneFormerForUniversalSegmentation, AutoTokenizer, AutoModel
from torchvision.transforms.functional import to_pil_image
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Load the FROMAGe model used in the paper.
model_dir = './fromage_model/'
model = models.load_fromage(model_dir)

# Load the first model for segmentation of the query image (CLIPSeg)
processor_clip = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
vis_model_clip = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

# Load the second model for segmentation of the query image (Oneformer)
processor_oneformer = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
vis_model_oneformer = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")

# Load the language model to extract text embeddings (all-MiniLM-L6-v2)
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
lm = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

## Load data

In [None]:
with open('./guided_vqa/guided_vqa_shots_1_ways_2_all_questions.json', 'r') as f:
    vqa_data = json.load(f)

vqa_sublist = vqa_data[20:25]
vqa_images_folder = './guided_vqa'

## Some useful functions

- The display_interleaved_outputs function is used to plot the images generated by the model.
- The clipseg_segment_image and the oneformer_segment image are the function described in the figure of the Overview section. Specifically they refer to the middle and right part of the figure.
- The compute_score function computes the cosine similarity of the text embeddings of the output of the model vs the answer (i.e. label) to determine how close or not is the model to the right answer of the question. The figure below explains in a simple way the aforementioned procedure.

&nbsp;

<p align="center">
  <img src="./images_report/embeds_cos_sim.png" width="700" height="200" />
</p>

In [None]:
def display_interleaved_outputs(model_outputs, one_img_per_ret=True):
    for output in model_outputs:
        if type(output) == str:
            print(output)
        elif type(output) == list:
            if one_img_per_ret:
                plt.figure(figsize=(3, 3))
                plt.imshow(np.array(output[0]))
            else:
                fig, ax = plt.subplots(1, len(output), figsize=(3 * len(output), 3))
                for i, image in enumerate(output):
                    image = np.array(image)
                    ax[i].imshow(image)
                    ax[i].set_title(f'Retrieval #{i+1}')
            plt.show()
        elif type(output) == Image.Image:
            plt.figure(figsize=(3, 3))
            plt.imshow(np.array(output))
            plt.show()

def cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    if len(a.shape) == 1:
        a = a.unsqueeze(0)
    if len(b.shape) == 1:
        b = b.unsqueeze(0)
    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))


def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    token_embeddings = model_output[0] 
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def clipseg_segment_image(question_image):
    encoded_image = processor_clip(images=[question_image], return_tensors = 'pt')
    outputs = vis_model_clip(**encoded_image, conditional_pixel_values = encoded_image.pixel_values)
    segmented_image = outputs.logits.unsqueeze(0)
    segmented_pil_image = to_pil_image(segmented_image).resize((224, 224)).convert('RGB')
    return segmented_pil_image
    

def oneformer_segment_image(question_image):
    # Semantic Segmentation
    semantic_inputs = processor_oneformer(images=question_image, task_inputs=["semantic"], return_tensors="pt")
    semantic_outputs = vis_model_oneformer(**semantic_inputs)
    predicted_semantic_map = processor_oneformer.post_process_semantic_segmentation(semantic_outputs, target_sizes=[question_image.size[::-1]])[0]
    semantic_map = to_pil_image(predicted_semantic_map.float())
    # Instance Segmentation
    instance_inputs = processor_oneformer(images=question_image, task_inputs=["instance"], return_tensors="pt")
    instance_outputs = vis_model_oneformer(**instance_inputs)
    predicted_instance_map = processor_oneformer.post_process_instance_segmentation(instance_outputs, target_sizes=[question_image.size[::-1]])[0]["segmentation"]
    instance_map = to_pil_image(predicted_instance_map.float())
    # Panoptic Segmentation
    panoptic_inputs = processor_oneformer(images=question_image, task_inputs=["panoptic"], return_tensors="pt")
    panoptic_outputs = vis_model_oneformer(**panoptic_inputs)
    predicted_panoptic_map = processor_oneformer.post_process_panoptic_segmentation(panoptic_outputs, target_sizes=[question_image.size[::-1]])[0]["segmentation"]
    panoptic_map = to_pil_image(predicted_panoptic_map.float())
    
    return semantic_map, instance_map, panoptic_map

def compute_score(model_outputs_original, model_outputs_clip_segment,model_outputs_oneformer_segment,answer):
    # Tokenize the input
    encoded_unaugmented_input = tokenizer(model_outputs_original, padding=True, truncation=True, return_tensors='pt')
    encoded_augmented_clip_input = tokenizer(model_outputs_clip_segment, padding=True, truncation=True, return_tensors='pt')
    encoded_augmented_oneformer_input = tokenizer(model_outputs_oneformer_segment, padding=True, truncation=True, return_tensors='pt')
    encoded_target_input = tokenizer(answer, padding=True, truncation=True, return_tensors='pt')

    # FF through the model
    with torch.no_grad():
        model_unaugmented_output = lm(**encoded_unaugmented_input)
        model_augmented_clip_output = lm(**encoded_augmented_clip_input)
        model_augmented_oneformer_output = lm(**encoded_augmented_oneformer_input)
        model_target_output = lm(**encoded_target_input)

    # Process the embeddings
    unaugmented_embeddings = F.normalize(mean_pooling(model_unaugmented_output, encoded_unaugmented_input['attention_mask']), p=2, dim=1)
    augmented_clip_embeddings = F.normalize(mean_pooling(model_augmented_clip_output, encoded_augmented_clip_input['attention_mask']), p=2, dim=1)
    augmented_oneformer_embeddings = F.normalize(mean_pooling(model_augmented_oneformer_output, encoded_augmented_oneformer_input['attention_mask']), p=2, dim=1)
    target_embeddings = F.normalize(mean_pooling(model_target_output, encoded_target_input['attention_mask']), p=2, dim=1)

    # Compute cosine similarity
    augmented_clip_score = cos_sim(augmented_clip_embeddings, target_embeddings)
    augmented_onerformer_score = cos_sim(augmented_oneformer_embeddings, target_embeddings)
    unaugmented_score = cos_sim(unaugmented_embeddings, target_embeddings)

    return augmented_clip_score, augmented_onerformer_score, unaugmented_score

## Inference

We obviously can not pass all the data through the model, because it needs time. So we will just pick 3 examples for the demonstration.

In [None]:
for vqa_dict in vqa_sublist:
    
    image1_path = vqa_dict['image_1']
    image1 = Image.open(os.path.join(vqa_images_folder,image1_path)).resize((224, 224)).convert('RGB')
    caption1 = vqa_dict['caption_1']
    display_interleaved_outputs(image1)
    print(caption1)

    image2_path = vqa_dict['image_2']
    image2 = Image.open(os.path.join(vqa_images_folder,image2_path)).resize((224, 224)).convert('RGB')
    caption2 = vqa_dict['caption_2']
    display_interleaved_outputs(image2)
    print(caption2)

    question_image_path = vqa_dict['question_image']
    question_image = Image.open(os.path.join(vqa_images_folder,question_image_path)).resize((224, 224)).convert('RGB')
    question = vqa_dict['question']
    answer = vqa_dict['answer']
    display_interleaved_outputs(question_image)
    print(question)

    # ClipSeg - segment query image 
    segmented_pil_image = clipseg_segment_image(question_image)

    # Oneformer - segment query image
    semantic_map, instance_map, panoptic_map = oneformer_segment_image(question_image)

    # Generate output using the original query image
    model_input_original = [ 
                image1, caption1, 
                image2, caption2, 
                question_image, 
                'Q: ' + question]
    model_outputs_original = model.generate_for_images_and_texts(model_input_original, num_words=15, max_num_rets=0)

    # Generate output using visual augmented prompt (Oneformer)
    model_input_oneformer_segment = [ 
                image1, caption1, 
                image2, caption2, 
                semantic_map, instance_map, panoptic_map, 
                question_image, 
                'Q: ' + question]
    model_outputs_oneformer_segment = model.generate_for_images_and_texts(model_input_oneformer_segment, num_words=15, max_num_rets=0)

    # Generate output using visual augmented prompt (CLIPSeg)
    model_input_clip_segment = [ 
                image1, caption1, 
                image2, caption2, 
                segmented_pil_image, question_image, 
                'Q: ' + question]
    model_outputs_clip_segment = model.generate_for_images_and_texts(model_input_clip_segment, num_words=15, max_num_rets=0)

    # Compute the scores by comparing the text embeddings
    augmented_clip_score, augmented_onerformer_score, unaugmented_score = compute_score(model_outputs_original, model_outputs_clip_segment,model_outputs_oneformer_segment,answer)

    print("The question is :", question, " and the answer is :" ,answer)
    print("Cos sim between original output - answer :" ,unaugmented_score)
    print("Cos sim between output using CLIPSeg - answer" ,augmented_clip_score)
    print("Cos sim between output using Oneformer - answer" ,augmented_onerformer_score)
    print("\n")

