In [None]:
%autosave 60
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import json
import os
import pickle
import platform
from collections import OrderedDict
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import PIL.Image as pil_img
import seaborn as sns
from IPython.core.display import HTML, Markdown
from IPython.display import Image, display
from matplotlib_inline.backend_inline import set_matplotlib_formats
from nltk.tokenize.punkt import PunktSentenceTokenizer
from PIL import Image as pil_img
from tqdm.contrib import tenumerate
from tqdm.contrib.bells import tqdm

from geoscreens.data import get_all_geoguessr_split_metadata
from geoscreens.utils import load_json, save_json

In [None]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", 15)
pd.set_option("display.max_rows", 50)
# Suitable default display for floats
pd.options.display.float_format = "{:,.2f}".format
plt.rcParams["figure.figsize"] = (12, 10)

# This one is optional -- change graphs to SVG only use if you don't have a
# lot of points/lines in your graphs. Can also just use ['retina'] if you
# don't want SVG.
%config InlineBackend.figure_formats = ["retina"]
set_matplotlib_formats("pdf", "png")

---

In [None]:
df_ingame = pickle.load(open("/shared/gbiamby/geo/segment/in_game_frames_000.pkl", "rb"))

In [None]:
pd.DataFrame(
    df_ingame.groupby(["video_id", "img_width", "img_height"]).agg(
        total_frames=("sec", "count"),
        total_rounds=("round_num", "nunique"),
        # total_frames=("sec", "count"),
    )
)

---

---

## Match up the `ec` captions (the ones Grace generated clue similarities for) with Video Timestamps

In [None]:
def idx_to_keys(caption_mapping):
    start, end = 0, 0
    mapping = {}
    for k in caption_mapping:
        end += len(caption_mapping[k])
        mapping[k] = (start, end)
        start = end
    return mapping


def intersect(a, b):
    return min(a[1], b[1]) - max(a[0], b[0]) > 0


# def sentence_to_timings(ann):
#     caption = "".join(ann["nemo_caption"].values())
#     mapping = idx_to_keys(ann["nemo_caption"])

#     idx = 0
#     keys = list(mapping.keys())
#     sentences = {}
#     for ent in ann["nemo_caption_entities"]:
#         if ent[2] == "sentence":
#             timings = [k for k, v in mapping.items() if intersect(ent, v)]
#             subcaption = caption[ent[0] : ent[1]]
#             sentences[subcaption] = timings
#     return sentences


# def sentence_to_timings_punkt(ann):
#     caption = "".join(ann["nemo_caption"].values()).strip()
#     mapping = idx_to_keys(ann["nemo_caption"])

#     sentences = {}
#     tokenizer = PunktSentenceTokenizer()
#     subcaptions = list(tokenizer.tokenize(caption))
#     spans = list(tokenizer.span_tokenize(caption))
#     for span, subcaption in zip(spans, subcaptions):
#         timings = [k for k, v in mapping.items() if intersect(span, v)]
#         sentences[subcaption] = timings
#     return sentences


def get_spans(caption: str, sentences: list[str]):
    start = 0
    end = len(sentences[0])
    spans = []
    for i, s in enumerate(sentences):
        spans.append((start, end))
        start = end
        end += len(sentences[i + 1]) if i + 1 < len(sentences) else len(sentences)
    return spans


def sentence_to_timings_nltk(ann):
    caption = "".join(ann["nemo_caption"].values()).strip()
    time_to_span = idx_to_keys(ann["nemo_caption"])

    sentences = OrderedDict()
    subcaptions = list(nltk.tokenize.sent_tokenize(caption))
    spans = get_spans(caption, subcaptions)
    for i, (subcaption, span) in enumerate(zip(subcaptions, spans)):
        timings = [
            float(time) for time, _idx_span in time_to_span.items() if intersect(span, _idx_span)
        ]
        sentences[subcaption] = {
            "times": timings,
            "idx": i,
            "span": span,
            "start": min(timings),
            "end": max(timings),
        }
    return sentences


def get_meta():
    """Get metadata for all videos"""
    df_meta = pd.DataFrame(
        get_all_geoguessr_split_metadata(
            force_include=["nemo_caption", "nemo_caption_entities"]
        ).values()
    ).set_index("id")
    df_meta["video_id"] = df_meta.index
    return df_meta


def load_clue_sims(dataset_type: str):
    clue_sims = load_json(
        f"/shared/g-luo/geoguessr/data/data/guidebook/narrations/{dataset_type}.json"
    )
    clue_sims = [{"idx": i, **narration} for i, (narration) in enumerate(clue_sims["narrations"])]
    # Update the index for each sentence so it starts at 0 for each video_id:
    video_id = clue_sims[0]["id"]
    idx = 0
    for cs in clue_sims:
        if cs["id"] != video_id:
            idx = 0
            video_id = cs["id"]
        cs["idx"] = idx
        idx += 1

    clue_sim_lookup = {(cs["id"], cs["text"], cs["idx"]): cs for cs in clue_sims}
    return clue_sims, clue_sim_lookup


def get_caption_timings(df_meta: pd.DataFrame):
    captions_nltk = {}
    for i, video_id in tenumerate(df_meta.video_id.values, desc="get_caption_timings"):
        # if i[0] > 0:
        #     break
        captions_nltk[video_id] = sentence_to_timings_nltk(df_meta.loc[video_id].to_dict())
    return captions_nltk


def merge_timings_and_clue_sims(
    captions_nltk, clue_sims: dict, clue_sim_lookup: list[tuple], clue_sim_ids: set[str]
):
    num_matches = 0
    result = {}
    for i, (video_id, sentences) in tenumerate(
        captions_nltk.items(), desc="merge_timings_and_clue_sims"
    ):
        if video_id not in clue_sim_ids:
            continue
        if video_id not in result:
            result[video_id] = []
        for sentence, sentence_info in sentences.items():
            key = (video_id, sentence, sentence_info["idx"])
            if key in clue_sim_lookup:
                num_matches += 1
                result[video_id].append(
                    {
                        "sentence": sentence,
                        "clue_type": clue_sim_lookup[key]["clue_type"],
                        **deepcopy(sentence_info),
                    }
                )
    print("num_matches: ", num_matches)
    return result


def sort_captions(captions):
    for video_id, caps in list(captions.items()):
        captions[video_id] = sorted(caps, key=lambda x: x["idx"])

#### Combine ASR/Clue Sims With ASR Timestamps

In [None]:
if False:
    if "df_meta_original" not in locals() or "captions_nltk_original" not in locals():
        df_meta_original = get_meta()
        captions_nltk_original = get_caption_timings(df_meta_original)
    df_meta = deepcopy(df_meta_original)
    captions_nltk = deepcopy(captions_nltk_original)
    print("total videos: ", len(captions_nltk))

    for dataset_type in ["val", "test", "train"]:
        print("\n", "=" * 120, f"\n{dataset_type}")
        clue_sims, clue_sim_lookup = load_clue_sims(dataset_type)
        print(f"Total clue_sims: {len(clue_sims)}, clue_sims_lookup: {len(clue_sim_lookup)}")
        clue_sim_ids = {c["id"] for c in clue_sims}
        print(
            f"Total captions ({dataset_type}): ",
            sum([len(t) for video_id, t in captions_nltk.items() if video_id in clue_sim_ids]),
        )
        result = merge_timings_and_clue_sims(
            captions_nltk, clue_sims, clue_sim_lookup, clue_sim_ids
        )
        print(
            f"Merged sims + timings -- videos: {len(result)}, sentences: {sum([len(s) for s in result.values()])}"
        )
        save_path = Path(f"/shared/gbiamby/geo/captions/{dataset_type}_captions_with_timings.json")
        # save_json(save_path, result)

## Visual Inspection of ASR Sentences

### Inspect Sentence Time Windows

In [None]:
split = "train"
captions = load_json(f"/shared/gbiamby/geo/captions/{split}_captions_with_timings.json")
sort_captions(captions)

In [None]:
clues_paragraphs = load_json(
    "/shared/g-luo/geoguessr/data/data/guidebook/text/clues/paragraphs.json"
)
clue_clusters = list(clues_paragraphs.keys())

In [None]:
flattened = []
asr_sentences = [
    [
        {
            "video_id": vid,
            "clue_cluster": clue_clusters[c["clue_type"][1]],
            "clue_sim": c["clue_type"][0],
            "dur": c["end"] - c["start"],
            **c,
        }
        for i, c in enumerate(caps)
        if (c["clue_type"][0] >= 0.4) and ((c["end"] - c["start"]) > 0)
    ]
    for vid, caps in captions.items()
]
# Re-index:
for sentences in asr_sentences:
    for i in range(len(sentences)):
        sentences[i]["idx"] = i

In [None]:
list(map(lambda x: flattened.extend(x), asr_sentences))
len(flattened)

In [None]:
df_flat = pd.DataFrame(
    flattened,
    # columns=["idx", "video_id", "cluster_name", "clue_sim", "start", "end", "dur", "sentence"],
).sort_values(["video_id", "idx"])
cuts, bins = pd.cut(df_flat.clue_sim, bins=10, retbins=True)
df_flat["clue_bin"] = cuts
df_flat.set_index(["video_id", "idx"], drop=False, inplace=True, verify_integrity=True)
df_flat.index.rename(["_video_id", "_idx"], inplace=True)
display(df_flat)

In [None]:
df_flat["dur_bin"], bins = pd.cut(df_flat.dur, bins=100, retbins=True)
# df_flat.dur.plot.bar()

In [None]:
# df_meta = get_meta()

In [None]:
bins
print(df_flat.shape)
sns.histplot(df_flat[df_flat.dur <= 100].dur, bins=100)

In [None]:
# df_flat.loc["zyZRvZohmro"]
# df_flat[df_flat.dur > 50]

In [None]:
df_meta.loc["-OYDoUERUqA"]

In [None]:
df_tmp = df_flat.sample(50)
print(df_tmp.shape)
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    # display(df_flat.loc["-OYDoUERUqA"].head(100))
    display(df_tmp)

In [None]:
df_tmp = df_flat
display(pd.DataFrame(df_tmp.groupby("clue_cluster").agg(total=("idx", "count"))).reset_index())
print(f"Total clues: {len(df_tmp):,}")
df_tmp.groupby("clue_cluster").agg(total=("idx", "count")).plot.bar(
    title="Counts by Clue Cluster - clue_sim:>=0.4"
)

In [None]:
df_tmp = pd.DataFrame(df_flat.groupby("clue_bin").agg(total=("idx", "count"))).reset_index()
display(df_tmp)
df_tmp.plot.bar(x="clue_bin", title="ASR Sentence/Clue Similarity Distribution (clue_sim>=0.4)")

### Show Some ASR Sentence Examples (Random Sample Each Time the Cell is Run)

In [None]:
display(df_flat.sample(10))

In [None]:
file = f"/shared/g-luo/geoguessr/videos/bRSdHaz57Qk.en.vtt"
current = ""
with open(file) as f:
    text = f.readlines()

text


---

---

## Generate Samples to Fine-tune CLIP - Simplest Approach: Map ASR Time Spans to Image Timestamps

Each image should be paired with either 0 or one sentence. One sentence can be paired to many images.

In [None]:
captions["-13sRRWmIxY"][:5]

In [None]:
clues_paragraphs = load_json(
    "/shared/g-luo/geoguessr/data/data/guidebook/text/clues/paragraphs.json"
)
print(type(clues_paragraphs))
print(len(clues_paragraphs))
print(clues_paragraphs.keys())

In [None]:
def sort_captions(captions):
    for video_id, caps in list(captions.items()):
        captions[video_id] = sorted(caps, key=lambda x: x["idx"])


def filter_captions(captions: dict[str, list[dict]]):
    result = {}
    for video_id, caps in captions.items():
        if video_id not in result:
            result[video_id] = []
        for c in caps:
            clue_sim, clue_cluster_id = c["clue_type"][0], c["clue_type"][1]
            if ("welcome back" in c["sentence"].casefold()) or (len(c["sentence"]) < 5):
                continue
            if (c["end"] - c["start"]) <= 0:
                continue
            elif clue_sim >= 0.4:
                result[video_id].append(c)
            # if clue_cluster_id == 11:
            #     if clue_sim >= 0.3:
            #         result[video_id].append(c)
            # elif clue_cluster_id == 4:
            #     if clue_sim >= 0.2:
            #         result[video_id].append(c)
            # elif clue_sim >= 0.2:
            #     result[video_id].append(c)
        result[video_id] = sorted(result[video_id], key=lambda x: x["idx"])
    return result


def ensure_no_time_overlaps(captions):
    """
    Scan captions sequentially for each video and make sure the (start, end)
    times for adjacent captions do not overlap. This check is good if we want to
    pick images for each caption for each caption's time span.
    """
    overlaps = []
    for video_id, caps in captions.items():
        for i in range(len(caps)):
            cap = caps[i]
            if i + 1 < len(caps) and cap["end"] > caps[i + 1]["start"]:
                overlaps.append((video_id, cap, caps[i + 1]))
    print("Num overlaps: ", len(overlaps))
    return overlaps


def to_fixed_time_windows(captions):
    """
    Make each
    """
    pass


def get_clip_samples_simple(df_ingame: pd.DataFrame, captions: dict[str, list[dict[str, Any]]]):
    """
    Positive samples:
        For each caption, maps images in the caption's (start, end) time range to the caption.
    Negative samples:
        For each caption, maps random images from the same video but different rounds to the caption.
    """
    # fmt: off
    frame_columns = [
        "round_num", "frame_idx", "img_width", "img_height", "sec", "time",
        "labels", "scores", "bboxes", "split", "file_path",
    ]
    # fmt: on
    samples = []
    no_frames = set()
    for video_id, caps in tqdm(
        captions.items(), desc="get_clip_samples_simple", total=len(captions)
    ):
        video_has_pos_samples = False
        if video_id not in df_ingame.index:
            no_frames.add(video_id)
            continue
        df = df_ingame.loc[video_id]
        for c in caps:
            round_num = None
            pos_samples = df[(c["start"] <= df.sec) & (df.sec < c["end"])]
            if pos_samples is not None and len(pos_samples) > 0:
                # positive sample(s):
                round_num = pos_samples.iloc[0]["round_num"]
                samples.append(
                    {
                        "video_id": video_id,
                        "caption_info": c,
                        "frames": pos_samples[frame_columns].to_dict("records"),
                        "gt": True,
                    }
                )
                video_has_pos_samples = True
                # negative sample(s):
                neg_samples = df[~(df.round_num == round_num)].sample(len(pos_samples))
                samples.append(
                    {
                        "video_id": video_id,
                        "caption_info": c,
                        "frames": neg_samples[frame_columns].to_dict("records"),
                        "gt": False,
                    }
                )
        if not video_has_pos_samples:
            no_frames.add(video_id)
    return samples, no_frames


def get_clip_samples_fixed_window(
    df_ingame: pd.DataFrame, captions: dict[str, list[dict[str, Any]]], time_window: float = 5.0
):
    """
    Positive samples:
        For each caption, maps images in the caption's (start, end) time range to the caption.
    Negative samples:
        For each caption, maps random images from the same video but different rounds to the caption.
    """
    # fmt: off
    frame_columns = [
        "round_num", "frame_idx", "img_width", "img_height", "sec", "time",
        "labels", "scores", "bboxes", "split", "file_path",
    ]
    # fmt: on
    samples = []
    no_frames = set()
    for video_id, caps in tqdm(
        captions.items(), desc="get_clip_samples_fixed_window", total=len(captions)
    ):
        video_has_pos_samples = False
        if video_id not in df_ingame.index:
            no_frames.add(video_id)
            continue
        df = df_ingame.loc[video_id]
        for c in caps:
            round_num = None
            anchor = (c["start"], c["end"])
            pos_samples = df[(anchor[0] - time_window <= df.sec) * (anchor[0] + 1 > df.sec)]
            if pos_samples is not None and len(pos_samples) > 0:
                # positive sample(s):
                round_num = pos_samples.iloc[0]["round_num"]
                samples.append(
                    {
                        "video_id": video_id,
                        "caption_info": c,
                        "frames": pos_samples[frame_columns].to_dict("records"),
                        "anchor": anchor,
                        "gt": True,
                    }
                )
                video_has_pos_samples = True
        if not video_has_pos_samples:
            no_frames.add(video_id)
    return samples, no_frames


def get_clip_samples(
    df_ingame: pd.DataFrame, captions: dict[str, list[dict[str, Any]]], method="simple"
):
    if method == "simple":
        return get_clip_samples_simple(df_ingame, captions)
    elif method == "fixed_window":
        return get_clip_samples_fixed_window(df_ingame, captions)
    else:
        raise NotImplementedError()

#### Old Code to Create Captions for a Single Split

In [None]:
# if False:
#     overlaps = ensure_no_time_overlaps(captions_nltk)
#     captions_filtered = filter_captions(captions_nltk)
#     print(f"Num videos: {len(captions_filtered):,}")
#     print(f"Num captions: {sum([len(caps) for caps in captions.values()]):,}")
#     print(f"Num filtered captions: {sum([len(caps) for caps in captions_filtered.values()]):,}")
#     pickle.dump(
#         captions_filtered,
#         open(f"/shared/gbiamby/geo/captions/{split}_captions_with_timings_filtered.json", "wb"),
#     )

In [None]:
# captions_filtered = pickle.load(open(f"/shared/gbiamby/geo/captions/{split}_captions_with_timings_filtered.json", "rb"))

In [None]:
# pos, no_frames = get_clip_samples(df_ingame, captions_filtered)

In [None]:
# print(f"Num captions w/ samples: {len(pos):,}")
# print(f"Num videos w/o any sampled frames: {len(no_frames):,}")
# print(f"Num videos w/ samples: {len({s['video_id'] for s in pos}):,}")
# print(f"Num samples: {sum([len(s['frames']) for s in pos]):,}")

In [None]:
# pos[0]

#### Generate CLIP Samples for All Splits

In [None]:
def flatten_samples(samples: list[dict[str, Any]]):
    flattened = []
    drop = ["bboxes", "scores", "labels"]
    # list(map(lambda x: flattened.extend(x), samples.values()))
    for s in samples:
        caption_info = s["caption_info"]
        base = {
            # "video_id": s["video_id"],
            # "gt": s["gt"],
            **{k: v for k, v in s.items() if k not in ("caption_info", "frames")},
            **caption_info,
        }
        for f in s["frames"]:
            s_flat = {**base, **f}
            s_flat = deepcopy(s_flat)
            for k in drop:
                if k in s_flat:
                    del s_flat[k]
            flattened.append(s_flat)
    print(f"len(samples): {len(samples):,}")
    print(f"len(flattened): {len(flattened):,}")
    return flattened


def generate_samples(split):
    print("\n", "=" * 120)
    df_ingame = pickle.load(open("/shared/gbiamby/geo/segment/in_game_frames_000.pkl", "rb"))
    # captions = load_json(f"/shared/gbiamby/geo/captions/{split}_captions_with_timings.json")
    # sort_captions(captions)
    # overlaps = ensure_no_time_overlaps(captions)
    # captions_filtered = filter_captions(captions)
    # print(f"Num videos: {len(captions_filtered):,}")
    # print(f"Num captions: {sum([len(caps) for caps in captions.values()]):,}")
    # print(f"Num filtered captions: {sum([len(caps) for caps in captions_filtered.values()]):,}")
    # pickle.dump(
    #     captions_filtered,
    #     open(f"/shared/gbiamby/geo/captions/{split}_captions_with_timings_filtered.json", "wb"),
    # )
    captions_filtered = pickle.load(
        open(f"/shared/gbiamby/geo/captions/{split}_captions_with_timings_filtered.json", "rb")
    )

    method = "fixed_window"
    clip_samples, no_frames = get_clip_samples(df_ingame, captions_filtered, method=method)
    print(f"Num captions w/ samples: {len(clip_samples):,}")
    print(f"Num videos w/o any sampled frames: {len(no_frames):,}")
    print(f"Num videos w/ samples: {len({s['video_id'] for s in clip_samples}):,}")
    print(f"Num samples: {sum([len(s['frames']) for s in clip_samples]):,}")
    flattened = flatten_samples(clip_samples)
    save_json(f"/shared/gbiamby/geo/captions/clip_samples_{method}_{split}.json", flattened)
    df = pd.DataFrame(flattened)
    print(f"Total frames: {len(df)}, unique frames: {df.file_path.nunique()}")
    df.to_csv(f"/shared/gbiamby/geo/captions/clip_samples_{method}_{split}_full.csv", index=False)
    df = df[df["gt"]][["file_path", "sentence"]]
    df.to_csv(
        f"/shared/gbiamby/geo/captions/clip_samples_{method}_{split}_openclip.csv", index=False
    )
    print("Done!")
    return df


for split in ["val", "test", "train"]:
    # for split in ["val"]:
    df = generate_samples(split)

In [None]:
idx = 1
print(list(captions_filtered.items())[idx][0])
list(captions_filtered.items())[idx][1][5:120]

In [None]:
for split in ["val", "test", "train"]:
    clip_samples = load_json(f"/shared/gbiamby/geo/captions/clip_samples_simple_{split}.json")
    print("")
    print("=" * 100)
    print(f"Num captions w/ samples: {len(clip_samples):,}")
    print(f"Num videos w/ samples: {len({s['video_id'] for s in clip_samples}):,}")
    print(f"Num samples: {sum([len(s['frames']) for s in clip_samples]):,}")

In [None]:
df_clip = pd.read_csv(f"/shared/gbiamby/geo/captions/clip_samples_fixed_window_train_full.csv")
print("total samples: ", len(df_clip), "unique frames: ", df_clip.file_path.nunique())
df_clip["file_count"] = df_clip.join(
    pd.DataFrame(df_clip.groupby("file_path").agg(file_count=("file_path", "count"))),
    on="file_path",
)[["file_count"]]

In [None]:
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(df_clip.head(100))

---

## Visualize Some of the CLIP +/- Samples

In [None]:
def plot_grid(images: np.ndarray, max_rows=4, max_cols=2):
    fig, axes = plt.subplots(nrows=max_rows, ncols=max_cols, figsize=(40, 40))
    for idx, image in enumerate(images[: max_rows * max_cols]):
        row = idx // max_cols
        col = idx % max_cols
        axes[row, col].axis("off")
        axes[row, col].imshow(image, cmap="gray", aspect="auto")
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    plt.show()


def show_asr_sentence_and_video_frame_samples(video_info):
    # video_info = list(clip_samples)[0]
    video_id = video_info["video_id"]
    print("video_id: ", video_id)
    print("Sentence: ", video_info["caption_info"]["sentence"])
    print(video_info["caption_info"]["clue_type"])
    print(
        "Clue Cluster: ", list(clues_paragraphs.keys())[video_info["caption_info"]["clue_type"][1]]
    )
    print("Num frames: ", len(video_info["frames"]))
    print("Sample types: ", "Positive" if video_info["gt"] else "Negative")
    # print([(frame["file_path"], Path(frame["file_path"]).exists()) for frame in video_info["frames"]])
    # for f in video_info["frames"]:
    #     print(f)
    imgs = [pil_img.open(frame["file_path"]) for frame in video_info["frames"]]
    plot_grid(imgs)
    video_info

In [None]:
clues_paragraphs = load_json(
    "/shared/g-luo/geoguessr/data/data/guidebook/text/clues/paragraphs.json"
)
# show_video_info(np.random.choice(clip_samples, 1)[0])

In [None]:
show_asr_sentence_and_video_frame_samples(
    [cs for cs in clip_samples if cs["video_id"] == "HEPyfvK-Vhg"][0]
)
# np.random.choice([cs for cs in clip_samples if cs["video_id"]=="HEPyfvK-Vhg"], 1)[0]

In [None]:
show_asr_sentence_and_video_frame_samples(np.random.choice(clip_samples, 1)[0])

In [None]:
show_asr_sentence_and_video_frame_samples(np.random.choice(clip_samples, 1)[0])

In [None]:
show_asr_sentence_and_video_frame_samples(np.random.choice(clip_samples, 1)[0])

In [None]:
show_asr_sentence_and_video_frame_samples(np.random.choice(clip_samples, 1)[0])

---

---

## Debugging Why Sentence counts aren't matching up with the clue-sim counts

Update: **__Solved__**. Issue was explicitly specifying the punkt tokenizer instead of using `nltk.tokenize` (which also uses punkt). The two methods give slightly different results.

In [None]:
# if "captions_guide" not in locals() or "captions_guide_lookup" not in locals():
#     guide = load_json(f"/shared/g-luo/geoguessr/data/data/guidebook/narrations/train.json")
#     captions_guide = {}
#     captions_guide_lookup = {}
#     for g in tqdm(guide["narrations"]):
#         if g["id"] not in captions_guide:
#             captions_guide[g["id"]] = []
#         captions_guide[g["id"]].append(deepcopy(g))
#         captions_guide_lookup[(g["id"], g["text"])] = g["clue_type"]
# # guide_lookup = {n["text"]: n for n in guide["narrations"]}
# # type(guide["narrations"]), guide["narrations"][0]

In [None]:
# print("cg video_ids: ", len(captions_guide))
# print("cg[K4GXuDACK40] sentences: ", len(captions_guide["K4GXuDACK40"]))
# print("cg total captions: ", len(captions_guide_lookup))
# print(captions_guide["K4GXuDACK40"][-10:])

In [None]:
# print(len(captions_old["K4GXuDACK40"]))
# print(len(captions_new["K4GXuDACK40"]))
# print(len(captions_nltk["K4GXuDACK40"]))

# list(captions_old["K4GXuDACK40"].items())[:10]
# list(captions_nltk["K4GXuDACK40"].items())[:10]

In [None]:
# if "captions_old" not in locals():
#     captions_old = {}
#     for t, video_id in tqdm(enumerate(df_meta.video_id.values)):
#         captions_old[video_id] = sentence_to_timings(df_meta.loc[video_id].to_dict())

# print(len(captions_old))
# print(
#     "total captions ",
#     sum([len(t) for video_id, t in captions_old.items() if video_id in captions_guide]),
# )

In [None]:
# if "captions_new" not in locals():
#     captions_new = {}
#     for video_id in tqdm(df_meta.video_id.values):
#         captions_new[video_id] = sentence_to_timings_punkt(df_meta.loc[video_id].to_dict())

# print(len(captions_new))
# sum([len(t) for video_id, t in captions_new.items() if video_id in captions_guide])
# # list(captions_new.items())[-10:]

In [None]:
# from geoscreens.consts import FRAMES_METADATA_PATH
# from geoscreens.utils import load_json
# fm = load_json(FRAMES_METADATA_PATH)