In [None]:
%autosave 60
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import json
import os
import pickle
from collections import Counter, OrderedDict
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import cv2
import matplotlib as plt
import numpy as np
import pandas as pd
import PIL.Image as pil_img
import seaborn as sns
import sklearn as skl
from IPython.display import Image, display
from matplotlib.patches import Rectangle
from matplotlib_inline.backend_inline import set_matplotlib_formats
from sklearn.cluster import KMeans
from tqdm.contrib import tenumerate, tmap, tzip
from tqdm.contrib.bells import tqdm, trange

from geoscreens.consts import (
    DATASET_PATH,
    DETECTIONS_PATH,
    FRAMES_METADATA_PATH,
    LATEST_DETECTION_MODEL_NAME,
    PROJECT_ROOT,
)
from geoscreens.data import get_all_geoguessr_split_metadata, load_detections
from geoscreens.utils import load_json, save_json

In [None]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", 15)
pd.set_option("display.max_rows", 50)
# Suitable default display for floats
pd.options.display.float_format = "{:,.2f}".format
plt.rcParams["figure.figsize"] = (12, 10)

# This one is optional -- change graphs to SVG only use if you don't have a
# lot of points/lines in your graphs. Can also just use ['retina'] if you
# don't want SVG.
%config InlineBackend.figure_formats = ["retina"]
set_matplotlib_formats("pdf", "png")

In [None]:
set_matplotlib_formats("pdf", "png")
plt.rcParams["savefig.dpi"] = 75

plt.rcParams["figure.autolayout"] = False
plt.rcParams["figure.figsize"] = 10, 6
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["font.size"] = 16
plt.rcParams["lines.linewidth"] = 2.0
plt.rcParams["lines.markersize"] = 8
plt.rcParams["legend.fontsize"] = 14
plt.rcParams["text.usetex"] = True

# plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = "cm"
# plt.rcParams["text.latex.preamble"] = "\\usepackage{subdepth}, \\usepackage{type1cm}"

# Load all detections

In [None]:
# model = "gsmoreanch02_012--geoscreens_012-model_faster_rcnn-bb_resnest50_fpn-2b72cbf305"
model = LATEST_DETECTION_MODEL_NAME
CACHE_PATH = PROJECT_ROOT / "notebooks" / "cache" / "all_video_dets.pkl"

In [None]:
def get_video_dets(video_ids, df_meta):
    video_dets = {}

    for i, video_id in tenumerate(video_ids):
        # if i > 2:
        #     break
        if video_id in df_meta.index:
            df = load_detections(video_id, df_meta.loc[video_id].split, model=model)
            video_dets[video_id] = {
                "video_id": video_id,
                "label_ids": df.label_ids.values.tolist(),
                "labels": df.labels.values.tolist(),
            }
    return video_dets


def count_video_labels(video_dets: dict, ignore=None):
    label_counts = {}
    cats = load_json(DATASET_PATH)["categories"]

    for i, (video_id, vid_info) in tenumerate(video_dets.items()):
        counter = Counter({c["name"]: 0 for c in cats})
        for labels in vid_info["labels"]:
            counter.update(labels)
        if ignore:
            for ignore_name in ignore:
                counter[ignore_name] = 0
        label_counts[video_id] = counter

    return label_counts


def get_vectors(label_counts: dict):
    cats = load_json(DATASET_PATH)["categories"]
    cat_name_to_id = {c["name"]: c["id"] for c in cats}
    vectors = []
    for video_id, vid_counts in label_counts.items():
        counts_vector = [
            count[1] for count in sorted(vid_counts.items(), key=lambda x: cat_name_to_id[x[0]])
        ]
        vectors.append(np.array(counts_vector))
    return vectors

In [None]:
if True or "df_meta" not in locals():
    df_meta = pd.DataFrame(
        get_all_geoguessr_split_metadata(
            force_include=["nemo_caption", "nemo_caption_entities"]
        ).values()
    ).set_index("id")

video_ids = sorted(
    [
        d.stem.replace("df_frame_dets-video_id_", "")
        for d in (DETECTIONS_PATH / model).glob("**/*.pkl")
    ]
)

if False and CACHE_PATH.exists():
    data = pickle.load(open(CACHE_PATH, "rb"))
    video_dets = data["video_dets"]
    count_vectors = data["count_vectors"]
else:
    video_dets = get_video_dets(video_ids, df_meta)
    label_counts = count_video_labels(video_dets, ignore=["video", "url"])
    count_vectors = get_vectors(label_counts)

print("num video_ids: ", len(video_ids))
print("num video_dets: ", len(video_dets))
print("num count_vectors: ", len(count_vectors))

In [None]:
# pickle.dump(
#     {"video_dets": video_dets, "count_vectors": count_vectors},
#     open(CACHE_PATH, "wb"),
# )

---

## Cluster the Detections 

In [None]:
def cluster(count_vectors, num_clusters: int = 4, normalize=True):
    if normalize:
        count_vectors = [v / np.linalg.norm(v) for v in count_vectors]
    kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(count_vectors)
    clusters = kmeans.cluster_centers_[np.argsort(kmeans.cluster_centers_[:, 1])]
    result = {
        "labels": kmeans.labels_,
        "clusters": clusters,
        "kmeans": kmeans,
    }
    # print("kmeans.labels: ", kmeans.labels_)
    # print("kmeans.cluster_centers: ", kmeans.cluster_centers_)
    return result


results = {}
for nc in [3, 4, 5, 6, 7, 8, 9, 10]:
    result = cluster(count_vectors, num_clusters=nc)
    results[nc] = result
    print(nc, Counter(result["kmeans"].labels_))

In [None]:
Counter(results[4]["kmeans"].labels_)

In [None]:
import matplotlib.pyplot as plt


def inspect_clustering(kmeans: KMeans):
    cats = load_json(DATASET_PATH)["categories"]
    cat_name_to_id = {c["name"]: c["id"] for c in cats}
    plt.rcParams["figure.figsize"] = (10, 25)
    # print(kmeans.cluster_centers_[0])
    for cluster in kmeans.cluster_centers_:
        sns.barplot(x=cluster.tolist(), y=list(cat_name_to_id.keys()))
        plt.show()
    print(kmeans)


inspect_clustering(results[10]["kmeans"])

### How Many are battle Royale


In [None]:

count_vectors = [v / np.linalg.norm(v) for v in count_vectors]
cats = load_json(DATASET_PATH)["categories"]
cat_name_to_id = {c["name"]: c["id"] for c in cats}
len([v for v in count_vectors if v[cat_name_to_id["br_players_box_white"]] > 0.05])