In [None]:
%autosave 60
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import json
import os
import pickle
from collections import Counter, OrderedDict
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image as pil_img
import seaborn as sns
import sklearn as skl
import torch
import torch.nn as nn
from IPython.display import Image, display
from matplotlib.patches import Rectangle
from matplotlib_inline.backend_inline import set_matplotlib_formats
from torch.nn import functional as F
from tqdm.contrib import tenumerate, tmap, tzip
from tqdm.contrib.bells import tqdm, trange

In [None]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", 15)
pd.set_option("display.max_rows", 50)
# Suitable default display for floats
pd.options.display.float_format = "{:,.2f}".format
plt.rcParams["figure.figsize"] = (12, 10)

# This one is optional -- change graphs to SVG only use if you don't have a
# lot of points/lines in your graphs. Can also just use ['retina'] if you
# don't want SVG.
%config InlineBackend.figure_formats = ["retina"]
set_matplotlib_formats("pdf", "png")

In [None]:
def get_sims(a: torch.tensor, b: torch.tensor, batch_size=1000):
    sims = {i: (0, -1) for i in range(a.shape[0])}
    for j in tqdm(range(0, b.shape[0], batch_size)):
        batch_sims = a @ b[j : j + batch_size].T
        batch_sims = torch.from_numpy(batch_sims)
        values, idxs = batch_sims.max(dim=-1)
        # Append an index and sim every iteration
        for i in range(a.shape[0]):
            offset = j + idxs[i].item()
            if values[i] > sims[i][0]:
                sims[i] = (values[i].item(), offset)
    return sims


def normalize_rows(mat: torch.tensor) -> None:
    for i in range(len(mat)):
        mat[i] /= mat[i].norm(p=2, dim=-1, keepdim=True)
    return mat

---

## Geoguessr In-game Frames -> GPT-J Text Lookup

In [None]:
split = "test"
# Load query image embeddings:
image_embs = pickle.load(
    open(
        f"/shared/gbiamby/geo/models/clip_ft/vit-b32/geoframes_clip_samples_fixed_window_{split}_img.pkl",
        "rb",
    )
)
text_embs = pickle.load(
    open(f"/shared/gbiamby/geo/models/clip_ft/vit-b32/gptj_clues_text.pkl", "rb")
)
# load target captions:
gpt_caps = json.load(open("/shared/g-luo/geoguessr/data/data/guidebook/kb/v3/cleaned_clues.json"))["clues"]

# # Append embeddings
# for caption in gpt_caps:
#     caption["clip_emb"] = image_embs[caption["file_path"]]

In [None]:
t_image_embs = normalize_rows(
    torch.stack([torch.tensor(emb) for emb in image_embs.values()]).to("cuda")
)
t_text_embs = normalize_rows(
    torch.stack([torch.tensor(emb) for emb in text_embs.values()]).to("cuda")
)

print(f"image_embs.shape: {t_image_embs.shape}, text_emb.shape: {t_text_embs.shape}")

In [None]:
sims = torch.mm(t_image_embs, t_text_embs.T)
max_sim_scores, max_sim_idxs = sims.max(dim=1)

In [None]:
texts_unique = list(text_embs.keys())
img_to_text_sims = {}
for i, img_path in enumerate(image_embs.keys()):
    img_to_text_sims[img_path] = {
        "best_match_text": texts_unique[max_sim_idxs[i]],
        "best_sim_score": max_sim_scores[i].tolist(),
        "best_sim_idx": max_sim_idxs[i].tolist(),
        "sims_all": sims[i].tolist(),
        "file_path": img_path,
    }

In [None]:
print(t_image_embs.shape, t_text_embs.shape, sims.shape)

In [None]:
# list(img_to_text_sims.items())[:2]

In [None]:
from IPython.core.display import HTML, Markdown


def show_samples(img_to_text_sims, n_samples: int = 20):
    df_random = np.random.choice(range(len(img_to_text_sims)), n_samples, replace=False)

    for i in df_random:
        img_row = deepcopy(img_to_text_sims[i])
        del img_row["sims_all"]
        print("=" * 180)
        # print(img_row.keys())
        # print(img_row)
        display(pd.DataFrame({k: [v] for k, v in img_row.items()}).T)
        img = pil_img.open(img_row["file_path"])
        img.thumbnail((1080, 640), pil_img.NEAREST)
        display(img)
        print("\n")

### Choose Images w/ Highest Img/Text Similarity Scores

In [None]:
show_samples(
    sorted(img_to_text_sims.values(), key=lambda x: x["best_sim_score"], reverse=True)[:100],
    n_samples=20,
)

### Choose Random Images, show best match

In [None]:
show_samples(list(img_to_text_sims.values()), n_samples=20)

---