In [12]:
DEBUG = False

In [13]:
%%html
<style>
.img_box_selected{
    border-radius: 20px;
    background-color: #99FFB9;
}

.mybox_layout {
    flex-direction: row;
}

.hide_overflow {
    overflow: hidden;
}

@media screen and (max-width: 550px) {
    .mybox_layout {
        flex-direction: column;
    }
}

</style>

In [14]:
import ipywidgets as ipyw
import numpy as np

import huggingface_hub
import os
import pathlib
import sys
import ipyspin
import contextlib
import io
import shutil

from collections import namedtuple
from datetime import datetime
from huggingface_hub import HfApi, snapshot_download, login

In [15]:
HF_SURVEY_DATASET = os.environ.get("HF_SURVEY_DATASET", None)
CAMPAIGN_NAME = os.environ.get("CAMPAIGN_NAME", "default")
NUM_QUESTIONS = int(os.environ.get("NUM_QUESTIONS", "5"))
FLUSH_IMAGES_FROM_STORAGE = int(os.environ.get("FLUSH_IMAGES_FROM_STORAGE", "0"))
USE_SAMPLES_IMAGES = int(os.environ.get("USE_SAMPLES_IMAGES", "1"))

FOLDER_DATA = pathlib.Path("..")
HF_API = None

if HF_SURVEY_DATASET:
    with contextlib.redirect_stdout(io.StringIO()) as f:
        login(token=os.environ["HF_TOKEN"])
    HF_API = HfApi()
    FOLDER_DATA = pathlib.Path("/data")
    if FLUSH_IMAGES_FROM_STORAGE == 1:
        if (FOLDER_DATA / "images").exists():
            shutil.rmtree(str(FOLDER_DATA / "images"), ignore_errors=True)

FOLDER_ANSWERS = FOLDER_DATA / "answers" / CAMPAIGN_NAME

images_subfolder = "samples" if USE_SAMPLES_IMAGES == 1 else "full"
FOLDER_MI = FOLDER_DATA / "images" / "metrics_comparison" / images_subfolder / "mi"
FOLDER_BLIP = FOLDER_DATA / "images" / "metrics_comparison" / images_subfolder / "blip"
FOLDER_HPS = FOLDER_DATA / "images" / "metrics_comparison" / images_subfolder / "hps"

In [16]:
class ImageBox:
    def __init__(self, path):
        self.path = path
        self.selected = False

        self.img = None
        self.button = ipyw.Button(
            description="Click to select",
            style=dict(font_size="15pt"),
            icon="plus",
            layout=ipyw.Layout(width="99%", padding_left="10px", padding_right="10px", height="50px"),
        )
        self.button.on_click(self.callback_button)

        self.box = ipyw.VBox(
            [],
            layout=ipyw.Layout(
                align_items="center",
                margin_left="1px",
                margin_right="1px",
                padding="15px",
                border="solid 10px white",
            ),
        )
        self.box.add_class("hide_overflow")

    def _load_img(self, path):
        with open(path, "rb") as fin:
            data = fin.read()
        return data

    def load(self):
        self.selected = False
        self.img = ipyw.Image(value=self._load_img(self.path), format="png")
        self.box.children = (self.img, self.button)
        return self.box

    def reset(self):
        if self.selected:
            self.callback_button()

    def callback_button(self, *args, **kwargs):
        self.selected = not self.selected
        if self.selected:
            self.box.background_color = "blue"
            self.button.description = "Click to remove"
            self.box.add_class("img_box_selected")
            self.button.icon="minus"
        else:
            self.box.background_color = "white"
            self.button.description = "Click to select"
            self.box.remove_class("img_box_selected")
            self.button.icon="plus"

    def _ipython_display_(self):
        if len(self.box.children) == 0:
            self.load()
        display(self.box)

In [17]:
if DEBUG:
    imagebox = ImageBox(
        pathlib.Path(
            "../images/samples/mi/a brown giraffe and a red suitcase_0.png"
        )
    )
    display(imagebox)

In [18]:
class ImageTriplet:
    def __init__(self, prompt, path_mi, path_blip, path_hps, rng=None):
        self.prompt = prompt
        self.path_mi = path_mi
        self.path_blip = path_blip
        self.path_hps = path_hps
        self.rng = rng

        self.box_prompt = ipyw.HTML(
            f"""<h2 style="text-align:center">"{self.prompt.lower()}"</h2>"""
        )
        self.imgbox_mi = ImageBox(self.path_mi)
        self.imgbox_blip = ImageBox(self.path_blip)
        self.imgbox_hps = ImageBox(self.path_hps)

        self.box = ipyw.VBox([])

    def load(self):
        imgboxes = [
            self.imgbox_mi.load(), 
            self.imgbox_blip.load(),
            self.imgbox_hps.load(),
        ]
        
        if not self.rng is None:
            indices = np.arange(len(imgboxes), dtype=int)
            self.rng.shuffle(indices)
            l = [imgboxes[idx] for idx in indices]
            imgboxes = l
            
        self.box_images = ipyw.HBox(imgboxes)
        self.box_images.add_class("mybox_layout")

        self.box.children = (self.box_prompt, self.box_images)
        return self.box

    def reset(self):
        self.imgbox_mi.reset()
        self.imgbox_blip.reset()
        self.imgbox_hps.reset()

    def to_csv_text(self):
        return f'"{self.prompt}",{self.imgbox_mi.selected},{self.imgbox_blip.selected},{self.imgbox_hps.selected}'
        
    def _ipython_display_(self):
        display(self.load())

In [19]:
if DEBUG:
    triplet = ImageTriplet(
        "A bright yellow wall in a bathroom adds appeal to a white tiled floor",
        pathlib.Path(
            "../images/samples/mi/a brown giraffe and a red suitcase_0.png"
        ),
        pathlib.Path(
            "../images/samples/blip/a brown giraffe and a red suitcase_17.png"
        ),
        pathlib.Path(
            "../images/samples/hps/a brown giraffe and a red suitcase_49.png"
        ),
    )
    display(triplet)

In [20]:
class Dataset:
    def __init__(self, folder_mi, folder_blip, folder_hps, rng=None):
        self.folder_mi = folder_mi
        self.folder_blip = folder_blip
        self.folder_hps = folder_hps
        
        self._images_mi = self._get_png_paths(folder_mi)
        self._images_blip = self._get_png_paths(folder_blip)
        self._images_hps = self._get_png_paths(folder_hps)

        self.triplets = {}
        for path_mi in self._images_mi:
            prompt = path_mi.stem.rsplit("_", 1)[0]
            path_blip = self._image_from_prompt(prompt, self._images_blip)
            path_hps = self._image_from_prompt(prompt, self._images_hps)
            assert path_blip
            assert path_hps
            self.triplets[prompt] = ImageTriplet(prompt, path_mi, path_blip, path_hps, rng)

    def _image_from_prompt(self, prompt, images):
        for path in images:
            if path.stem.startswith(prompt):
                return path
        return None

    def reset(self):
        for triplet in self.triplets.values():
            triplet.reset()

    def _get_png_paths(self, folder):
        return {path for path in folder.rglob("*.png") if not "(1)" in str(path)}

    def __len__(self):
        return len(self.triplets)

In [21]:
HTML_WELCOME_AND_INSTRUCTIONS = """
<div style="text-align:center; font-size: 15pt">
<h1>Welcome to our survey ✨</h1>

<p>
<ul style="text-align:left; line-height:1.8; padding-top:20px; padding-bottom: 10px">
<li>The survey is composed of {NUM_QUESTIONS} rounds.</li>
<li>In each round we show you a text prompt and 3 images.</li>
<li>Select which images you think better represent each prompt.</li>
<li>There are no right or wrong answers 😊:</li>
    <ul style="font-size:13pt">
        <li>You can select from 0 (no images) to 3 images (all images).</li>
        <li>Some of the 3 images may seem identical (i.e., this is not a bug).</li>
    <ul>
<ul>
<p>
</div>
"""

HTML_ETHICS = """
<p style="font-size:10pt; padding-top: 10px">
<table style="border: solid 1px lightgray; border-radius: 10px; padding: 5px">
    <tr>
        <td>🔐</td>
        <td>
        We only collected your anonymized answers.
        No cookies or other user tracking is used in this survey.
        </td>
    </tr>
</table>
</p>
"""
class Survey:
    def __init__(self, folder_data, folder_images_mi, folder_images_blip, folder_images_hps, folder_answers, num_questions=-1, randomize=True):
        self.folder_data = folder_data
        self.folder_images_mi = folder_images_mi
        self.folder_images_blip = folder_images_blip
        self.folder_images_hps = folder_images_hps
        self.folder_answers = folder_answers
        self.num_questions = num_questions
        self.randomize = randomize
        
        self.rng = None
        self.dset = None
                                     
        LAYOUT_BUTTON = ipyw.Layout(
            height="100px",
            width="100px",
        )
        STYLE_BUTTON = dict(
            button_color="white",
            font_size="60px",
        )

        self.button_prev = ipyw.Button(
            description="",
            layout=LAYOUT_BUTTON,
            style=STYLE_BUTTON,
            icon="chevron-left",
        )
        self.button_next = ipyw.Button(
            description="",
            layout=LAYOUT_BUTTON,
            style=STYLE_BUTTON,
            icon="chevron-right",
        )
        self.button_prev.on_click(self.callback_button_prev)
        self.button_next.on_click(self.callback_button_next)
        self.label = ipyw.HTML("")

        self.box_content = ipyw.HBox(
            [
                self.button_prev,
                self.button_next,
            ],
            layout=ipyw.Layout(
                align_items="center",
            ),
        )
        self.box_content.add_class("mybox_layout")
        self.box_body = ipyw.VBox([
            self.box_content,
            self.label
        ], layout=ipyw.Layout(align_items="center", visibility="visible")
        )
        self.box_spinner = ipyw.VBox([], layout=ipyw.Layout(align_items="center", visibility="hidden", height="0px"))

        self.box = ipyw.VBox([
                self.box_spinner,
                self.box_body,
        ], layout=ipyw.Layout(align_items="center"))
        display(self.box)
        
        self._install_images()
        self.reset_rng()
        
        self.dset = Dataset(
            self.folder_images_mi,
            self.folder_images_blip,
            self.folder_images_hps,
            self.rng,
        )
        
        if self.num_questions <= 0:
            self.num_questions = len(self.dset)
        else:
            self.num_questions = min(len(self.dset), self.num_questions)
        
        self.show_instruction_page()

    def reset_rng(self):
        seed = 12345
        if self.randomize:
            seed = int(datetime.now().strftime("%H%s"))
        self.rng = np.random.default_rng(seed=seed)

    def callback_start(self, *args, **kwargs):
        self.box.children = (
            self.box_spinner,
            self.box_body,
        )
        self.reset()

    def show_instruction_page(self):
        button_start = ipyw.Button(
            description="Start", 
            icon="play", 
            style=dict(font_size="15pt"),
            layout=ipyw.Layout(height="50px", width="300px")
        )
        button_start.on_click(self.callback_start)
        box_start = ipyw.VBox(
            [
                ipyw.HTML(HTML_WELCOME_AND_INSTRUCTIONS.format(NUM_QUESTIONS=self.num_questions)),
                button_start,
                ipyw.HTML(HTML_ETHICS),
            ], 
            layout=ipyw.Layout(align_items="center")
        )
        self.box.children = (box_start,)

    def _show_box_spinner(self, message):
        self.box_spinner.children = (
            ipyspin.Spinner(scale=0.5),
            ipyw.HTML(f"<h2>{message}</h2>"),
        )
        self.box_body.layout.visibility = "hidden"
        self.box_spinner.layout.height = "300px"
        self.box_spinner.layout.visibility = "visible"
        
    def _hide_box_spinner(self):
        spinner = ipyspin.Spinner(scale=0.5)
        self.box_spinner.layout.visibility = "hidden"
        self.box_spinner.layout.height = "0px"
        self.box_body.layout.visibility = "visible"

    def _install_images(self):
        if not HF_API is None and not FOLDER_MI.exists():
            self._show_box_spinner("Downloading data. This might take a few minutes...")
            if not self.folder_data.exists():
                self.folder_data.mkdir(parents=True)
            snapshot_download(
                repo_id=HF_SURVEY_DATASET,
                repo_type="dataset",
                allow_patterns="images/*.png",
                local_dir=self.folder_data,
                force_download=True,
            )
            self._hide_box_spinner()
        

    def load_triplet(self, idx=0):
        self.label.value = f"""<div style="font-size:15pt">Progress: {idx+1}/{self.num_questions}</div>"""
        if idx == self.num_questions - 1:
            self.box.children = (
                *self.box.children,
                ipyw.HTML(
                    """<div style="color:red; font-size:15pt">This is the last picture. Press the > button once your preference is entered to complete the survey.</div>"""
                ),
            )
        elif len(self.box.children) > 2:
            self.box.children = self.box.children[:-1]
            
        self.box_content.children = (
            self.button_prev,
            ipyspin.Spinner(scale=0.3),
            self.button_next,
        )
        
        box_triplet = self.triplets[idx].load()

        self.button_prev.disabled = idx == 0

        self.box_content.children = (
            self.button_prev,
            box_triplet,
            self.button_next,
        )

    def reset(self):
        self.reset_rng()
        self.dset.rng = self.rng
        self.dset.reset()

        self._idx_triplet = 0
        triplets_available = sorted(
            self.dset.triplets.values(), key=lambda triplet: triplet.prompt
        )

        self.rng.shuffle(triplets_available)
        self.triplets = triplets_available[: self.num_questions]
        self.load_triplet(self._idx_triplet)

    def callback_button_prev(self, *args, **kwargs):
        self._idx_triplet = max(0, self._idx_triplet - 1)
        self.load_triplet(self._idx_triplet)

    def callback_button_next(self, *args, **kwargs):
        import sys
        print("AAAA", file=sys.stderr)
        if self._idx_triplet + 1 == self.num_questions:
            self.show_last_page()
        else:
            self._idx_triplet = min(self.num_questions - 1, self._idx_triplet + 1)
            self.load_triplet(self._idx_triplet)

    def to_csv(self):
        now = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        fname = self.folder_answers / f"answer_{now}.csv"
        if not fname.parent.exists():
            fname.parent.mkdir(parents=True)
            
        print(f"saving: {fname}")
        with open(fname, "w") as fout:
            fout.write("prompt,mi,blip,hps\n")
            fout.write("\n".join([
                triplet.to_csv_text()
                for triplet in self.triplets
            ]))

        if not HF_API is None:
            self._show_box_spinner("Saving data. Do now close this page yet...")
            HF_API.upload_file(
                repo_id=HF_SURVEY_DATASET,
                repo_type="dataset",
                path_or_fileobj=fname,
                path_in_repo=f"answers/{fname.parent.name}/{fname.name}",
            )
            self._hide_box_spinner()

    def callback_new_session(self, *args, **kwargs):
        self.label.value = ""
        self.box.children = (
            self.box_spinner,
            self.box_body,
        )
        self.reset()
        
    def show_last_page(self):
        self.to_csv()

        button_new_session = ipyw.Button(
            description="New session", 
            icon="repeat", 
            style=dict(font_size="15pt"),
            layout=ipyw.Layout(height="50px", width="300px")
        )
        button_new_session.on_click(self.callback_new_session)
        self.box.children = (
            ipyw.HTML("""
                <div style="text-align: center">
                <h1>Thanks for participating in this survey!</h1>
                <p style="font-size:12pt">To contribute another session click the button below, otherwise simply close this tab.</p>
                </div>
                """,
                layout=ipyw.Layout(padding_top="200px"),
            ),
            button_new_session,
        )

    def _ipython_display_(self):
        display(self.box)

In [22]:
survey = Survey(
    FOLDER_DATA,
    FOLDER_MI,
    FOLDER_BLIP,
    FOLDER_HPS,
    FOLDER_ANSWERS,
    NUM_QUESTIONS,
)

VBox(children=(VBox(layout=Layout(align_items='center', height='0px', visibility='hidden')), VBox(children=(HB…