In [None]:
%autosave 60
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import json
import logging
import os
import sys
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple, Union, cast

import cv2
import decord
import matplotlib as plt
import numpy as np
import pandas as pd
import PIL
import PIL.Image as pil_img
import seaborn as sns
import sklearn as skl
from icevision import models, tfms
from icevision.all import *
from icevision.data import Dataset, DataSplitter, RandomSplitter
from icevision.parsers.coco_parser import COCOBBoxParser
from IPython.display import Image, display
from matplotlib.patches import Rectangle
from matplotlib_inline.backend_inline import set_matplotlib_formats
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from tqdm.contrib import tenumerate, tmap, tzip
from tqdm.contrib.bells import tqdm, trange

from geoscreens.geo_data import GeoScreensDataModule
from geoscreens.models import get_model, load_model_from_path
from geoscreens.modules import LightModelTorch, build_module
from geoscreens.utils import batchify, load_json, timeit_context

In [None]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", 15)
pd.set_option("display.max_rows", 50)
# Suitable default display for floats
pd.options.display.float_format = "{:,.2f}".format
plt.rcParams["figure.figsize"] = (12, 10)

# This one is optional -- change graphs to SVG only use if you don't have a
# lot of points/lines in your graphs. Can also just use ['retina'] if you
# don't want SVG.
%config InlineBackend.figure_formats = ["retina"]
set_matplotlib_formats("pdf", "png")

In [None]:
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("pdf", "png")
plt.rcParams["savefig.dpi"] = 75

plt.rcParams["figure.autolayout"] = False
plt.rcParams["figure.figsize"] = 10, 6
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["font.size"] = 16
plt.rcParams["lines.linewidth"] = 2.0
plt.rcParams["lines.markersize"] = 8
plt.rcParams["legend.fontsize"] = 14
plt.rcParams["text.usetex"] = True

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = "cm"
# plt.rcParams["text.latex.preamble"] = "\\usepackage{subdepth}, \\usepackage{type1cm}"

## Load Data and Build Model

In [None]:
seed_everything(42, workers=True)
DEVICE = torch.device("cuda:0")
config, module, model, light_model = load_model_from_path(
    # "/shared/gbiamby/geo/models/geoscreens_009-resnest50_fpn-with_augs/",
    "/home/gbiamby/proj/geoscreens/tools/output/keep/gs_012_extra_augs_more_epochs--geoscreens_012-model_faster_rcnn-bb_resnest50_fpn-36e514692a/",
    device=DEVICE,
)
model, light_model = model.eval(), light_model.eval()
geoscreens_data = GeoScreensDataModule(config, module)

## Show Some Training Samples

In [None]:
train_ds = geoscreens_data.train_ds

In [None]:
# Show an element of the train_ds with augmentation transformations applied
samples = [train_ds[10] for _ in range(3)]
show_samples(samples, ncols=3)

### Show some validation set samples

In [None]:
module.show_batch(first(geoscreens_data.val_dataloader()), ncols=4)

### Show some predictions

In [None]:
num_samples = 10
size = 30
module.show_results(
    light_model,
    geoscreens_data.valid_ds,
    num_samples=num_samples,
    detection_threshold=0.5,
    device=DEVICE,
    figsize=(size, (size * num_samples) / 2),
)

---

# Prediction Testing Dataloader and Batching

In [None]:
from icevision.core import ClassMap
from icevision.core.record import BaseRecord
from icevision.core.record_components import ClassMapRecordComponent, ImageRecordComponent
from icevision.tfms import Transform
from PIL import Image


class GeoscreensInferenceDataset(object):
    """
    Only usable for inference.

    Provides a dataset over a folder with video frames in form::

        <video_id_1>/
            frame_....jpg
        <video_id_2>/
            frame_....jpg

    If no video_id specified, the dataset will loop over all <video_id>
    subfolders and include all frames in each.
    """

    def __init__(
        self,
        frames_path: Union[str, Path],
        class_map: ClassMap,
        video_ids: Union[int, List[int]] = None,
        tfm: Optional[Transform] = None,
    ):
        self.frames_path = Path(frames_path).resolve()
        assert self.frames_path.exists(), f"Frames path not found: {self.frames_path}"
        assert self.frames_path.is_dir(), f"Frames path is not a directory: {self.frames_path}"
        if video_ids and isinstance(video_ids, str):
            video_ids = [video_ids]
        elif video_ids is None:
            video_ids = []
        self.tfm = tfm
        self.class_map = class_map
        self.frames = []
        record_id: int = 0
        print("video_ids")
        for video_id in video_ids:
            frames = sorted((self.frames_path / video_id).glob("*.jpg"))
            print("Num frames found: ", len(frames))
            for f in frames:
                record = BaseRecord((ImageRecordComponent(),))
                record.set_record_id(record_id)
                # record.set_img(image)

                # TODO, HACK: adding class map because of `convert_raw_prediction`
                record.add_component(ClassMapRecordComponent(task=tasks.detection))
                if class_map is not None:
                    record.detection.set_class_map(class_map)
                parts = f.stem.replace("frame_", "").replace("s", "").split("-")
                self.frames.append(
                    {
                        "video_id": video_id,
                        "frame_idx": -1,
                        "file_path": f,
                        "frame_idx": int(parts[0]),
                        "seconds": round(float(parts[1]), 2),
                        "record": record,
                    }
                )
                record_id += 1

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, i: int):
        meta = self.frames[i]
        record = meta["record"]
        img = np.array(Image.open(str(meta["file_path"])))
        record.set_img(img)
        record.load()
        if self.tfm is not None:
            record = self.tfm(record)
        # else:
        #     # HACK FIXME
        #     # record.set_img(np.array(record.img))
        #     pass
        return record

    def __repr__(self):
        return f"<{self.__class__.__name__} with {len(self.records)} items>"

In [None]:
# video_path = Path("/shared/gbiamby/geo/video_frames/pF9OA332DPk.mp4")
frames_path = "/shared/gbiamby/geo/video_frames"
infer_tfms = tfms.A.Adapter(
    [*tfms.A.resize_and_pad(config.dataset_config.img_size), tfms.A.Normalize()]
)
infer_ds = GeoscreensFramesDataset(
    frames_path, geoscreens_data.parser.class_map, "pF9OA332DPk", infer_tfms
)
infer_dl = module.infer_dl(infer_ds, batch_size=8, shuffle=False, num_workers=16)

print("len ds: ", len(infer_ds))
preds = module.predict_from_dl(model, infer_dl, detection_threshold=0.5)
preds

In [None]:
preds[0].pred

In [None]:
f_name = "frame_00039798-001326.600s.jpg"
parts = f_name.replace("frame_", "").replace(".jpg", "").split("-")
frame_idx = int(parts[0])
seconds = round(float(parts[1].replace("s", "")), 2)
frame_idx, seconds

In [None]:
def get_detections_from_generator():
    raw_frames = [np.array(frame)]
    infer_ds = Dataset.from_images(
        raw_frames, infer_tfms, class_map=geoscreens_data.parser.class_map
    )
    preds = module.predict(model, infer_ds, detection_threshold=0.5)
    if preds:
        assert len(preds) == 1, "Expected list of size 1."
        preds = preds[0]
        detections[frame_counter] = {
            "label_ids": [int(l) for l in preds.detection.label_ids],
            "scores": preds.detection.scores.tolist(),
            "bboxes": [
                {
                    "xmin": float(box.xmin),
                    "ymin": float(box.ymin),
                    "xmax": float(box.xmax),
                    "ymax": float(box.ymax),
                }
                for box in preds.detection.bboxes
            ],
        }


@timeit_context("")
def get_frames_wrapper(fn, config, video_path):
    return [f for f in fn(config, video_path)]


def get_indices_to_sample(config, total_frames: int, fps: float) -> List[int]:
    indices = map(
        int,
        np.linspace(
            start=0.0,
            stop=total_frames,
            num=int(total_frames * (config.frame_sample_rate_fps / fps)),
            retstep=False,
            endpoint=False,
        ),
    )
    return list(indices)

In [None]:
# def get_frames_generator_opencv(
#     config: DictConfig,
#     video_path: Path,
# ):
#     print("Segmenting video: ", video_path)
#     error_state = False
#     cap = cv2.VideoCapture(str(video_path))
#     if not cap.isOpened():
#         print("Error opening input video: {}".format(video_path))
#         return

#     num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
#     fps = cap.get(cv2.CAP_PROP_FPS)
#     sample_indices = get_indices_to_sample(config, num_frames, fps)
#     print(f"total_frames: {num_frames:,}, num_to_sample: {len(sample_indices):,}, fps: {fps}")
#     print("config.frame_sample_rate_fps: ", config.frame_sample_rate_fps)
#     for frame_counter in tqdm(range(len(sample_indices)), total=len(sample_indices)):
#         frame_idx = sample_indices[frame_counter]
#         if config.fast_debug and frame_counter >= config.debug_max_frames:
#             break
#         seconds = frame_idx / fps
#         cap.set(cv2.CAP_PROP_POS_MSEC, (seconds * 1000))
#         ret, frame = cap.read()
#         if not ret:
#             raise Error(f"Error while processing video_id: {video_path} (ret:{ret}")
#             break
#         yield (seconds, frame_idx, frame)


# video_path = Path("/home/gbiamby/proj/geoscreens/data/videos/pF9OA332DPk.mp4")
# config = DictConfig(
#     {
#         "frame_sample_rate_fps": 4.0,
#         "fast_debug": False,
#         "debug_max_frames": 300,
#     }
# )

# frames_cv = get_frames_wrapper(get_frames_generator_opencv, config, video_path)
# print("num_frames sampled: ", len(frames_cv))

In [None]:
from decord import VideoReader, cpu, gpu


def get_frames_generator_decord(config, video_path):
    vr = VideoReader(str(video_path), ctx=cpu(0))
    sample_indices = get_indices_to_sample(config, len(vr), vr.get_avg_fps())
    print(
        f"num_frames: {len(vr):,}, num_to_sample: {len(sample_indices):,}, fps: {vr.get_avg_fps()}"
    )
    print("config.frame_sample_rate: ", config.frame_sample_rate_fps)
    for sample_idx in tqdm(range(len(sample_indices)), total=len(sample_indices)):
        frame_idx = sample_indices[sample_idx]
        if config.fast_debug and sample_idx >= config.debug_max_frames:
            break
        frame = vr[frame_idx]
        seconds = round(frame_idx / vr.get_avg_fps(), 2)
        yield (frame_idx, seconds, frame)


# video_path = Path("/home/gbiamby/proj/geoscreens/data/videos/pF9OA332DPk.mp4")
# config = DictConfig(
#     {
#         "frame_sample_rate_fps": 4.0,
#         "fast_debug": True,
#         "debug_max_frames": 30,
#         "video_frames_path": "/home/gbiamby/proj/geoscreens/data/video_frames",
#     }
# )

# frames_decord = get_frames_wrapper(get_frames_generator_decord, config, video_path)
# print("num_frames sampled: ", len(frames_decord))
# frames_decord[:10], frames_decord[:-10]

In [None]:
from typing import Callable


@timeit_context("extract_frames")
def extract_frames(config: DictConfig, video_path: Path, get_frames_fn: Callable):
    frames_path = Path(config.video_frames_path) / video_path.stem
    frames_path.mkdir(exist_ok=True, parents=True)
    print("Saving frames to: ", frames_path)
    for frame_idx, seconds, frame in get_frames_fn(config, video_path):
        frame_out_path = frames_path / f"frame_{frame_idx:08}-{seconds:010.3f}s.jpg"
        cv2.imwrite(str(frame_out_path), cv2.cvtColor(frame.asnumpy(), cv2.COLOR_RGB2BGR))


video_path = Path("/shared/g-luo/geoguessr/videos/pF9OA332DPk.mp4")
config = DictConfig(
    {
        "frame_sample_rate_fps": 4.0,
        "fast_debug": False,
        "debug_max_frames": 30,
        "video_frames_path": "/shared/gbiamby/geo/video_frames",
    }
)
extract_frames(config, video_path, get_frames_generator_decord)

In [None]:
from multiprocessing import Pool


def extract_frames_fake(config: DictConfig, video_path: Path, get_frames_fn: Callable):
    frames_path = Path(config.video_frames_path) / video_path.stem
    frames_path.mkdir(exist_ok=True, parents=True)
    print("Saving frames to: ", frames_path)


def process_videos_muli_cpu(config: DictConfig):
    files = sorted(Path(config.videos_path).glob("*.mp4"))
    print(len(files))

    with Pool(processes=4) as pool:
        result = pool.map(extract_frames_fake, (config, files))
        print(result.get(timeout=1))


config = DictConfig(
    {
        "frame_sample_rate_fps": 4.0,
        "fast_debug": False,
        "debug_max_frames": 30,
        "video_frames_path": "/shared/gbiamby/geo/video_frames",
        "videos_path": "/shared/g-luo/geoguessr/videos",
        "num_workers": 4,
    }
)
process_videos_muli_cpu(config)

In [None]:
# from geoscreens.utils import timeit_context


# # Using the decord batching is somehow slower than just using the VideoReader indexing, i.e,
# # get_frames_generator_decord().
# @timeit_context("get_frames_generator_decord_batched")
# def get_frames_generator_decord_batched(config, video_path):
#     vr = VideoReader(str(video_path), ctx=cpu(0))
#     indices = get_indices_to_sample(config, len(vr), vr.get_avg_fps())

#     print(f"num_frames: {len(vr):,}, fps: {vr.get_avg_fps()}")
#     print("config.frame_sample_rate: ", config.frame_sample_rate_fps)

#     if config.fast_debug and len(indices) > config.debug_max_frames:
#         indices = indices[: config.debug_max_frames]
#     frames = vr.get_batch(indices).asnumpy()
#     yield from frames


# video_path = Path("/shared/g-luo/geoguessr/videos/pF9OA332DPk.mp4")
# config = DictConfig(
#     {
#         "frame_sample_rate_fps": 4.0,
#         "fast_debug": True,
#         "debug_max_frames": 10000,
#     }
# )

# frames = get_frames_wrapper(get_frames_generator_decord_batched, config, video_path)
# print("num_frames sampled: ", len(frames))

In [None]:
# To get multiple frames at once, use get_batch
# this is the efficient way to obtain a long list of frames
frames = vr.get_batch([1, 3, 5, 7, 9])
print(frames.shape)
# (5, 240, 320, 3)
# duplicate frame indices will be accepted and handled internally to avoid duplicate decoding
frames2 = vr.get_batch([1, 2, 3, 2, 3, 4, 3, 4, 5]).asnumpy()
print(frames2.shape)
# (9, 240, 320, 3)

# 2. you can do cv2 style reading as well
# skip 100 frames
vr.skip_frames(100)
# seek to start
vr.seek(0)
batch = vr.next()
print("frame shape:", batch.shape)
print("numpy frames:", batch.asnumpy())

In [None]:
# from torchvision import transforms as t
# from torchvision.datasets.folder import make_dataset


# def get_samples(root, extensions=(".mp4", ".avi")):
#     _, class_to_idx = _find_classes(root)
#     return make_dataset(root, class_to_idx, extensions=extensions)


# class RandomDataset(torch.utils.data.IterableDataset):
#     def __init__(
#         self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16,
#         video_id: str =
#     ):
#         super(RandomDataset).__init__()

#         self.samples = []

#         # Allow for temporal jittering
#         if epoch_size is None:
#             epoch_size = len(self.samples)
#         self.epoch_size = epoch_size

#         self.clip_len = clip_len
#         self.frame_transform = frame_transform
#         self.video_transform = video_transform

#     def __iter__(self):
#         for i in range(self.epoch_size):
#             # Get random sample
#             path, target = random.choice(self.samples)
#             # Get video object
#             vid = torchvision.io.VideoReader(path, "video")
#             metadata = vid.get_metadata()
#             video_frames = []  # video frame buffer

#             # Seek and return frames
#             max_seek = metadata["video"]["duration"][0] - (
#                 self.clip_len / metadata["video"]["fps"][0]
#             )
#             start = random.uniform(0.0, max_seek)
#             for frame in itertools.islice(vid.seek(start), self.clip_len):
#                 video_frames.append(self.frame_transform(frame["data"]))
#                 current_pts = frame["pts"]
#             # Stack it into a tensor
#             video = torch.stack(video_frames, 0)
#             if self.video_transform:
#                 video = self.video_transform(video)
#             output = {
#                 "path": path,
#                 "video": video,
#                 "target": target,
#                 "start": start,
#                 "end": current_pts,
#             }
#             yield output

---

## Naive Detection of Bad Ground Truth Lables

In [None]:
tasks = json.load(
    open("/shared/gbiamby/geo/exports/geoscreens_009-from_proj_id_58.json", "r", encoding="utf-8")
)

mistakes = []
for i, t in enumerate(tqdm(tasks, total=len(tasks))):
    # if i >= 10:
    #     break
    # print("")
    anns_results = [ann["result"] for ann in t["annotations"]]
    # print(anns_results)
    # print([ann for ann in anns_results])
    labels = [ann["value"]["rectanglelabels"][0] for ann in anns_results[0]]
    if len(labels) != len(set(labels)):
        mistakes.append(t)

In [None]:
len(mistakes)

In [None]:
[m["data"] for m in mistakes]

In [None]:
[m["data"] for m in mistakes]

In [None]:
for i, t in enumerate(tqdm(tasks, total=len(tasks))):
    # if i >= 10:
    #     break
    if "aob8sh6l-6M/frame_00000221" in t["data"]["image"]:
        print("")
        print(t["id"], t["data"]["image"])
        anns_results = [ann["result"] for ann in t["annotations"]]
        print("anns_results: ", anns_results, len(anns_results))
        labels = [ann["value"]["rectanglelabels"][0] for ann in anns_results[0]]
        print("labels: ", labels)

---

## Scratch / Junk

### Find/FIlter Duplicates

In [None]:
path_to_task = defaultdict(list)
for t in tasks:
    path_to_task[t["data"]["full_path"]].append(t)
print(len(tasks), len(path_to_task))

c = Counter([t["data"]["full_path"] for t in tasks])
dupes = [k for k, v in c.items() if v > 1]

print("total dupes: ", len(dupes))
to_remove = []
for path in dupes:
    print("")
    print("=" * 100)
    task_blobs = [json.dumps(t, sort_keys=True) for t in path_to_task[path]]
    ann_ids = [t["id"] for t in path_to_task[path]]
    max_id = max(ann_ids)
    # print("ann_ids: ", path_to_task[path])
    print("ann_ids: ", ann_ids)
    # for t in task_blobs:
    #     print("")
    #     print(t)
    print("Removing: ")
    for t in path_to_task[path]:
        if t["id"] != max_id:
            print("Removing task_id: ", t["id"])
            to_remove.append((t["id"], path))

to_remove

In [None]:
tasks_filtered = []

for t in tasks:
    if (t["id"], t["data"]["full_path"]) in to_remove:
        continue
    tasks_filtered.append(t)

print(len(tasks), len(tasks_filtered))

### Save

In [None]:
json.dump(
    tasks_filtered,
    open(Path("/shared/gbiamby/geo/geoscreens_004_tasks_with_preds.json"), "w"),
    indent=4,
    sort_keys=True,
)

---

---

In [None]:
213 % 10, 213 // 10