In [None]:
%autosave 60
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import json
import os
import pickle
import platform
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import matplotlib as plt
import numpy as np
import pandas as pd
import PIL.Image as pil_img
from IPython.core.display import HTML, Markdown
from IPython.display import Image, display
from matplotlib_inline.backend_inline import set_matplotlib_formats
from PIL import Image as pil_img
from tqdm.contrib import tenumerate
from tqdm.contrib.bells import tqdm

from geoscreens.data import get_all_geoguessr_split_metadata
from geoscreens.utils import load_json, save_json

In [None]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", 15)
pd.set_option("display.max_rows", 50)
# Suitable default display for floats
pd.options.display.float_format = "{:,.2f}".format
plt.rcParams["figure.figsize"] = (12, 10)

# This one is optional -- change graphs to SVG only use if you don't have a
# lot of points/lines in your graphs. Can also just use ['retina'] if you
# don't want SVG.
%config InlineBackend.figure_formats = ["retina"]
set_matplotlib_formats("pdf", "png")

---

In [None]:
import pickle
import platform

import pandas as pd

df_ingame = pickle.load(open("/shared/gbiamby/geo/segment/in_game_frames_000.pkl", "rb"))

In [None]:
pd.DataFrame(
    df_ingame.groupby(["video_id", "img_width", "img_height"]).agg(
        total_frames=("sec", "count"),
        total_rounds=("round_num", "nunique"),
        # total_frames=("sec", "count"),
    )
)

In [None]:
df_ingame

In [None]:
if True or "df_meta" not in locals():
    df_meta = pd.DataFrame(
        get_all_geoguessr_split_metadata(
            force_include=["nemo_caption", "nemo_caption_entities"]
        ).values()
    ).set_index("id")

In [None]:
df_meta["video_id"] = df_meta.index

In [None]:
df_meta.tail(2).T

---

---

## Match up the `ec` captions (the ones Grace generated clue similarities for) with Video Timestamps

In [None]:
from collections import OrderedDict

import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer


def idx_to_keys(caption_mapping):
    start, end = 0, 0
    mapping = {}
    for k in caption_mapping:
        end += len(caption_mapping[k])
        mapping[k] = (start, end)
        start = end
    return mapping


def intersect(a, b):
    return min(a[1], b[1]) - max(a[0], b[0]) > 0


def sentence_to_timings(ann):
    caption = "".join(ann["nemo_caption"].values())
    mapping = idx_to_keys(ann["nemo_caption"])

    idx = 0
    keys = list(mapping.keys())
    sentences = {}
    for ent in ann["nemo_caption_entities"]:
        if ent[2] == "sentence":
            timings = [k for k, v in mapping.items() if intersect(ent, v)]
            subcaption = caption[ent[0] : ent[1]]
            sentences[subcaption] = timings
    return sentences


def sentence_to_timings_punkt(ann):
    caption = "".join(ann["nemo_caption"].values()).strip()
    mapping = idx_to_keys(ann["nemo_caption"])

    sentences = {}
    tokenizer = PunktSentenceTokenizer()
    subcaptions = list(tokenizer.tokenize(caption))
    spans = list(tokenizer.span_tokenize(caption))
    for span, subcaption in zip(spans, subcaptions):
        timings = [k for k, v in mapping.items() if intersect(span, v)]
        sentences[subcaption] = timings
    return sentences


def get_spans(caption: str, sentences: list[str]):
    start = 0
    end = len(sentences[0])
    spans = []
    for i, s in enumerate(sentences):
        spans.append((start, end))
        start = end
        end += len(sentences[i + 1]) if i + 1 < len(sentences) else len(sentences)
    return spans


def sentence_to_timings_nltk(ann):
    caption = "".join(ann["nemo_caption"].values()).strip()
    time_to_span = idx_to_keys(ann["nemo_caption"])

    sentences = OrderedDict()
    subcaptions = list(nltk.tokenize.sent_tokenize(caption))
    spans = get_spans(caption, subcaptions)
    for i, (subcaption, span) in enumerate(zip(subcaptions, spans)):
        timings = [time for time, _idx_span in time_to_span.items() if intersect(span, _idx_span)]
        sentences[subcaption] = {
            "times": timings,
            "idx": i,
            "span": span,
            "start": min(timings),
            "end": max(timings),
        }
    return sentences


def get_meta():
    """Get metadata for all videos"""
    df_meta = pd.DataFrame(
        get_all_geoguessr_split_metadata(
            force_include=["nemo_caption", "nemo_caption_entities"]
        ).values()
    ).set_index("id")
    df_meta["video_id"] = df_meta.index
    return df_meta


def load_clue_sims(dataset_type: str):
    clue_sims = load_json(
        f"/shared/g-luo/geoguessr/data/data/guidebook/narrations/{dataset_type}.json"
    )
    clue_sims = [
        # {"idx": i, "clue_sim": clue_types[0], "clue_cluster": clue_types[1], **narration}
        # for i, (narration, clue_types) in enumerate(
        #     zip(clue_sims["narrations"], clue_sims["clue_types"])
        # )
        {"idx": i, **narration}
        for i, (narration) in enumerate(clue_sims["narrations"])
    ]
    # Update the index for each sentence so it starts at 0 for each video_id:
    video_id = clue_sims[0]["id"]
    idx = 0
    for cs in clue_sims:
        if cs["id"] != video_id:
            idx = 0
            video_id = cs["id"]
        cs["idx"] = idx
        idx += 1

    clue_sim_lookup = {(cs["id"], cs["text"], cs["idx"]): cs for cs in clue_sims}
    return clue_sims, clue_sim_lookup


def get_caption_timings(df_meta: pd.DataFrame):
    captions_nltk = {}
    for i, video_id in tenumerate(df_meta.video_id.values, desc="get_caption_timings"):
        # if i[0] > 0:
        #     break
        captions_nltk[video_id] = sentence_to_timings_nltk(df_meta.loc[video_id].to_dict())
    return captions_nltk


def merge_timings_and_clue_sims(
    captions_nltk, clue_sims: dict, clue_sim_lookup: list[tuple], clue_sim_ids: set[str]
):
    num_matches = 0
    result = {}
    for i, (video_id, sentences) in tenumerate(
        captions_nltk.items(), desc="merge_timings_and_clue_sims"
    ):
        if video_id not in clue_sim_ids:
            continue
        if video_id not in result:
            result[video_id] = []
        for sentence, sentence_info in sentences.items():
            key = (video_id, sentence, sentence_info["idx"])
            if key in clue_sim_lookup:
                num_matches += 1
                result[video_id].append(
                    {
                        "sentence": sentence,
                        "clue_type": clue_sim_lookup[key]["clue_type"],
                        **deepcopy(sentence_info),
                    }
                )
                # sentence_info["clue_sim"] = clue_sim_lookup[key]["clue_sim"]
                # sentence_info["clue_cluster"] = clue_sim_lookup[key]["clue_cluster"]
    print("num_matches: ", num_matches)
    return result

In [None]:
# cs = load_json(f"/shared/g-luo/geoguessr/data/data/guidebook/narrations/train.json")["narrations"]
# print(type(cs))
# list(cs)[:1]

In [None]:
if "df_meta_original" not in locals() or "captions_nltk_original" not in locals():
    df_meta_original = get_meta()
    captions_nltk_original = get_caption_timings(df_meta_original)
df_meta = deepcopy(df_meta_original)
captions_nltk = deepcopy(captions_nltk_original)
print("total videos: ", len(captions_nltk))

for dataset_type in ["val", "test", "train"]:
    print("\n", "=" * 120, f"\n{dataset_type}")
    clue_sims, clue_sim_lookup = load_clue_sims(dataset_type)
    print(f"Total clue_sims: {len(clue_sims)}, clue_sims_lookup: {len(clue_sim_lookup)}")
    clue_sim_ids = {c["id"] for c in clue_sims}
    print(
        f"Total captions ({dataset_type}): ",
        sum([len(t) for video_id, t in captions_nltk.items() if video_id in clue_sim_ids]),
    )
    result = merge_timings_and_clue_sims(captions_nltk, clue_sims, clue_sim_lookup, clue_sim_ids)
    print(
        f"Merged sims + timings -- videos: {len(result)}, sentences: {sum([len(s) for s in result.values()])}"
    )
    save_path = Path(f"/shared/gbiamby/geo/captions/{dataset_type}_captions_with_timings.json")
    save_json(save_path, result)

---

---

## Debugging Why Sentence counts aren't matching up with the clue-sim counts

In [None]:
if "captions_guide" not in locals() or "captions_guide_lookup" not in locals():
    guide = load_json(f"/shared/g-luo/geoguessr/data/data/guidebook/narrations/train.json")
    captions_guide = {}
    captions_guide_lookup = {}
    for g in tqdm(guide["narrations"]):
        if g["id"] not in captions_guide:
            captions_guide[g["id"]] = []
        captions_guide[g["id"]].append(deepcopy(g))
        captions_guide_lookup[(g["id"], g["text"])] = g["clue_type"]
# guide_lookup = {n["text"]: n for n in guide["narrations"]}
# type(guide["narrations"]), guide["narrations"][0]

In [None]:
print("cg video_ids: ", len(captions_guide))
print("cg[K4GXuDACK40] sentences: ", len(captions_guide["K4GXuDACK40"]))
print("cg total captions: ", len(captions_guide_lookup))
print(captions_guide["K4GXuDACK40"][-10:])

In [None]:
print(len(captions_old["K4GXuDACK40"]))
print(len(captions_new["K4GXuDACK40"]))
print(len(captions_nltk["K4GXuDACK40"]))

list(captions_old["K4GXuDACK40"].items())[:10]
list(captions_nltk["K4GXuDACK40"].items())[:10]

In [None]:
if "captions_old" not in locals():
    captions_old = {}
    for t, video_id in tqdm(enumerate(df_meta.video_id.values)):
        captions_old[video_id] = sentence_to_timings(df_meta.loc[video_id].to_dict())

print(len(captions_old))
print(
    "total captions ",
    sum([len(t) for video_id, t in captions_old.items() if video_id in captions_guide]),
)

In [None]:
if "captions_new" not in locals():
    captions_new = {}
    for video_id in tqdm(df_meta.video_id.values):
        captions_new[video_id] = sentence_to_timings_punkt(df_meta.loc[video_id].to_dict())

print(len(captions_new))
sum([len(t) for video_id, t in captions_new.items() if video_id in captions_guide])
# list(captions_new.items())[-10:]

---

## Simplest Approach: Map ASR Time to Image Time

In [None]:
window_size = 1

In [None]:
ec_ids = {n["id"] for n in narrations["narrations"]}

In [None]:
narrations = load_json(f"/shared/g-luo/geoguessr/data/data/guidebook/narrations/train.json")

In [None]:
print("length: ", len(narrations["narrations"]))
narrations["narrations"][:5]

In [None]:
print(len(narrations["clue_types"]))
narrations["clue_types"][:10]

---

In [None]:
clues = load_json("/shared/g-luo/geoguessr/data/data/guidebook/clues/train.json")

In [None]:
clues[:10]

In [None]:
len(clues)

In [None]:
import json
import sys

import nltk
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

nltk.download("punkt")

In [None]:
# dataset = json.load(open(f"/shared/g-luo/geoguessr/data/data/train.json"))
print(len(dataset))
dataset2 = json.load(open(f"/shared/g-luo/geoguessr/data/data/guidebook/narrations/train.json"))
print(len(dataset2))

In [None]:
narrations = [nltk.tokenize.sent_tokenize("".join(ann["nemo_caption"].values())) for ann in dataset]
narrations = sum(narrations, [])

In [None]:
len(narrations)

In [None]:
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
clues = json.load(open("/shared/g-luo/geoguessr/data/data/guidebook/text/clues/paragraphs.json"))
cs = [nltk.tokenize.sent_tokenize(c["caption"]) for c in clues.values()]
# clue_embeddings = np.vstack([np.mean(model.encode(c), axis=0) for c in cs])

In [None]:
cs

In [None]:
import nltk


def get_inv_norm(x):
    norm = np.expand_dims(np.linalg.norm(x, axis=1), 1)
    return np.where(norm == 0, 0, 1 / norm)


def get_labels(clue_embeddings, narration_embeddings):
    sims = clue_embeddings @ narration_embeddings.T
    values, idxs = torch.max(torch.from_numpy(sims), dim=0)
    return [(values[i].item(), idxs[i].item()) for i in range(values.shape[0])]


def main():
    dataset_type = sys.argv[1]
    model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    dataset = json.load(open(f"/shared/g-luo/geoguessr/data/data/{dataset_type}.json"))
    narrations = [
        nltk.tokenize.sent_tokenize("".join(ann["nemo_caption"].values())) for ann in dataset
    ]
    narrations = sum(narrations, [])
    batch_size = 100

    narration_embeddings = []
    for i in tqdm(range(0, len(narrations), batch_size)):
        with torch.no_grad():
            narration_embeddings.extend(model.encode(narrations[i : i + batch_size]))

    clues = json.load(
        open("/shared/g-luo/geoguessr/data/data/guidebook/text/clues/paragraphs.json")
    )
    cs = [nltk.tokenize.sent_tokenize(c["caption"]) for c in clues.values()]
    clue_embeddings = np.vstack([np.mean(model.encode(c), axis=0) for c in cs])

    narration_embeddings = narration_embeddings * get_inv_norm(narration_embeddings)
    clue_embeddings = clue_embeddings * get_inv_norm(clue_embeddings)

    sims = get_labels(clue_embeddings, narration_embeddings)
    # json.dump(
    #     sims, open(f"/shared/g-luo/geoguessr/data/data/guidebook/clues/{dataset_type}.json", "w")
    # )

In [None]:
ref_data = load_json("/shared/g-luo/geoguessr/models/clip_zs/placing2014_no_indoor.json")
ref_data[0]

---