In [23]:
%load_ext autoreload
%autoreload 2

from typing import List
import gradio as gr
import numpy as np

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
from capit.core.data.datasets import *
from capit.core.data.datasets_old import *
from torchvision.transforms import Compose, RandomCrop, Resize, ToTensor


def get_image_transforms_instait():
    return Compose([Resize((224, 224)), ToTensor(), ToThreeChannels()])


dataset_dir = "/data_fast/"
dataset = InstagramImageTextMultiModalDatasePyArrow(
    dataset_dir=dataset_dir,
    set_name=SplitType.TRAIN,
    top_k_percent=100,
    reset_cache=False,
    num_episodes=1000,
    max_num_collection_images_per_episode=100,
    max_num_query_images_per_episode=100,
    challenge_image_source=ChallengeSamplesSourceTypes.WITHIN_USER,
    dummy_batch_mode=False,
    model_name_or_path="openai/clip-vit-base-patch32",
)

In [25]:
for key, value in dataset[0].__dict__.items():
    print(f"{key}: {value.shape if hasattr(value, 'shape') else len(value)}")


def rank(prompt, sample_idx):
    # randomize order of images
    sample = dataset[sample_idx]
    # collection_image_data = [item['image'] for item in sample.collection_paths] # sample.collection_images
    # challenge_image_data = [item['image'] for item in sample.challenge_paths] # sample.challenge_images

    images = [item["image"] for item in sample.challenge_paths]
    return np.random.permutation(images)


class RandomModel:
    def rank(self, prompt, sample_idx):
        return rank(prompt, sample_idx)

In [26]:
from typing import Union
from huggingface_hub import login

login(token="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", add_to_git_credential=True)

import torch
from capit.core.models import *
import torchvision.transforms as transforms
from PIL import Image
from capit.core.models import (
    CLIPImageTextModel,
    CLIPWithPostProcessingImageTextModel,
    CAPCLIPImageTextModel,
)
from capit.core.data.datasets import ImageTextRetrievalInput

accelerator = Accelerator(mixed_precision="bf16")


class Ranker(nn.Module):
    def __init__(
        self,
        model_type: Union[
            CLIPImageTextModel,
            CLIPWithPostProcessingImageTextModel,
            CAPCLIPImageTextModel,
        ],
        model_name_or_path: str,
        repo_path: str,
        model_name: str,
        batch: ImageTextRetrievalInput,
        cache_path: str = ".cache/",
        pretrained: bool = True,
        backbone_fine_tunable: bool = True,
    ):
        super().__init__()
        if model_type != CAPCLIPImageTextModel:
            self.model = model_type(
                pretrained=pretrained, model_name_or_path=model_name_or_path
            )
        else:
            self.model = model_type(
                pretrained=pretrained,
                model_name_or_path=model_name_or_path,
                backbone_fine_tunable=backbone_fine_tunable,
            )
        self.model.build(batch=batch)
        self.accelerator = Accelerator(mixed_precision="bf16")
        self.model = self.accelerator.prepare(self.model)
        self.model_name_or_path = model_name_or_path
        self.pretrained = pretrained
        self.load_from_repo(
            hf_repo_path=repo_path,
            model_name=model_name,
            hf_cache_dir=cache_path,
        )
        # hf_repo_path: str, model_name: str, hf_cache_dir: str

    def rank(self, prompt_text, sample_idx):
        with torch.no_grad():
            sample = dataset[sample_idx]
            collection_image_data = sample.collection_images.to(
                accelerator.device
            )
            challenge_image_data = sample.challenge_images.to(
                accelerator.device
            )
            prompt_text_ids = dataset.processor(
                text=prompt_text,
                padding=True,
                truncation=True,
                return_tensors="pt",
            )["input_ids"].to(accelerator.device)

            # print(prompt_text_ids.shape, challenge_image_data.shape, collection_image_data.shape)

            similarities = self.model.predict_individual(
                challenge_image_data, prompt_text_ids, collection_image_data
            )
            rank_similarities_args = torch.argsort(
                similarities, descending=True
            )[0]
            return [
                sample.challenge_paths[i]["image"]
                for i in rank_similarities_args
            ]

    def load_from_repo(
        self, hf_repo_path: str, model_name: str, hf_cache_dir: str
    ):
        download_output = download_model_with_name(
            hf_repo_path,
            hf_cache_dir,
            model_name,
            download_only_if_finished=False,
        )

        checkpoint_path = download_output["root_filepath"]
        checkpoint_path = (
            checkpoint_path
            if isinstance(checkpoint_path, pathlib.Path)
            else pathlib.Path(checkpoint_path)
        )

        logger.info(f"Loading checkpoint from {checkpoint_path}")

        self.accelerator.load_state(checkpoint_path)

        return checkpoint_path

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /home/evolvingfungus/.cache/huggingface/token
Login successful


In [31]:
# from capit.core.data.datasets import dataclass_collate
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataclass_collate)
# dummy_batch = next(iter(dataloader))
dummy_dict_batch = {
    key: (value.unsqueeze(0) if hasattr(value, "unsqueeze") else value)
    for key, value in dataset[0].__dict__.items()
}
dummy_batch = dataset[0].__class__(**dummy_dict_batch)
baseline = Ranker(
    model_type=CLIPImageTextModel,
    model_name_or_path="openai/clip-vit-base-patch32",
    repo_path="Antreas/baseline-100-100-27",
    model_name="ckpt_0",
    backbone_fine_tunable=True,
    batch=dummy_batch,
    cache_path="/data_fast/models/clip-base/",
)

baseline_fine_tuned = Ranker(
    model_type=CLIPImageTextModel,
    model_name_or_path="openai/clip-vit-base-patch32",
    repo_path="Antreas/baseline-100-100-27",
    model_name="ckpt_95000",
    backbone_fine_tunable=True,
    batch=dummy_batch,
    cache_path="/data_fast/models/baseline-100-100-27/",
)

cap = Ranker(
    model_type=CAPCLIPImageTextModel,
    model_name_or_path="openai/clip-vit-base-patch32",
    repo_path="Antreas/cap-100-100-24",
    model_name="ckpt_95000",
    backbone_fine_tunable=True,
    batch=dummy_batch,
    cache_path="/data_fast/models/cap-100-100-24/",
)

model_dict = {
    "random": RandomModel(),
    "clip-baseline": baseline,
    "clip-fine-tuned": baseline,
    "cap": cap,
}



Downloading (…)lve/main/config.yaml:   0%|          | 0.00/3.46k [00:00<?, ?B/s]

Downloading trainer_state.pt:   0%|          | 0.00/76.0k [00:00<?, ?B/s]

Downloading optimizer.bin:   0%|          | 0.00/2.23k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Downloading random_states_0.pkl:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

In [28]:
from collections import defaultdict
from distutils.command.upload import upload
from typing import Any
import gradio as gr
from PIL import Image


def load_images(sample_idx):
    sample = dataset[sample_idx]
    collection_image_data = [
        item["image"] for item in sample.collection_paths
    ]  # sample.collection_images
    challenge_image_data = [
        item["image"] for item in sample.challenge_paths
    ]  # sample.challenge_images

    if collection_image_data is not None:
        return *[gr.update(value=image) for image in collection_image_data], *[
            gr.update(value=image) for image in challenge_image_data
        ]
    else:
        return (*[gr.update(value=image) for image in challenge_image_data],)


def build_demo(
    model_dict: Any = None,
    collection_num_images: int = 100,
    challenge_num_images: int = 100,
):
    with gr.Blocks() as demo:
        with gr.Row():
            sample_idx_slider = gr.Slider(
                maximum=len(dataset),
                randomize=True,
                step=1,
                interactive=True,
                label="Datapoint idx to sample",
                info="Select the idx to sample",
            )

        with gr.Row():
            with gr.Column(scale=1, min_width=224):
                prompt = gr.Textbox(label="prompt", value="")
            with gr.Column(scale=1, min_width=224):
                rank_status = gr.Button(value="rank", label="rank")

        with gr.Row():
            if collection_num_images > 0:
                collection_images = []
                with gr.Column(scale=collection_num_images, min_width=224):
                    for i in range(collection_num_images):
                        collection_images.append(
                            gr.Image(label=f"collection-image-{i}")
                        )
            with gr.Column(scale=challenge_num_images, min_width=224):
                challenge_images = []
                for i in range(challenge_num_images):
                    challenge_images.append(
                        gr.Image(label=f"challenge-image-{i}")
                    )

            ranked_images_dict = defaultdict(list)
            for key, model in model_dict.items():
                with gr.Column(scale=challenge_num_images, min_width=224):
                    for i in range(challenge_num_images):
                        ranked_images_dict[key].append(
                            gr.Image(
                                shape=(224, 224),
                                label=f"ranked-image-{key}-{i}",
                            )
                        )

        sample_idx_slider.change(
            load_images,
            inputs=[sample_idx_slider],
            outputs=[*collection_images, *challenge_images],
        )

        for model_name, model in model_dict.items():
            rank_status.click(
                fn=model.rank,
                inputs=[prompt, sample_idx_slider],
                outputs=ranked_images_dict[model_name],
            )

    return demo

In [29]:
demo = build_demo(
    collection_num_images=dataset.max_num_collection_images_per_episode,
    challenge_num_images=dataset.max_num_query_images_per_episode,
    model_dict=model_dict,
)
demo.queue(concurrency_count=8)
demo.launch(share=True, debug=True, enable_queue=True)