# Contents

This notebook shows two retrieval examples
- Retrieving by rplan id based on precomputed features
- Computing embeddings for a new image, and retrieving rplan images similar to it

## Visualization code

In [None]:
import matplotlib.pyplot as plt
import datasets

class FloorPlanImages:
    def __init__(self, ds_img) -> None:
        self.ds_img = ds_img

        self.id_to_index = {}

        for i, id in enumerate(self.ds_img["id"]):
            self.id_to_index[id] = i
    
    def __getitem__(self, id):
        return datasets.Image(decode=True).decode_example(self.ds_img[self.id_to_index[id]]["img"])
    

class VisualizeRetrievals:

    def __init__(self, images: FloorPlanImages):
        self.images = images

    def visualize_query_by_id(self, query_id, retrieved_ids, titles=None, relevants=None):
        """Visualize the search results.
        
        First item of the lists is the query."""

        k = len(retrieved_ids) + 1

        # if axes is None:
        fig, axes = plt.subplots(1, k, dpi=150, figsize=(20 * k / 5 * 0.75, 7.5 * 0.75))
        fig.tight_layout(pad=1.0)

        axes[0].imshow(self.images[query_id])

        axes[0].set_title(f"{query_id=}")
        axes[0].axis("off")

        for i, id in enumerate(retrieved_ids):

            axes[i+1].imshow(self.images[id])

            if titles is None:
                axes[i + 1].set_title(f"{id=}")
            else:
                axes[i + 1].set_title(f"{id=}\n{titles[i]}")

            axes[i + 1].axis("off")

        return fig


    def visualize_query_by_image(self, retrieved_ids, query_img, titles=None, query_img_title="query"):
        """Visualize the search results.
        
        First item of the lists is the query."""

        k = len(retrieved_ids) + 1

        # if axes is None:
        fig, axes = plt.subplots(1, k, dpi=150, figsize=(20 * k / 5 * 0.75, 7.5 * 0.75))
        fig.tight_layout(pad=1.0)

        axes[0].axis("off")
        axes[0].imshow(query_img)
        axes[0].set_title(query_img_title)

        for i, id in enumerate(retrieved_ids):

            img = self.images[id]            

            axes[i+1].imshow(img)

            if titles is None:
                axes[i + 1].set_title(f"{id=}")
            else:
                axes[i + 1].set_title(f"{id=}\n{titles[i]}")

            axes[i + 1].axis("off")

        return fig


## Code for loading precomputed features

In [None]:
import torch

from pathlib import Path

import os
import typing

def infer_dataset_name(wandb_prefix, wandb_model_ref, features_cache_folder="data/predicted/"):
    model_folder = Path(features_cache_folder) / wandb_prefix / wandb_model_ref

    dataset_name_options = [item for item in os.listdir(model_folder) if os.path.isdir(model_folder / item)]

    if len(dataset_name_options) == 0:
        raise RuntimeError(f"No features found for {wandb_prefix}/{wandb_model_ref}. First run predict_embeddings.py")

    if len(dataset_name_options) > 1:
        raise RuntimeError(f"More than one dataset_name option was found, specify one of {dataset_name_options}")

    dataset_name = dataset_name_options[0]

    return dataset_name


def load_features(wandb_prefix, wandb_model_ref, split, dataset_name, features_cache_folder="data/predicted/") -> typing.Dict[str, torch.Tensor]:

    model_folder = Path(features_cache_folder) / wandb_prefix / wandb_model_ref

    split_options = [item for item in os.listdir(model_folder / dataset_name) if os.path.isdir(model_folder / dataset_name / item)]

    assert split in split_options, f"Did not find features for {split=}, only for {split_options}"

    features_path = Path(features_cache_folder) / wandb_prefix / wandb_model_ref / dataset_name / split / "feats.pth"
    
    return torch.load(features_path)

# Retrieve based on an rplan id in the dataset

In [None]:
from inference.retrieval_context import EmbeddingsRetrievalContext
from inference.model_context import WandbModelContext

retrieval_context = None

wandb_prefix = "emanuel/msc_thesis_models"
wandb_model_ref = "model-uc18eq89:best"
split = "val"

dataset_name = infer_dataset_name(wandb_prefix, wandb_model_ref)
features = load_features(wandb_prefix, wandb_model_ref, split, dataset_name=dataset_name)

retrieval_context = EmbeddingsRetrievalContext(features)

In [None]:
# If it failed, first preprocess the dataset and precompute features:

if retrieval_context is None:

    from predict_embeddings import model_huggingface_url_to_preprocessing_style

    model_context = WandbModelContext(wandb_prefix, wandb_model_ref)
    method = model_huggingface_url_to_preprocessing_style(model_context.huggingface_dataset)

    rplan_dataset_path = "/home/emanuel/thesisdata/dataset/floorplan_dataset"

    # Preprocess dataset:
    if not os.path.exists(f"data/processed/{method.value}/{split}"):
        print(f"Run: python run_preprocessing.py method={method.name} rplan_dataset_path={rplan_dataset_path} split={split}")
        !python run_preprocessing.py method={method.name} rplan_dataset_path={rplan_dataset_path} split={split}

    # Precompute features:
    print(f"Run: python predict_embeddings.py wandb_prefix={wandb_prefix} wandb_model_ref={wandb_model_ref} split={split}")
    !python predict_embeddings.py wandb_prefix={wandb_prefix} wandb_model_ref={wandb_model_ref} split={split}

# Now exectute the cell above again

In [None]:
# By default uses the same dataset as the model, you can set it to:
# dataset_name = "ds_rplanpy_rgb"
# for nicer visualizations
ds_images = datasets.load_from_disk(f"data/processed/{dataset_name}/{split}/")

visualize_retrievals = VisualizeRetrievals(FloorPlanImages(ds_images))

In [None]:
dataset_name

In [None]:
show_top_k = 5

query_ids = [4]

assert retrieval_context is not None

retrieved_idss = retrieval_context.retrieve(query_ids)[:, :show_top_k]

for query_id, retrieved_ids in zip(query_ids, retrieved_idss):
    plot = visualize_retrievals.visualize_query_by_id(query_id, retrieved_ids)

# Retrieve based on a new image

In [None]:
from inference.model_context import WandbModelContext

model_ctx = WandbModelContext("emanuel/msc_thesis_models", "model-uc18eq89:best")

model_ctx.model.cpu()
model = model_ctx.model.eval()

In [None]:
from PIL import Image
import numpy as np

# Open an image to use as query
query_img = Image.open("data/example_edited_query_images/rplan-0-squared-version.png").convert("RGB")

query_img

In [None]:
assert np.array(query_img).shape == (256, 256, 3)

In [None]:
import torch

from torchvision.transforms.functional import to_tensor

with torch.no_grad():
    repr_vector = model_ctx.model({"img": to_tensor(query_img).unsqueeze(0)})

query_embeddings = repr_vector["pred"]

query_embeddings.shape

In [None]:
from inference.retrieval_context import EmbeddingsRetrievalContext

wandb_prefix = "emanuel/msc_thesis_models"
wandb_model_ref = "model-uc18eq89:best"
split = "val"

dataset_name = infer_dataset_name(wandb_prefix, wandb_model_ref)
features = load_features(wandb_prefix, wandb_model_ref, split, dataset_name=dataset_name)

retrieval_context = EmbeddingsRetrievalContext(features)

ds_images = datasets.load_from_disk(f"data/processed/{dataset_name}/{split}/")

visualize_retrievals = VisualizeRetrievals(FloorPlanImages(ds_images))

In [None]:
retrieved_ids = retrieval_context.retrieve_by_embedding(query_embeddings, top_k=100)[0][:10]

In [None]:
plot = visualize_retrievals.visualize_query_by_image(retrieved_ids, query_img)