In [1]:
DEBUG = False
%config Completer.use_jedi = False

In [2]:
if not DEBUG:
    from IPython.display import HTML, Javascript
    from IPython.display import display as ipythondisplay
    import os
     
    ipythondisplay(
        HTML(
            """
    <style>
    .jp-Notebook .jp-Cell {
        padding: 0px;
        margin: 0px;
    }
     
    :root {
        --jp-widgets-container-padding: 3px;
    }
    </style>
    """
        )
    )

    survey_title = os.environ.get("SURVEY_TITLE", "survey")
    ipythondisplay(Javascript(f'document.title = " {survey_title}";'))

<IPython.core.display.Javascript object>

In [3]:
%%html
<style>
.img_box_selected{
    border-radius: 20px;
    background-color: #99FFB9;
}

.mybox_layout {
    grid-template-columns: 1fr 1fr 1fr 1fr;
}

.hide_overflow {
    overflow: hidden;
}

@media screen and (max-width: 550px) {
    .mybox_layout {
        grid-template-columns: 1fr 1fr;
    }

</style>

In [4]:
import ipywidgets as ipyw
import numpy as np
import pandas as pd

import huggingface_hub
import os
import pathlib
import sys
import ipyspin
import contextlib
import io
import shutil
import itertools

from collections import namedtuple, OrderedDict
from datetime import datetime
from huggingface_hub import HfApi, snapshot_download, login

In [5]:
HF_SURVEY_DATASET = os.environ.get("HF_SURVEY_DATASET", None)
CAMPAIGN_NAME = os.environ.get("CAMPAIGN_NAME", "default")
NUM_QUESTIONS_PER_CATEGORY = int(os.environ.get("NUM_QUESTIONS_PER_CATEGORY", "1"))
FLUSH_IMAGES_FROM_STORAGE = int(os.environ.get("FLUSH_IMAGES_FROM_STORAGE", "0"))
USE_SAMPLES_IMAGES = int(os.environ.get("USE_SAMPLES_IMAGES", "1"))

# os.environ["CURL_CA_BUNDLE"] = ""

FOLDER_DATA = pathlib.Path("..")
HF_API = None

if HF_SURVEY_DATASET:
    with contextlib.redirect_stdout(io.StringIO()) as f:
        login(token=os.environ["HF_TOKEN"])
    HF_API = HfApi()
    
    FOLDER_DATA = pathlib.Path("/data")
    
# folder = FOLDER_DATA / "images"
# # if folder.exists():
# #     shutil.rmtree(str(folder), ignore_errors=True)
# # folder.mkdir(parents=True)

# snapshot_download(
#     repo_id=HF_SURVEY_DATASET,
#     repo_type="dataset",
#     allow_patterns=f"images/methods_comparison/full/spatial/*.png",
#     local_dir=FOLDER_DATA,
#     force_download=True,
#     etag_timeout=3600,
# )

# shutil.rmtree("/data/images/images", ignore_errors=True)

# shutil.rmtree("/data/images/methods_comparison", ignore_errors=True)
# shutil.move("/data/images/images/methods_comparison", "/data/images")

FOLDER_ANSWERS = FOLDER_DATA / "answers" / "methods_comparison" / CAMPAIGN_NAME

FOLDER_IMAGES = FOLDER_DATA / "images" / "methods_comparison"
if USE_SAMPLES_IMAGES:
    FOLDER_IMAGES /= "samples"
else:
    FOLDER_IMAGES /= "full"

FOLDER_CATEGORIES_DICT = {
    path.name: path 
    for path in FOLDER_IMAGES.iterdir()
    if path.name[0] != "."
}

#display(FOLDER_ANSWERS)

# display(FOLDER_CATEGORIES_DICT)
#display(FOLDER_IMAGES)
#display(f"samples: {len(list((FOLDER_DATA / "images" / "methods_comparison" / "samples").rglob("*/*/*.png")))}")
#display(f"color: {len(list(FOLDER_IMAGES.rglob("color/*/*.png")))}")
#display(f"complex: {len(list(FOLDER_IMAGES.rglob("complex/*/*.png")))}")
#display(f"non_spatial: {len(list(FOLDER_IMAGES.rglob("non_spatial/*/*.png")))}")
#display(f"shape: {len(list(FOLDER_IMAGES.rglob("shape/*/*.png")))}")
#display(f"spatial: {len(list(FOLDER_IMAGES.rglob("spatial/*/*.png")))}")
#display(f"texture: {len(list(FOLDER_IMAGES.rglob("texture/*/*.png")))}")
###

# display(FOLDER_DATA)
# display(f"samples: {len(list(FOLDER_DATA.rglob("images/methods_comparison/samples/*/*/*.png")))}")
# display(f"color: {len(list(FOLDER_DATA.rglob("images/methods_comparison/full/color/*/*.png")))}")
# display(f"complex: {len(list(FOLDER_DATA.rglob("images/methods_comparison/full/complex/*/*.png")))}")
# display(f"non_spatial: {len(list(FOLDER_DATA.rglob("images/methods_comparison/full/non_spatial/*/*.png")))}")
# display(f"shape: {len(list(FOLDER_DATA.rglob("images/methods_comparison/full/shape/*/*.png")))}")
# display(f"spatial: {len(list(FOLDER_DATA.rglob("images/methods_comparison/full/spatial/*/*.png")))}")
# display(f"texture: {len(list(FOLDER_DATA.rglob("images/methods_comparison/full/texture/*/*.png")))}")

In [6]:
class ImageBox:
    def __init__(self, path):
        self.path = path
        self.selected = False

        self.img = None
        self.button = ipyw.Button(
            description="Select",
            style=dict(font_size="15pt"),
            icon="plus",
            layout=ipyw.Layout(width="99%", padding_left="10px", padding_right="10px", height="45px"),
        )
        self.button.on_click(self.callback_button)

        self.box = ipyw.VBox(
            [],
            layout=ipyw.Layout(
                align_items="center",
                margin_left="1px",
                margin_right="1px",
                padding="10px",
                border="solid 3px white",
            ),
        )
        self.box.add_class("hide_overflow")

    def _load_img(self, path):
        with open(path, "rb") as fin:
            data = fin.read()
        return data

    def load(self):
        self.selected = False
        self.img = ipyw.Image(value=self._load_img(self.path), format="png")
        self.box.children = (self.img, self.button)
        return self.box

    def reset(self):
        if self.selected:
            self.callback_button()

    def callback_button(self, *args, **kwargs):
        self.selected = not self.selected
        if self.selected:
            self.box.background_color = "blue"
            self.button.description = "Remove"
            self.box.add_class("img_box_selected")
            self.button.icon="minus"
        else:
            self.box.background_color = "white"
            self.button.description = "Select"
            self.box.remove_class("img_box_selected")
            self.button.icon="plus"

    def _ipython_display_(self):
        if len(self.box.children) == 0:
            self.load()
        display(self.box)

In [7]:
if DEBUG:
    imagebox = ImageBox(
        pathlib.Path(
            "../images/methods_comparison/samples/color/MI/a brown horse and a blue vase_000004.png"
        )
    )
    display(imagebox)

In [8]:
class ImageGrid:
    def __init__(self, category, prompt, path_images, ncols=-1, rng=None):
        self.category = category
        self.prompt = prompt
        self.path_images = path_images
        self.rng = rng
        self.ncols = ncols

        self.imgbox_dict = OrderedDict([
            (path.parent.name, ImageBox(path))
            for path in self.path_images
        ])

        self.nrows = 1
        if ncols == -1:
            self.ncols = len(self.imgbox_dict)
        else:
            self.nrows = int(np.ceil(len(self.imgbox_dict) / self.ncols))
            
        self.box_prompt = ipyw.HTML(
            f"""<h2 style="text-align:center; padding: 0px; margin: 0px">"{self.prompt.lower()}"</h2>""",
            layout=ipyw.Layout(
                visibility="hidden"
            )
        )

        self.box_spinner = ipyw.HBox(
            [
                ipyspin.Spinner(scale=0.3),
            ], 
            layout=ipyw.Layout(
                height="200px",
                # border="solid 1px red",
                visibility="visible",
            ),
        )
        
        self.box_images = ipyw.GridBox(
            [],
            layout=ipyw.Layout(
                # grid_template_columns=" ".join(itertools.repeat("1fr", self.ncols)),
                # grid_template_rows=" ".join(itertools.repeat("1fr", self.nrows)),
                visibility="hidden",
            )
        )
        self.box_images.add_class("mybox_layout")

        self.box = ipyw.VBox([
            self.box_spinner,
            self.box_prompt,
            self.box_images,
        ])

    def load(self):
        self.box_spinner.layout.visibility = "visible"
        self.box_prompt.layout.visibility = "hidden"
        self.box_images.layout.visibility = "hidden"
        self.box_spinner.layout.height = "200px"
        self.box_prompt.layout.height = "auto"
        self.box_images.layout.height = "auto"
        
        imgboxes = [
            imgbox.load()
            for imgbox in self.imgbox_dict.values()
        ]
        
        if not self.rng is None:
            indices = np.arange(len(imgboxes), dtype=int)
            self.rng.shuffle(indices)
            l = [imgboxes[idx] for idx in indices]
            imgboxes = l
            
        # self.box_images = ipyw.HBox(imgboxes)
        # self.box_images.add_class("mybox_layout")
        # self.box_images = ipyw.GridBox(
        #     imgboxes,
        #     layout=ipyw.Layout(
        #         grid_template_columns=" ".join(itertools.repeat("1fr", self.ncols)),
        #         grid_template_rows=" ".join(itertools.repeat("1fr", self.nrows))
        #     )
        # )
        self.box_images.children = imgboxes

        # self.box.children = (self.box_prompt, self.box_images)
        self.box_spinner.layout.visibility = "hidden"
        self.box_prompt.layout.visibility = "visible"
        self.box_images.layout.visibility = "visible"
        self.box_spinner.layout.height = "0px"
        self.box_prompt.layout.height = "auto"
        self.box_images.layout.height = "auto"
        return self.box

    def reset(self):
        for imgbox in self.imgbox_dict.values():
            imgbox.reset()


    def to_pandas(self):
        names = sorted(self.imgbox_dict.keys())
        values = [self.imgbox_dict[name].selected for name in names]
        df = pd.DataFrame([values], columns=names)
        df = df.assign(
            prompt = repr(self.prompt),
            category = self.category
        )
        df = df[["category", "prompt"] + names]
        return df
        
    def _ipython_display_(self):
        display(self.load())
        # display(self.box)
        # pass

In [9]:
if DEBUG:
    imggrid = ImageGrid(
        "color",
        "A green banana and a brown horse",
        [
            pathlib.Path(
                "../images/methods_comparison/samples/color/AandE/a brown horse and a blue vase_000004.png"
            ),
            pathlib.Path(
                "../images/methods_comparison/samples/color/DPOK/a brown horse and a blue vase_000004.png"
            ),
            pathlib.Path(
                "../images/methods_comparison/samples/color/GORS/a brown horse and a blue vase_000004.png"
            ),
            pathlib.Path(
                "../images/methods_comparison/samples/color/HardNeg/a brown horse and a blue vase_000004.png"
            ),
            pathlib.Path(
                "../images/methods_comparison/samples/color/MI/a brown horse and a blue vase_000004.png"
            ),
            pathlib.Path(
                "../images/methods_comparison/samples/color/S-CFG/a brown horse and a blue vase_000004.png"
            ),
            pathlib.Path(
                "../images/methods_comparison/samples/color/Structured/a brown horse and a blue vase_000004.png"
            ),
            pathlib.Path(
                "../images/methods_comparison/samples/color/VNL/a brown horse and a blue vase_000004.png"
            ),
        ],
        ncols=4,
        rng=np.random.default_rng(seed=12345),
    )
    display(imggrid)
    display(imggrid.to_pandas())

In [10]:
# imggrid.load()
# # # imggrid.box_spinner.height
# # imggrid.box_spinner.layout.visibility = "hidden"
# # imggrid.box_spinner.layout.height = "auto"

In [11]:
class DatasetCategory:
    def __init__(self, category, folder, ncols=-1, rng=None):
        self.category = category
        self.folder = folder
        self.rng = rng
        self.ncols = ncols

        # load all png path
        # find all prompts
        self.methods_dict = dict()
        self._prompts = set()
        for method in folder.iterdir():
            data = dict()
            for path in method.rglob("*.png"):
                prompt = path.stem.rsplit("_", 1)[0]
                data[prompt] = path
                self._prompts.add(prompt)
            self.methods_dict[method] = data
        self._prompts = list(self._prompts)
        self.methods = sorted(self.methods_dict.keys())

        self.imggrid = dict()
        for prompt in self._prompts:
            self.imggrid[prompt] = None

    def next(self, cycle=True):
        while True:
            if self.rng:
                self.rng.shuffle(self._prompts)
            for prompt in self._prompts:
                grid = self.imggrid[prompt]
                if not grid:
                    grid = ImageGrid(
                        self.category,
                        prompt,
                        [prompt_data[prompt] for prompt_data in self.methods_dict.values()],
                        self.ncols,
                        self.rng
                    )
                yield grid
            if not cycle:
                break
                
    def reset(self):
        for imggrid in self.imggrid.values():
            if imggrid: 
                imggrid.reset()

    def set_rng(self, rng):
        self.rng = rng
        for grid in self.imggrid.values():
            if grid: 
                grid.rng = rng

    def __len__(self):
        return len(self.imggrid)

In [12]:
if DEBUG:
    dset_category = DatasetCategory("shape", FOLDER_CATEGORIES_DICT["shape"], ncols=4)
    display(next(dset_category.next()))

In [13]:
class Dataset:
    def __init__(self, folder_dict, ncols=-1, rng=None):
        self.folder_dict = folder_dict
        self.rng = rng

        self._categories = sorted(self.folder_dict.keys())
        self.dset_categories = {
            category: DatasetCategory(
                category,
                self.folder_dict[category],
                ncols,
                rng
            )
            for category in self._categories
        }

        self._methods = self.dset_categories[self._categories[0]].methods
        
    def next(self):
        if self.rng:
            self.rng.shuffle(self._categories)
        return [
            next(self.dset_categories[category].next(cycle=True))
            for category in self._categories
        ]
                
    def reset(self):
        for dset in self.dset_categories.values():
            dset.reset()

    def set_rng(self, rng):
        self.rng = rng
        for dset_category in self.dset_categories.values():
            dset_category.set_rng(rng)

    def __len__(self):
        return sum([len(dset) for dset in self.dset_categories.values()])

In [14]:
if DEBUG:
    dset = Dataset(FOLDER_CATEGORIES_DICT, ncols=4, rng=np.random.default_rng(seed=12345))

In [19]:
HTML_WELCOME_AND_INSTRUCTIONS = """
<div style="text-align:center; font-size: 15pt">
<h1>Welcome to our survey ✨</h1>

<p>
<ul style="text-align:left; line-height:1.8; padding-top:20px; padding-bottom: 10px">
<li>The survey is composed of {NUM_QUESTIONS} rounds.</li>
<li>In each round we show you a text prompt and {NUM_METHODS} images.</li>
<li>Select which images you think better represent each prompt.</li>
<li>There are no right or wrong answers 😊:</li>
    <ul style="font-size:13pt">
        <li>You can select from 0 (no images) to {NUM_METHODS} images (all images).</li>
        <li>Some images may be similar.</li>
        <li>Pay more attention to coherence rather than aesthetics.</li>
    <ul>
<ul>
<p>
</div>
"""

HTML_ETHICS = """
<p style="font-size:10pt; padding-top: 10px">
<table style="border: solid 1px lightgray; border-radius: 10px; padding: 5px">
    <tr>
        <td>🔐</td>
        <td>
        We only collected your anonymized answers.
        No cookies or other user tracking is used in this survey.
        </td>
    </tr>
</table>
</p>
"""
class Survey:
    def __init__(self, folder_dict, folder_answers, num_questions_per_category=-1, ncols=-1, randomize=True):
        self.folder_dict = folder_dict
        self.folder_answers = folder_answers
        self.num_questions_per_category = num_questions_per_category
        self.randomize = randomize
        self.ncols = ncols
        
        self.rng = None
        self.dset = None
        self._spinner = ipyspin.Spinner(scale=0.3)
                                     
        LAYOUT_BUTTON = ipyw.Layout(
            height="100px",
            width="100px",
        )
        STYLE_BUTTON = dict(
            button_color="white",
            font_size="60px",
        )

        self.button_prev = ipyw.Button(
            description="",
            layout=LAYOUT_BUTTON,
            style=STYLE_BUTTON,
            icon="chevron-left",
        )
        self.button_next = ipyw.Button(
            description="",
            layout=LAYOUT_BUTTON,
            style=STYLE_BUTTON,
            icon="chevron-right",
        )
        self.button_prev.on_click(self.callback_button_prev)
        self.button_next.on_click(self.callback_button_next)
        self.label = ipyw.HTML("")

        self.box_content = ipyw.HBox(
            [
                self.button_prev,
                self.button_next,
            ],
            layout=ipyw.Layout(
                align_items="center",
            ),
        )
        # self.box_content.add_class("mybox_layout")
        self.box_body = ipyw.VBox([
            self.box_content,
            self.label
        ], layout=ipyw.Layout(align_items="center", visibility="visible")
        )
        self.box_spinner = ipyw.VBox([], layout=ipyw.Layout(align_items="center", visibility="hidden", height="0px"))

        self.box = ipyw.VBox(
            [
                self.box_spinner,
                self.box_body,
            ], 
            layout=ipyw.Layout(align_items="center"))
        display(self.box)
        
        # self._install_images()
        self.reset_rng()
        
        self.dset = Dataset(
            self.folder_dict,
            self.ncols,
            self.rng,
        )

        self.num_questions = len(self.dset)
        if self.num_questions_per_category > 0:
            self.num_questions = min(self.num_questions, self.num_questions_per_category * len(self.folder_dict))
        
        self.show_instruction_page()

    def reset_rng(self):
        seed = 12345
        if self.randomize:
            seed = int(datetime.now().strftime("%H%s"))
        self.rng = np.random.default_rng(seed=seed)

    def callback_start(self, *args, **kwargs):
        self.box.children = (
            self.box_spinner,
            self.box_body,
        )
        self.reset()

    def show_instruction_page(self):
        button_start = ipyw.Button(
            description="Start", 
            icon="play", 
            style=dict(font_size="15pt"),
            layout=ipyw.Layout(height="50px", width="300px")
        )
        button_start.on_click(self.callback_start)
        box_start = ipyw.VBox(
            [
                ipyw.HTML(HTML_WELCOME_AND_INSTRUCTIONS.format(
                    NUM_QUESTIONS=self.num_questions,
                    NUM_METHODS=len(self.dset._methods),
                )),
                button_start,
                ipyw.HTML(HTML_ETHICS),
            ], 
            layout=ipyw.Layout(align_items="center")
        )
        self.box.children = (box_start,)

    def _show_box_spinner(self, message):
        if self.box.children and isinstance(self.box.children[-1], ipyw.HTML):
            self.box.children = self.box.children[:-1]
        self.box_body.layout.visibility = "hidden"
        self.box_body.layout.height = "0px"
        self.box_spinner.children = (
            ipyspin.Spinner(scale=0.5),
            ipyw.HTML(f"<h2>{message}</h2>"),
        )
        self.box_spinner.layout.height = "300px"
        if message == "":
            self.box_spinner.layout.width = "300px"
        self.box_spinner.layout.visibility = "visible"
        
    def _hide_box_spinner(self):
        self.box_spinner.layout.visibility = "hidden"
        self.box_spinner.layout.height = "0px"
        self.box_body.layout.visibility = "visible"
        self.box_body.layout.height = "auto"

    # def _install_images(self):
    #     # if not HF_API is None:
    #     #     missing = [
    #     #         folder
    #     #         for folder in FOLDER_CATEGORIES_DICT.values()
    #     #         if not folder.exists()
    #     #     ]

    #     #     for folder in missing:
    #     #         missing.mkdir(parents=True)

    #     #     if missing:
    #     #         self._show_box_spinner("Downloading data. This might take a few minutes...")
    #     #         folder.mkdir(parents=True)
    #     #         snapshot_download(
    #     #             repo_id=HF_SURVEY_DATASET,
    #     #             repo_type="dataset",
    #     #             allow_patterns="images/*.png",
    #     #             local_dir=self.folder_data,
    #     #             force_download=True,
    #     #         )
    #     #         self._hide_box_spinner()
    #     pass
        

    def load_imggrid(self, idx=0):
        self._show_box_spinner("")
        
        self.label.value = f"""<div style="font-size:15pt">Progress: {idx+1}/{self.num_questions}</div>"""
        if idx == self.num_questions - 1:
            self.box.children = (
                *self.box.children,
                ipyw.HTML(
                    """<div style="color:red; font-size:15pt">This is the last picture. Press the > button once your preference is entered to complete the survey.</div>"""
                ),
            )
        elif len(self.box.children) > 2:
            self.box.children = self.box.children[:-1]
            
        box_imggrid = self._imggrids[idx].load()
        self.button_prev.disabled = idx == 0

        self._hide_box_spinner()
        self.box_content.children = (
            self.button_prev,
            box_imggrid,
            self.button_next,
        )

    def reset(self):
        self._show_box_spinner("")
        self.reset_rng()
        self.dset.set_rng(self.rng)
        self.dset.reset()

        self._idx_imggrid = 0
        self._imggrids = []
        for _ in range(self.num_questions_per_category):
            self._imggrids.extend(self.dset.next())
        self.rng.shuffle(self._imggrids)
        self.load_imggrid(self._idx_imggrid)

    def callback_button_prev(self, *args, **kwargs):
        self._idx_imggrid = max(0, self._idx_imggrid - 1)
        self.load_imggrid(self._idx_imggrid)

    def callback_button_next(self, *args, **kwargs):
        if self._idx_imggrid + 1 == self.num_questions:
            self.show_last_page()
        else:
            self._idx_imggrid = min(self.num_questions - 1, self._idx_imggrid + 1)
            self.load_imggrid(self._idx_imggrid)

    def to_csv(self):
        now = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        fname = self.folder_answers / f"answer_{now}.csv"
        if not fname.parent.exists():
            fname.parent.mkdir(parents=True)

        df_answers = pd.concat([imggrid.to_pandas() for imggrid in self._imggrids])
        print(f"saving: {fname}")
        df_answers.to_csv(fname, index=None)

        _, path = str(fname).split("answers/", 1)

        if not HF_API is None:
            self._show_box_spinner("Saving. Do now close this page...")
            HF_API.upload_file(
                repo_id=HF_SURVEY_DATASET,
                repo_type="dataset",
                path_or_fileobj=fname,
                path_in_repo=f"answers/{path}",
            )
            self._hide_box_spinner()

    def callback_new_session(self, *args, **kwargs):
        self.label.value = ""
        self.box.children = (
            self.box_spinner,
            self.box_body,
        )
        self.reset()
        
    def show_last_page(self):
        self.to_csv()

        button_new_session = ipyw.Button(
            description="New session", 
            icon="repeat", 
            style=dict(font_size="15pt"),
            layout=ipyw.Layout(height="50px", width="300px")
        )
        button_new_session.on_click(self.callback_new_session)
        self.box.children = (
            ipyw.HTML("""
                <div style="text-align: center">
                <h1>Thanks for participating in this survey!</h1>
                <p style="font-size:12pt">To contribute another session click the button below, otherwise simply close this tab.</p>
                </div>
                """,
                layout=ipyw.Layout(padding_top="200px"),
            ),
            button_new_session,
        )

    def _ipython_display_(self):
        display(self.box)

In [16]:
survey = Survey(
    folder_dict=FOLDER_CATEGORIES_DICT,
    folder_answers=FOLDER_ANSWERS,
    ncols=4,
    num_questions_per_category=NUM_QUESTIONS_PER_CATEGORY,
)

VBox(children=(VBox(layout=Layout(align_items='center', height='0px', visibility='hidden')), VBox(children=(HB…