In [None]:
from typing import List
import gradio as gr
import numpy as np


def upload_file(files):
    file_paths = [file.name for file in files]
    return file_paths


def check_if_folder_has_images(file_paths):
    return len(file_paths) > 0


def show_state_of_files(file_paths):
    print(file_paths.value)


def rank(images, prompt):
    # randomize order of images
    images = [file.name for file in images]
    return np.random.permutation(images)

In [2]:
from huggingface_hub import login

login(token="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", add_to_git_credential=True)

Token is valid.
Your token has been saved in your configured git credential helpers (!f()).
Your token has been saved to /root/.huggingface/token
Login successful


In [None]:
%load_ext autoreload
%autoreload 2
from typing import Union
import torch.nn as nn
import torch
from capit.models import *
import torchvision.transforms as transforms
from PIL import Image
from capit.core.models import (
    CLIPImageTextModel,
    CLIPWithPostProcessingImageTextModel,
    CAPCLIPImageTextModel,
)
from capit.core.data.datasets import ImageTextRetrievalInput


class Ranker(nn.Module):
    def __init__(
        self,
        model_type: Union[
            CLIPImageTextModel,
            CLIPWithPostProcessingImageTextModel,
            CAPCLIPImageTextModel,
        ],
        model_name_or_path: str,
        repo_path: str,
        model_name: str,
        batch: ImageTextRetrievalInput,
        cache_path: str = ".cache/",
        pretrained: bool = True,
        backbone_fine_tunable: bool = True,
    ):
        super().__init__()
        if model_type != CAPCLIPImageTextModel:
            self.model = model_type(
                pretrained=pretrained, model_name_or_path=model_name_or_path
            )
        else:
            self.model = model_type(
                pretrained=pretrained,
                model_name_or_path=model_name_or_path,
                backbone_fine_tunable=backbone_fine_tunable,
            )
        self.model.build(batch=batch)
        self.model = self.model.to(torch.cuda.current_device())
        self.model_name_or_path = model_name_or_path
        self.pretrained = pretrained
        self.model.load_from_repo(
            repo_path=repo_path, model_name=model_name, cache_path=cache_path
        )

    def rank(
        self, prompt_text, challenge_image_paths, collection_image_paths=None
    ):
        with torch.no_grad():
            challenge_images = [
                transforms.ToTensor()(Image.open(file.name))
                for file in challenge_image_paths
            ]
            if collection_image_paths is not None:
                collection_images = [
                    transforms.ToTensor()(Image.open(file.name))
                    for file in collection_image_paths
                ]
            else:
                collection_images = None

            similarities = self.model.forward(
                challenge_images=challenge_images,
                collection_images=collection_images,
                prompt_text=prompt_text,
            )
            rank_similarities_args = torch.argsort(
                similarities.logits_per_image, descending=True
            )[0]
            return [
                challenge_image_paths[i].name for i in rank_similarities_args
            ]

In [None]:
print(hasattr(CLIPImageTextModel, "load_from_repo"))

In [None]:
dummy_inputs = ImageTextRetrievalInput(
    target_image=torch.rand(1, 3, 224, 224),
    challenge_images=torch.rand(1, 50, 3, 224, 224),
    challenge_paths=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"],
    target_text=["a picture of a cat"],
    collection_images=torch.rand(1, 15, 3, 224, 224),
    collection_paths=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"],
)

In [None]:
baseline = Ranker(
    model_type=CLIPImageTextModel,
    model_name_or_path="openai/clip-vit-large-patch14",
    repo_path="evolvingfungus/capit-v2-clip-baseline-True-2e-05-1e-05",
    model_name="latest.pt",
    backbone_fine_tunable=True,
    batch=dummy_inputs,
)
baseline_pp = Ranker(
    model_type=CLIPWithPostProcessingImageTextModel,
    model_name_or_path="openai/clip-vit-large-patch14",
    repo_path="evolvingfungus/capit-v2-clip-with-post-processing-baseline-False-2e-05-1e-05-True",
    model_name="latest.pt",
    backbone_fine_tunable=True,
    batch=dummy_inputs,
)
cap = Ranker(
    model_type=CAPCLIPImageTextModel,
    model_name_or_path="openai/clip-vit-large-patch14",
    repo_path="evolvingfungus/capit-v2-v1.1-cap-False-2e-05-1e-05-True",
    model_name="latest.pt",
    backbone_fine_tunable=True,
    batch=dummy_inputs,
)

model_dict = {"clip": baseline, "clip-pp": baseline_pp, "cap": cap}

In [None]:
from collections import defaultdict
from distutils.command.upload import upload
import gradio as gr


def build_demo(
    model_dict, collection_num_images: int = 0, challenge_num_images: int = 5
):
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(scale=1, min_width=224):
                prompt = gr.Textbox(label="prompt", value="")
        with gr.Row():
            with gr.Column(scale=1, min_width=224):
                rank_status = gr.Button(value="rank", label="rank")
        with gr.Row():
            if collection_num_images > 0:
                collection_images = []
                with gr.Column(scale=collection_num_images, min_width=224):
                    for i in range(collection_num_images):
                        collection_images.append(
                            gr.Image(
                                shape=(224, 224), label=f"collection-image-{i}"
                            )
                        )
            with gr.Column(scale=challenge_num_images, min_width=224):
                challenge_images = []
                for i in range(challenge_num_images):
                    challenge_images.append(
                        gr.Image(shape=(224, 224), label=f"challenge-image-{i}")
                    )
            ranked_images_dict = defaultdict(list)
            for key, model in model_dict.items():
                with gr.Column(scale=challenge_num_images, min_width=224):
                    for i in range(challenge_num_images):
                        ranked_images_dict[key].append(
                            gr.Image(
                                shape=(224, 224),
                                label=f"ranked-image-{key}-{i}",
                            )
                        )
        with gr.Row():
            with gr.Column(scale=2, min_width=224):
                challenge_upload_button = gr.UploadButton(
                    "Browse to select a folder with images for challenge images",
                    file_types=["image"],
                    file_count="multiple",
                )
                challenge_upload_button.upload(
                    upload_file, challenge_upload_button, challenge_images
                )
                collection_upload_button = gr.UploadButton(
                    "Browse to select a folder with images for collection images",
                    file_types=["image"],
                    file_count="multiple",
                )
                collection_upload_button.upload(
                    upload_file, collection_upload_button, collection_images
                )
                for model_name, model in model_dict.items():
                    rank_status.click(
                        fn=model.rank,
                        inputs=[
                            prompt,
                            challenge_upload_button,
                            collection_upload_button,
                        ],
                        outputs=ranked_images_dict[model_name],
                    )

    return demo

In [None]:
num_challenge_images = 50
num_collection_images = 15
demo = build_demo(
    collection_num_images=num_collection_images,
    challenge_num_images=num_challenge_images,
    model_dict=model_dict,
)
demo.queue()
demo.launch(share=True, debug=True)