# Overview

This notebook is used to give you a demonstration of how image retrieval is performed by using FROMAGe under different prompt settings. Specifically, it aims to prove how compressing complex text input (in this case in the form of a caption and a dialog) into a more clear and compact manner leads to improved capability of retrieving an image that suits the context. The image below describes the procedure in a more comprehensible way.

&nbsp;

<img src="images_report/visualdialog_scheme.png" alt="Image" />

# Import Model

In [None]:
from visual_dialog import dialog_utils
from src.image_retrieval_vdialog.scripts import models # changed original code
import numpy as np
import matplotlib.pyplot as plt

import json
import os
import io
import base64
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from fromage.utils import get_image_from_url
import openai

# Load the FROMAGe model used in the paper.
model_dir = './fromage_model/'
model = models.load_fromage(model_dir)

# Load data

In [None]:
def load_dialogs(stories_csv_path: str) -> pd.DataFrame:
    return pd.read_csv(stories_csv_path, encoding='utf8', dtype=str)

In [None]:

# Load dataframe
dialogs_csv_path = 'visual_dialog/dialogs.csv'
dialogs_df = load_dialogs(dialogs_csv_path)

# Some useful functions

In [None]:

# Get image file by image id
def get_image(image_id, show_image=False):
    # Loop through all the files in the folder
    for file_name in os.listdir(images_path):
        # Check if the file name ends with the number and .jpg suffix
        if file_name.endswith(str(image_id) + '.jpg'):
            # Create the full path to the matching file
            # Do something with the file, e.g. print its path
            image_path = os.path.join(images_path, file_name)
            image = Image.open(image_path).resize((224, 224)).convert('RGB')
            if show_image:
                plt.figure(figsize=(3, 3))
                plt.axis('off')
                plt.imshow(np.array(image))
                plt.show()
                        
            return image
    raise ValueError(f"Value {image_id} not found in list {images_path}")

def show_image(image):
    plt.figure(figsize=(3, 3))
    plt.axis('off')
    plt.imshow(np.array(image))
    plt.show()

instruction = """Transform the following caption 
with a question and answer dialogue about an image 
into a caption as short as possible while capturing 
all the information that is given: """

def gpt_prompt(prompt):
    input_prompt = instruction + prompt
    print("INPUT PROMPT TO GPT")
    print(input_prompt)
    # Generate text with a maximum length of 100 tokens
    response = openai.Completion.create(
        engine='text-davinci-003',
        prompt= input_prompt,
        temperature=0,
        max_tokens=100,
        n=1,
        stop=None,
    )

    adapted_prompt = response.choices[0].text.strip()
    return adapted_prompt

# Create prompts from dataframe
def get_prompt_list(dialogs_df, num_rows, prompt_length, ret_img=True, adapt_gpt_prompt=True, include_Q_A=False):
    text = ""
    dialog, url_dialog, input_dialog_list, url_dialog_list = [], [], [], []
    for i in range(num_rows-1):
        if int(dialogs_df['round'][i]) == 1:
            image_id = dialogs_df['image_id'][i]
            caption = dialogs_df['caption'][i]
            text += f"Caption: {caption}. "
            img = get_image(image_id)
            # dialog.append(img)
        if int(dialogs_df['round'][i]) <= prompt_length:
            if include_Q_A == True:
                text += f"Q: {dialogs_df['question'][i]}, "
                text += f"A: {dialogs_df['answer'][i]}, "
        if dialogs_df['id'][i+1] != dialogs_df['id'][i]:
            text = text[:-2]
            if adapt_gpt_prompt == True:
                text = gpt_prompt(text)
                print("ADAPTED GPT PROMPT")
                print(text)
            if ret_img == True: 
                dialog.append(text)
                dialog.append(img)
                url_dialog.append(text)
                url_dialog.append(image_id)
            else:
                dialog.append(img)
                dialog.append(text)
                url_dialog.append(image_id)
                url_dialog.append(text)

            # Append the dialog when a new dialog will start next
            input_dialog_list.append(dialog)
            url_dialog_list.append(url_dialog)
            dialog, url_dialog = [], []
            text = ""

    # capture the last row
    if prompt_length == 10:
        if include_Q_A == True:
            text += f"Q: {dialogs_df['question'][num_rows]}, "
            text += f"A: {dialogs_df['answer'][num_rows]}"
    if ret_img == True: 
        url_dialog.append(text)
        url_dialog.append(image_id)
        dialog.append(text)
        dialog.append(img)
    else:
        url_dialog.append(image_id)
        url_dialog.append(text)
        dialog.append(img)
        dialog.append(text)

    url_dialog_list.append(url_dialog)
    input_dialog_list.append(dialog)
    
    return input_dialog_list, url_dialog_list

# Display the prompt and retrieve images from their ids
def display_prompt(output_list):
    for output in output_list:
        # Show an image if possible, otherwise display the text
        try:
            get_image(output, show_image=True)
        except:
            split_Q = output.split('Q')
            for i, line in enumerate(split_Q):
                if len(split_Q) > 1:
                    if i > 0:
                        print(f'Q{line}')
                    else:
                        print(line)
                else:
                    print(line)


# Display the output of the model, retrieve the images by their url
def display_output(story_list):
    for element in story_list:
        if type(element) == str:
            # Show an image if possible, otherwise display the text
            try:
                image = get_image_from_url(element)
                plt.figure(figsize=(3, 3))
                plt.axis('off')
                plt.imshow(np.array(image))
                plt.show()
            except:
                split_Q = element.split('Q')
                for i, line in enumerate(split_Q):
                    if len(split_Q) > 1:
                        if i > 0:
                            print(f'Q{line}')
                        else:
                            print(line)
                    else:
                        print(line)

# Inference

In [None]:
num_tests = 3
ret_img = True

q_a_per_caption = 5
num_rows = num_tests * q_a_per_caption * 2

if num_rows > len(dialogs_df):
    num_rows = len(dialogs_df)

dialog_list, url_dialog_list = dialog_utils.get_prompt_list(dialogs_df, num_rows, q_a_per_caption, ret_img, adapt_prompt_gpt, include_Q_A)

prompt, prompt_to_save = [], []
prompt_list, model_outputs = [], []
counter = 0
num_pt = 1

# Adjustable parameters
provide_context = False
init_prompt = False

for i in range(len(dialog_list)):
    # If the number of required examples are met, use the prompt to obtain output from the model
    if i % num_pt == num_pt-1:
        # Add the last text or image to the model, this is the actual prompt which followed the examples
        prompt += [dialog_list[i][0]]
        prompt_to_save += [url_dialog_list[i][0]]

        if ret_img == True:
            num_words = 0
            prompt += ['[RET]']
            max_num_rets=1
        else:
            num_words = 300
            max_num_rets=0

        # Print the amount of performed tests to be able to keep track of the progress
        print("Test num is", (i+1)/num_pt)

        # Make sure that invalid output will be skipped
        try:
            image_outputs, output_urls_or_caption = model.generate_for_images_and_texts(list(prompt), max_img_per_ret=3, max_num_rets=max_num_rets, num_words=num_words)
            # Add the prompts with the image id to the prompt list and the output containing urls to the model outputs
            if image_outputs is not None:
                prompt_list.append(prompt_to_save)
                model_outputs.append(output_urls_or_caption)
                prompt, prompt_to_save = [], []
            # Skip if the model did not return images
            else:
                print(f'Test {(i+1)/num_pt} failed because model returned None.')
                prompt, prompt_to_save = [], []
                continue
        except:
            print(f'Test {(i+1)/num_pt} failed because of invalid model output.')
            prompt, prompt_to_save = [], []
            continue
        
        # Stop running when the amount of tests is reached
        counter += 1
        if counter == num_tests:
            break

    # Provide examples to the model to show it how to handle certain input
    else:
        if provide_context == True:
            # Add the dialogs with url to the prompts that will be stored in a json file
            # Add the dialogs with PIL.Image objects to the prompts such that the model can handle them
            prompt_to_save += url_dialog_list[i]
            prompt += dialog_list[i]

# Colab