In [1]:
import os

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output

from visual_storytelling import story_in_sequence

In [2]:
def display_story(story_list):
    for element in story_list:
        if type(element) == str:
            print(element)
        elif type(element) == Image.Image:
            plt.figure(figsize=(3, 3))
            plt.axis('off')
            plt.imshow(np.array(element))
            plt.show()

# Annotate Predictions

### Select experiment

In [3]:
num_captions = 5

include_images = True

In [4]:
vist_images_folder = './visual_storytelling/images'

experiment_name = f'{num_captions}_captions_{"with" if include_images else "no"}_images'

experiment_folder = f'{vist_images_folder}/{experiment_name}'

if os.path.exists(experiment_folder):
    print(experiment_folder)
else:
    print(f'ERROR: {experiment_folder} does not exist.')

./visual_storytelling/images/5_captions_with_images


### Load stories dataframe

In [5]:
stories_csv_path = './visual_storytelling/stories.csv'
stories_df = story_in_sequence.load_stories(stories_csv_path)

### Load annotations

In [6]:
vist_annotations_folder = './visual_storytelling/annotations'

vist_annotations_csv_path = f'{vist_annotations_folder}/{experiment_name}.csv'

if os.path.exists(vist_annotations_csv_path):
    annotations_df = pd.read_csv(vist_annotations_csv_path, encoding='utf8', 
                                 dtype={'story_id': 'str', 'prediction': 'int'})
else:
    annotations_df = pd.DataFrame({ 
        'story_id': pd.Series(dtype='str'), 'prediction': pd.Series(dtype='int') })

### Annotation loop

In [7]:
# scan experiment folder without recursion
root, dirs, files = next(os.walk(experiment_folder))

for filename in files:
    
    story_id = filename.replace('.png', '')
    image_path = os.path.join(root, filename)

    # skip annotated story
    if story_id in annotations_df['story_id'].values:
        continue

    print('\n', story_id, '\n')

    story_list = story_in_sequence.create_story_list(
        stories_df, story_id, num_captions, include_images, as_prompt=False)

    display_story(story_list)

    print('\n', '=' * 50, '\n')

    predictions_image = Image.open(image_path)
    plt.figure(figsize=(5 * 3, 5))
    plt.axis('off')
    plt.imshow(np.array(predictions_image))
    plt.show()

    prediction = ''
    while (prediction != '0') and (prediction != '1') and (prediction != '2'):
        prediction = input('prediction: 0 (false) - 1 (correct) - 2 (stop)')
    
    if prediction == '2':
        clear_output(wait=False)
        break

    annotations_df.loc[len(annotations_df)] = [story_id, int(prediction)]

    clear_output(wait=False)

annotations_df.to_csv(vist_annotations_csv_path, encoding='utf8', index=False)

print('\n', 'Done', '\n')