In [None]:
%load_ext autoreload
%autoreload 2

### Imports

In [None]:
from upyog.all import *

if True:
    import sys
    sys.path.append("/home/synopsis/git/CinemaNet-Training/")
    sys.path.append("/home/synopsis/git/YOLOX-Custom/")
    sys.path.append("/home/synopsis/git/YOLO-CinemaNet/")
    sys.path.append("/home/synopsis/git/icevision/")
    sys.path.append("/home/synopsis/git/labelling-workflows/")
    sys.path.append("/home/synopsis/git/amalgam/")
    sys.path.append("/home/synopsis/git/cinemanet-multitask-classification/")
    sys.path.append("/home/synopsis/git/Synopsis.py/")

import torch
import open_clip
import wandb
from cinemanet_clip.inference import *
from cinemanet_clip.utils import *
from training.inference import InferenceModel
from cinemanet_clip.inference import get_top_matches
from cinemanet_clip.inference import (
    EVALUATION_PROMPTS, CELEBRITY_PROMPTS, PROP_PROMPTS, EMOTION_PROMPTS)

import ipywidgets as widgets
from ipywidgets import FloatSlider, Dropdown, Layout, Box, HBox, VBox, Layout, HTML
from cinemanet.utils.widgets import *

NOTE: These `MODEL_CONFIGS` need to be updated MANUALLY. The keys are Human Readable names that will be displayed in the UI as such.

You also need to run inference separately and have the results cached. This part is a bit awkward and will be updated soon.

In [None]:
MODEL_CONFIGS = {
    "ConvNeXT (256x256; Laion-Aes) - 5 Epoch": dict(
        arch="convnext_base_w",
        pretrained="laion_aesthetic_s13b_b82k",
#         ckpt_path=Path("./logs/2023_02_26-13_34_33-model_convnext_base_w-lr_0.0001-b_256-j_8-p_amp_bf16/checkpoints/epoch_5.pt"),
        ckpt_path=Path("/mnt/Data/MODELS/OPEN_CLIP/2023_02_26-13_34_33-model_convnext_base_w-lr_0.0001-b_256-j_8-p_amp_bf16/checkpoints/epoch_5.pt"),
    ),
    "ConvNeXT (256x256; Laion-2B) - 14 Epoch": dict(
        arch="convnext_base_w",
        pretrained="laion2b_s13b_b82k",
#         ckpt_path=Path("./logs/2023_02_26-13_34_33-model_convnext_base_w-lr_0.0001-b_256-j_8-p_amp_bf16/checkpoints/epoch_5.pt"),
        ckpt_path=Path("/mnt/Data/MODELS/OPEN_CLIP/2023_02_25-20_44_13-model_convnext_base_w-lr_0.0001-b_256-j_8-p_amp_bf16/checkpoints/epoch_14.pt"),
    ),
    "ViT L-14 (224x224; OpenAI)": dict(
        arch="ViT-L-14",
        pretrained="openai",
#         ckpt_path=Path("./logs/2023_02_24-14_51_33-model_ViT-L-14-lr_0.0001-b_128-j_8-p_amp_bf16/checkpoints/epoch_5.pt"),
        ckpt_path=Path("/mnt/Data/MODELS/OPEN_CLIP/2023_02_24-14_51_33-model_ViT-L-14-lr_0.0001-b_128-j_8-p_amp_bf16/checkpoints/epoch_5.pt"),
    )
}

MODEL_NAMES_HR = sorted(MODEL_CONFIGS.keys())
MODEL_NAMES_HR

In [None]:
# Assume that `ckpt_path` does have all the datasets analysed
# We're not checking that the names are consistent across all models... yet
ckpt_path = list(MODEL_CONFIGS.values())[0]['ckpt_path']
# ckpt_path = "/mnt/Data/MODELS/OPEN_CLIP/2023_02_26-13_34_33-model_convnext_base_w-lr_0.0001-b_256-j_8-p_amp_bf16/checkpoints/epoch_5.pt"
DATASETS = [l.name for l in Path(ckpt_path).parent.parent.ls() if l.name.startswith("prompt")]
DATASETS

In [None]:
class ModelLoader:
    def __init__(self):
        self.setup_w_model_config()

    def setup_w_model_config(self):
        # Individual widgets
        self.w_alpha = widgets.FloatSlider(value=0.0, min=0, max=1.0, step=0.25, description="Alpha: ")
        self.w_model_name = Dropdown(options=MODEL_NAMES_HR)
        self.w_load_model = widgets.Button(description="-- Load Model! --")

        # Events
        self.w_load_model.on_click(lambda x: self._load_model())
        
        # Agg view
        self.W_MODEL_CFG = VBox(children = [self.w_model_name, vspace(10), self.w_alpha, vspace(20), self.w_load_model])

    def _load_model(self, *args):
        logger.info(f"Loading model!")
        self.inf = InferenceModelFromDisk(
            # Args from user
                     device = W_DEVICE.value,
                      alpha = self.alpha,
            # Fixed args
                       arch = self.arch,
                 pretrained = self.pretrained,
                  ckpt_path = self.ckpt_path,
            # TODO
             path_embedding = self.path_embedding,
            experiment_name = "bhen ka lauda",
        )

        # On init
        self.df = self.inf.get_image_embeddings()
        self.embeddings = np.stack(self.df.embedding)
    
    def get_top_matches(self, prompt, N, *args) -> List[PathLike]:
        top_matches = get_top_matches(prompt, self.inf.model, self.embeddings, self.inf.tokenizer)
        ids = [sim[1] for sim in top_matches[:N]]
        fpaths = self.df.iloc[ids].filepath.tolist()
        return fpaths

    @property
    def alpha(self): return self.w_alpha.value

    @property
    def arch(self): return MODEL_CONFIGS[self.w_model_name.value]["arch"]

    @property
    def ckpt_path(self): return MODEL_CONFIGS[self.w_model_name.value]["ckpt_path"]

    @property
    def base_dir(self): return self.ckpt_path.parent.parent
    
    @property
    def pretrained(self): return MODEL_CONFIGS[self.w_model_name.value]["pretrained"]

    @property
    def path_embedding(self):
        return list(
            (self.base_dir / W_DATASET.value ).rglob(f"{self.arch}--{self.pretrained}--finetuned-alpha-{self.w_alpha.value}*feather*")
        )[0]

In [None]:
W_INIT_DISPLAY = widgets.Output(layout=L_center)
def initialise(*args):    
    global MODELS
    MODELS = []

    for i in range(W_NUM_MODELS.value):
        MODELS.append(ModelLoader())

    cfgs = []
    for model in MODELS:
        cfgs.append(model.W_MODEL_CFG)
        cfgs.append(hspace(20))
    
    with W_INIT_DISPLAY:
        W_INIT_DISPLAY.clear_output()
        display(HBox(children=cfgs))

In [None]:
W_OUT_MATCHES = widgets.HBox(children=[], layout=L_center)

In [None]:
def launch_similarity_ui(*args):
    def run_similarity_search(*args):
        grids = []
        W_OUT_MATCHES.children = grids  # Set it to empty first
        for model in MODELS:
            fpaths = model.get_top_matches(W_PROMPT.value, W_NUM_MATCHES.value)
            grids.append(img_grid(fpaths, W_NCOL.value))
            grids.append(hspace(2000))

        W_OUT_MATCHES.children = grids

    W_NUM_MATCHES = widgets.IntText(description="Num. Matches: ", value=30, layout=Layout(width="15%"))
    W_NCOL = widgets.IntText(description="Num. Cols: ", value=3, layout=Layout(width="15%"))
    W_PROMPT = widgets.Text(description="Prompt: ", value="An asshole holding a beer", layout=Layout(width="100%", font_size="20px"))
    W_PROMPT.style.font_size = "18px"

    W_RUN_SIMILARITY_SEARCH = widgets.Button(description="🚀 Run Similarity Search 🚀", layout=Layout(width="25%"))
    W_RUN_SIMILARITY_SEARCH.on_click(run_similarity_search)

    W_PROMPT_LAYOUT = VBox(
        children=[
            W_PROMPT, vspace(15), W_NUM_MATCHES, vspace(10), W_NCOL, vspace(15), W_RUN_SIMILARITY_SEARCH],
        layout=L_center)
    ALL = VBox(
        children = [W_PROMPT_LAYOUT, vspace(100), W_OUT_MATCHES]
    )
    display(ALL)

In [None]:
W_INIT_DISPLAY.clear_output()  # Not strictly necessary here.

W_DATASET = widgets.Dropdown(options=DATASETS, description="Dataset: ", layout=Layout(width="30%"))
W_DEVICE = widgets.Dropdown(options=[0,1,2], description="GPU ID: ")
W_NUM_MODELS = widgets.IntText(description="Num Models: ", value=1)

W_INIT_LAUNCH = widgets.Button(description="🚀 Initialise")
W_INIT_SIMILARITY_UI = widgets.Button(description="🚀 Launch Similarity UI", layout=Layout(width="13%"))
W_INIT_SIMILARITY_UI.disabled = True

W_INIT_LAUNCH.on_click(initialise)
W_INIT_LAUNCH.on_click(lambda _: setattr(W_INIT_SIMILARITY_UI, "disabled", False))
W_INIT_SIMILARITY_UI.on_click(launch_similarity_ui)

W_INIT = VBox(children=[
    W_DATASET, vspace(10), W_DEVICE, vspace(10), W_NUM_MODELS, vspace(20), W_INIT_LAUNCH,
    vspace(50), W_INIT_DISPLAY, vspace(50), W_INIT_SIMILARITY_UI,
], layout=L_center)
W_INIT