In [None]:
%autosave 60
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
from pathlib import Path

In [None]:
import json
import logging
import os
import sys
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple, Union, cast

import cv2
import matplotlib as plt
import numpy as np
import pandas as pd
import PIL
import PIL.Image as pil_img
import seaborn as sns
import sklearn as skl
from icevision import models, tfms
from icevision.all import *
from icevision.data import Dataset, DataSplitter, RandomSplitter
from icevision.parsers.coco_parser import COCOBBoxParser
from IPython.display import Image, display
from matplotlib.patches import Rectangle
from matplotlib_inline.backend_inline import set_matplotlib_formats
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from tqdm.contrib import tenumerate, tmap, tzip
from tqdm.contrib.bells import tqdm, trange

from geoscreens.geo_data import GeoScreensDataModule
from geoscreens.models import get_model, load_model_from_path
from geoscreens.modules import LightModelTorch, build_module

In [None]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", 15)
pd.set_option("display.max_rows", 50)
# Suitable default display for floats
pd.options.display.float_format = "{:,.2f}".format
plt.rcParams["figure.figsize"] = (12, 10)

# This one is optional -- change graphs to SVG only use if you don't have a
# lot of points/lines in your graphs. Can also just use ['retina'] if you
# don't want SVG.
%config InlineBackend.figure_formats = ["retina"]
set_matplotlib_formats("pdf", "png")

In [None]:
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("pdf", "png")
plt.rcParams["savefig.dpi"] = 75

plt.rcParams["figure.autolayout"] = False
plt.rcParams["figure.figsize"] = 10, 6
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["font.size"] = 16
plt.rcParams["lines.linewidth"] = 2.0
plt.rcParams["lines.markersize"] = 8
plt.rcParams["legend.fontsize"] = 14
plt.rcParams["text.usetex"] = True

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = "cm"
# plt.rcParams["text.latex.preamble"] = "\\usepackage{subdepth}, \\usepackage{type1cm}"

## Load Data and Build Model

In [None]:
seed_everything(42, workers=True)
DEVICE = torch.device("cuda:0")
config, module, model, light_model = load_model_from_path(
    "/shared/gbiamby/geo/models/geoscreens_009-resnest50_fpn-with_augs/", device=DEVICE
)
model, light_model = model.eval(), light_model.eval()
geoscreens_data = GeoScreensDataModule(config, module)

## Show Some Training Samples

In [None]:
train_ds = geoscreens_data.train_ds

In [None]:
# Show an element of the train_ds with augmentation transformations applied
samples = [train_ds[10] for _ in range(3)]
show_samples(samples, ncols=3)

### Show some validation set samples

In [None]:
module.show_batch(first(geoscreens_data.val_dataloader()), ncols=4)

### Show some predictions

In [None]:
num_samples = 10
size = 30
module.show_results(
    light_model,
    geoscreens_data.valid_ds,
    num_samples=num_samples,
    detection_threshold=0.5,
    device=DEVICE,
    figsize=(size, (size * num_samples) / 2),
)

---

## Naive Detection of Bad Ground Truth Lables

In [None]:
tasks = json.load(open("/shared/gbiamby/geo/exports/geoscreens_009-from_proj_id_58.json", "r", encoding="utf-8"))

mistakes = []
for i, t in enumerate(tqdm(tasks, total=len(tasks))):
    # if i >= 10:
    #     break
    # print("")
    anns_results = [ann["result"] for ann in t["annotations"]]
    # print(anns_results)
    # print([ann for ann in anns_results])
    labels = [ann["value"]["rectanglelabels"][0] for ann in anns_results[0]]
    if len(labels) != len(set(labels)):
        mistakes.append(t)

In [None]:
len(mistakes)

In [None]:
[m["data"] for m in mistakes]

In [None]:
[m["data"] for m in mistakes]

In [None]:
for i, t in enumerate(tqdm(tasks, total=len(tasks))):
    # if i >= 10:
    #     break
    if "aob8sh6l-6M/frame_00000221" in t["data"]["image"]:
        print("")
        print(t["id"], t["data"]["image"])
        anns_results = [ann["result"] for ann in t["annotations"]]
        print("anns_results: ", anns_results, len(anns_results))
        labels = [ann["value"]["rectanglelabels"][0] for ann in anns_results[0]]
        print("labels: ", labels)

---

## Scratch / Junk

### Find/FIlter Duplicates

In [None]:
path_to_task = defaultdict(list)
for t in tasks:
    path_to_task[t["data"]["full_path"]].append(t)
print(len(tasks), len(path_to_task))

c = Counter([t["data"]["full_path"] for t in tasks])
dupes = [k for k, v in c.items() if v > 1]

print("total dupes: ", len(dupes))
to_remove = []
for path in dupes:
    print("")
    print("=" * 100)
    task_blobs = [json.dumps(t, sort_keys=True) for t in path_to_task[path]]
    ann_ids = [t["id"] for t in path_to_task[path]]
    max_id = max(ann_ids)
    # print("ann_ids: ", path_to_task[path])
    print("ann_ids: ", ann_ids)
    # for t in task_blobs:
    #     print("")
    #     print(t)
    print("Removing: ")
    for t in path_to_task[path]:
        if t["id"] != max_id:
            print("Removing task_id: ", t["id"])
            to_remove.append((t["id"], path))

to_remove

In [None]:
tasks_filtered = []

for t in tasks:
    if (t["id"], t["data"]["full_path"]) in to_remove:
        continue
    tasks_filtered.append(t)

print(len(tasks), len(tasks_filtered))

### Save

In [None]:
json.dump(
    tasks_filtered,
    open(Path("/shared/gbiamby/geo/geoscreens_004_tasks_with_preds.json"), "w"),
    indent=4,
    sort_keys=True,
)

---

## Get dets for an image

In [None]:
imgs = [
    image_from_url(
        "/shared/gbiamby/geo/screenshots/screen_samples_auto/-lPrvqk2mqs/frame_00000104.jpg"
    )
]
infer_ds = Dataset.from_images(imgs, infer_tfms)
batch, samples = models.torchvision.retinanet.build_infer_batch(infer_ds)
preds = models.torchvision.retinanet.predict(model, infer_ds, detection_threshold=0.4)
[(p.detection.scores, p.detection.label_ids, p.detection.bboxes) for p in preds]

---