In [1]:
from typing import List
import gradio as gr
import numpy as np

def upload_file(files):
    file_paths = [file.name for file in files]
    return file_paths

def check_if_folder_has_images(file_paths):
    return len(file_paths) > 0

def show_state_of_files(file_paths):
    print(file_paths.value)
    
def rank(images, prompt):
    # randomize order of images
    images = [file.name for file in images]
    return np.random.permutation(images)
    

In [2]:
import torch.nn as nn
import torch
from capit.models import *
import torchvision.transforms as transforms
from PIL import Image
class Ranker(nn.Module):
    def __init__(self, model_name_or_path, pretrained):
        super().__init__()
        self.model = CLIPImageTextModel(pretrained=pretrained, 
                                    model_name_or_path=model_name_or_path)
        self.model = self.model.to(torch.cuda.current_device())
        self.model_name_or_path = model_name_or_path
        self.pretrained = pretrained
        
    def rank(self, image_paths, prompt):
        with torch.no_grad():
            images = [transforms.ToTensor()(Image.open(file.name)) for file in image_paths]
            similarities = self.model.forward(image=images, text=prompt)
            rank_similarities_args = torch.argsort(similarities.logits_per_image, descending=True)[0]
            return [image_paths[i].name for i in rank_similarities_args]

class CAPITRanker(Ranker):
    def __init__(self, model_name_or_path, ckpt_path):
        super().__init__(model_name_or_path, False)
        model_weight = torch.load(ckpt_path)
        state_dict = {}
        for key, value in model_weight["state_dict"].items():
            state_dict[key.replace("model.model.", "model.")] = value

        self.model.load_state_dict(state_dict)
        self.model = self.model.to(torch.cuda.current_device())
        self.model_name_or_path = model_name_or_path
        
    

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from collections import defaultdict
from distutils.command.upload import upload
import gradio as gr

def build_demo(model_dict, collection_num_images: int = 0, challenge_num_images: int = 5):
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(scale=1, min_width=224):
                prompt = gr.Textbox(label="prompt", value="")
        with gr.Row():
            with gr.Column(scale=1, min_width=224):
                rank_status = gr.Button(value="rank", label="rank")
        with gr.Row():
            if collection_num_images > 0:
                collection_images = []
                with gr.Column(scale=collection_num_images, min_width=224):
                    for i in range(collection_num_images):
                        collection_images.append(gr.Image(shape=(224, 224), label=f"collection-image-{i}"))
            with gr.Column(scale=challenge_num_images, min_width=224):
                challenge_images = []
                for i in range(challenge_num_images):
                    challenge_images.append(gr.Image(shape=(224, 224), label=f"challenge-image-{i}"))
            ranked_images_dict = defaultdict(list)
            for key, model in model_dict.items():
                with gr.Column(scale=challenge_num_images, min_width=224):
                    for i in range(challenge_num_images):
                        ranked_images_dict[key].append(gr.Image(shape=(224, 224), label=f"ranked-image-{key}-{i}"))
        with gr.Row():
            with gr.Column(scale=2, min_width=224):
                upload_button = gr.UploadButton("Browse to select a folder with images", file_types=["image"], file_count="multiple")
                upload_button.upload(upload_file, upload_button, challenge_images)
                for model_name, model in model_dict.items():
                    rank_status.click(fn=model.rank, inputs=[upload_button, prompt], outputs=ranked_images_dict[model_name])
    return demo

In [4]:
model_random = Ranker("openai/clip-vit-large-patch14", pretrained=False)
model_random = model_random.to(torch.cuda.current_device())

In [3]:
model_baseline = Ranker("openai/clip-vit-large-patch14", pretrained=True)
model_baseline = model_baseline.to(torch.cuda.current_device())


In [4]:
model_fine_tuned = CAPITRanker("openai/clip-vit-large-patch14", ckpt_path="/workspaces/CAPMultiModal-1/capit-clip-ft/last.ckpt")
model_fine_tuned = model_fine_tuned.to(torch.cuda.current_device())

In [7]:
model_dict = {"baseline": model_baseline, "fine-tuned": model_fine_tuned}
num_challenge_images = 25
num_collection_images = 0
demo = build_demo(challenge_num_images=num_challenge_images, model_dict=model_dict)
demo.queue()
demo.launch(share=True, debug=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://f83dbf85238f7aac.gradio.app

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces
