In [1]:
%load_ext autoreload
%autoreload 2

from typing import List
import gradio as gr
import numpy as np




In [8]:
from capit.core.data.datasets import (
    ChallengeSamplesSourceTypes,
    InstagramImageTextMultiModalDatasePyArrow,
    SplitType,
)
from capit.core.data.datasets_old import ToThreeChannels
from torchvision.transforms import Compose, RandomCrop, Resize, ToTensor


def get_image_transforms_instait():
    return Compose([Resize((224, 224)), ToTensor(), ToThreeChannels()])


dataset_dir = "/data_fast/"
dataset = InstagramImageTextMultiModalDatasePyArrow(
    dataset_dir=dataset_dir,
    set_name=SplitType.TRAIN,
    top_k_percent=100,
    reset_cache=False,
    num_episodes=1000,
    max_num_collection_images_per_episode=100,
    max_num_query_images_per_episode=100,
    challenge_image_source=ChallengeSamplesSourceTypes.WITHIN_USER,
    dummy_batch_mode=False,
    model_name_or_path="openai/clip-vit-base-patch32",
)

In [3]:
for key, value in dataset[0].__dict__.items():
    print(f"{key}: {value.shape if hasattr(value, 'shape') else len(value)}")


def rank(prompt, sample_idx):
    # randomize order of images
    sample = dataset[sample_idx]
    # collection_image_data = [item['image'] for item in sample.collection_paths] # sample.collection_images
    # challenge_image_data = [item['image'] for item in sample.challenge_paths] # sample.challenge_images

    images = [item["image"] for item in sample.challenge_paths]
    return np.random.permutation(images)


class RandomModel:
    def rank(self, prompt, sample_idx):
        return rank(prompt, sample_idx)

In [4]:
from collections import defaultdict
from distutils.command.upload import upload
from typing import Any
import gradio as gr
from PIL import Image


def load_images(sample_idx):
    sample = dataset[sample_idx]
    collection_image_data = [
        item["image"] for item in sample.collection_paths
    ]  # sample.collection_images
    challenge_image_data = [
        item["image"] for item in sample.challenge_paths
    ]  # sample.challenge_images

    if collection_image_data is not None:
        return *[gr.update(value=image) for image in collection_image_data], *[
            gr.update(value=image) for image in challenge_image_data
        ]
    else:
        return (*[gr.update(value=image) for image in challenge_image_data],)


def build_demo(
    model_dict: Any = None,
    collection_num_images: int = 100,
    challenge_num_images: int = 100,
):
    with gr.Blocks() as demo:
        with gr.Row():
            sample_idx_slider = gr.Slider(
                maximum=len(dataset),
                randomize=True,
                step=1,
                interactive=True,
                label="Datapoint idx to sample",
                info="Select the idx to sample",
            )

        with gr.Row():
            with gr.Column(scale=1, min_width=224):
                prompt = gr.Textbox(label="prompt", value="")
            with gr.Column(scale=1, min_width=224):
                rank_status = gr.Button(value="rank", label="rank")

        with gr.Row():
            if collection_num_images > 0:
                collection_images = []
                with gr.Column(scale=collection_num_images, min_width=224):
                    for i in range(collection_num_images):
                        collection_images.append(
                            gr.Image(label=f"collection-image-{i}")
                        )
            with gr.Column(scale=challenge_num_images, min_width=224):
                challenge_images = []
                for i in range(challenge_num_images):
                    challenge_images.append(
                        gr.Image(label=f"challenge-image-{i}")
                    )

            ranked_images_dict = defaultdict(list)
            for key, model in model_dict.items():
                with gr.Column(scale=challenge_num_images, min_width=224):
                    for i in range(challenge_num_images):
                        ranked_images_dict[key].append(
                            gr.Image(
                                shape=(224, 224),
                                label=f"ranked-image-{key}-{i}",
                            )
                        )

        sample_idx_slider.change(
            load_images,
            inputs=[sample_idx_slider],
            outputs=[*collection_images, *challenge_images],
        )

        for model_name, model in model_dict.items():
            rank_status.click(
                fn=model.rank,
                inputs=[prompt, sample_idx_slider],
                outputs=ranked_images_dict[model_name],
            )

    return demo

In [6]:
num_challenge_images = 100
num_collection_images = 100
demo = build_demo(
    collection_num_images=num_collection_images,
    challenge_num_images=num_challenge_images,
    model_dict=dict(random=RandomModel()),
)
demo.queue(concurrency_count=8)
demo.launch(share=True, debug=True, enable_queue=True)