In [None]:
import geopandas as gpd
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import polars as pl
from tqdm.auto import tqdm

from lt_lib.core.matching import POLARS_AP_DF_SCHEMA
from lt_lib.core.train import append_nested_dict_with_0

from lt_lib.data.datasets import POLARS_GTS_SCHEMA
import lt_lib.data.preprocessing as preprocessing

from lt_lib.entrypoints.run import run, RunCliArgs
from lt_lib.entrypoints.optimization import optimization, OptimizationCliArgs

from lt_lib.orchestration.task_orchestrator import TaskOrchestrator

from lt_lib.schemas.config_files_schemas import RunConfig, ModelConfig

from lt_lib.utils.load_and_save import load_pytorch_checkpoint, load_json_as_dict
from lt_lib.utils.dict_utils import flatten_dict
from lt_lib.utils.regex_matcher import get_elements_with_regex
from lt_lib.utils.update_outdated_objects import update_checkpoint_metrics_dict

%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

In [None]:
max_t = 4
min_t = 1
rf = 2
s = 0

MAX_RUNGS = int(np.log(max_t / min_t) / np.log(rf) - s + 1)

[(min_t * rf ** (k + s), {}) for k in reversed(range(MAX_RUNGS))]

In [None]:
np.nanpercentile([0.5, 0.7, 0.8, 0.55, 0.6, 0.57, 0.4], (1 - 1 / rf) * 100)

In [None]:
import numpy as np

In [None]:
c = load_pytorch_checkpoint(Path("/path/to/retinanet-E7.tar"))

In [None]:
c["metrics"]["val"].keys()

In [None]:
update_checkpoint_metrics_dict(Path("/path/to/retinanet-E7.tar"))

In [None]:
b = pl.from_dict({"id": [], "img_name":[], "confidence": [], "label": [], "iou": [], "match": [], "correct_match": []})

In [None]:
a = pl.DataFrame(schema=POLARS_AP_DF_SCHEMA)

In [None]:
import numpy as np
from sklearn.metrics import average_precision_score
import torch
from lt_lib.core.matching import boxes_iou
from scipy.optimize import linear_sum_assignment


In [None]:
gts_path = Path("/content/datasets/dataset_v1/val/annotations/gts.csv")

# Load gts and predictions csv files
predictions = pl.read_csv("/path/to/predictions.csv")
gts = pl.read_csv(gts_path).cast(POLARS_GTS_SCHEMA, strict=True)

label_to_label_name_dict = load_json_as_dict(Path(gts_path).parent / "label_to_label_name.json")

In [None]:
gts.columns

In [None]:
predictions.columns

In [None]:
len(predictions)

In [None]:
test_df = pl.DataFrame()
test_df = test_df.with_columns(label=np.array([1,2,3]))
test_df = test_df.with_columns(img_name=pl.lit("img_test.jpg"))
test_df = test_df.with_columns(confidence=pl.lit(0))


In [None]:
test_df

In [None]:
predictions = predictions.with_columns(confidence=pl.max_horizontal(pl.col(f"^.*confidence.*$")))
predictions = predictions.filter(pl.col("confidence") > 0.3)
predictions.drop_in_place("id")
predictions = predictions.with_row_index("id")

In [None]:
# Get the image names list
img_name_list = gts["img_name"].unique().to_list()

AP_df = predictions.select(["id", "img_name", "probable_label", "confidence"])
AP_df = AP_df.rename({"probable_label": "label"})
AP_df = AP_df.with_columns(iou=pl.lit(0.0, dtype=pl.Float32))
AP_df = AP_df.with_columns(matched=pl.lit(0))

# Processes images 1 by 1
for img_name in tqdm(img_name_list):

    # Filter image gts and predictions
    img_gts = gts.filter(pl.col("img_name") == img_name)
    img_predictions = predictions.filter(pl.col("img_name") == img_name)

    # Get image gts and predictions bboxes
    img_gts_bboxes = img_gts[["bbox_xmin", "bbox_ymin", "bbox_xmax", "bbox_ymax"]].to_numpy()
    img_gts_labels = img_gts["label"].to_numpy()
    img_predictions_bboxes = img_predictions[["bbox_xmin", "bbox_ymin", "bbox_xmax", "bbox_ymax"]].to_numpy()
    img_predictions_id = img_predictions["id"].to_numpy()
    img_predictions_probable_labels = img_predictions["probable_label"].to_numpy()

    # Computes the iou matrix between detecions and gts
    iou_matrix = boxes_iou(img_predictions_bboxes, img_gts_bboxes)

    # Hugarian algorithm to match predictions with gts. It returns predictions_indexes and column_indexes matched
    predictions_idx, gts_idx = linear_sum_assignment(1 - iou_matrix)

    AP_df = AP_df.with_columns(
        iou=AP_df["iou"].scatter(img_predictions_id[predictions_idx], iou_matrix[predictions_idx, gts_idx])
    )

    mask_correct_matched = img_gts_labels[gts_idx] == img_predictions_probable_labels[predictions_idx]
    AP_df = AP_df.with_columns(
        matched=AP_df["matched"].scatter(img_predictions_id[predictions_idx], mask_correct_matched.astype(int))
    )
    
    mask = np.ones(img_gts_labels.shape, dtype=bool)
    mask[gts_idx] = False
    unmatched_gts_labels = img_gts_labels[mask]
    if len(unmatched_gts_labels) != 0:
        unmatched_gts = pl.DataFrame({"label":unmatched_gts_labels})
        unmatched_gts = unmatched_gts.insert_column(0, pl.Series("img_name", [img_name] * len(unmatched_gts_labels)))
        unmatched_gts = unmatched_gts.with_columns(
            confidence=pl.lit(0.0, dtype=pl.Float32),
            iou=pl.lit(0.0, dtype=pl.Float32),
            matched=pl.lit(0),
        )
        unmatched_gts = unmatched_gts.with_row_index("id", offset=len(AP_df))

        AP_df = pl.concat([AP_df, unmatched_gts], how="vertical_relaxed")


In [None]:
AP_df_filtered = AP_df.filter(pl.col("label").is_in([1,2,3]))
AP_df_filtered.drop_in_place("id")
AP_df_filtered = AP_df_filtered.with_row_index("id")

AP_per_thresh = []
for iou_thresh in np.linspace(0.5, 0.95, 10):
    too_low_iou_ids = AP_df_filtered.filter(pl.col("matched") == 1, pl.col("iou").is_between(0, iou_thresh, closed="right"))["id"]
    AP_df_filtered = AP_df_filtered.with_columns(
        matched=AP_df_filtered["matched"].scatter(too_low_iou_ids, 0)
    )
    fn_df = AP_df_filtered.filter(pl.col("id").is_in(too_low_iou_ids.to_list()))
    fn_df = fn_df.with_columns(confidence=pl.lit(0.0), iou=pl.lit(0.0))
    fn_df.drop_in_place("id")
    fn_df = fn_df.with_row_count("id", offset=len(AP_df_filtered))
    AP_df_filtered = pl.concat([AP_df_filtered, fn_df], how="vertical_relaxed")

    AP_per_label = []
    for label in  AP_df_filtered["label"].unique():
        AP_df_filtered_per_label = AP_df_filtered.filter(pl.col("label") == label)
        AP_per_label.append(average_precision_score(AP_df_filtered_per_label["matched"], AP_df_filtered_per_label["confidence"]))
    AP_per_thresh.append(np.mean(AP_per_label))

np.mean(AP_per_thresh)

In [None]:
# ROOT_DIR_PATH = Path("/content/datasets/dataset_v0/train")

# global_gdf, minimal_gts = preprocessing.get_all_annotations_from_rareplanes_geojsons(
#     root_dir_path = ROOT_DIR_PATH,
#     tiled_version = True,
#     imgs_extension=".png",
#     save_to_file = True
# )

In [None]:
# gts_csv_path = ROOT_DIR_PATH / "annotations/gts.csv")
# gts = pl.read_csv(gts_csv_path)
# np.unique(gts["label"])

In [None]:
# gts = gts.with_columns(pl.col("label").replace({2:1, 4: 2, 6: 3}))
# gts.cast(POLARS_GTS_SCHEMA, strict=True).write_csv(gts_csv_path)

# minimal_gts_json = {}
# for img_name in np.unique(gts["img_name"].to_list()):
#     sub_img_df = gts.filter(pl.col("img_name") == img_name)
#     minimal_gts_json[img_name] = {
#         "ids": sub_img_df["id"].to_list(),
#         "bboxes": sub_img_df[["bbox_xmin", "bbox_ymin", "bbox_xmax", "bbox_ymax"]].to_numpy().tolist(),
#         "labels": sub_img_df["label"].to_list(),
#     }
# with open(Path("/".join(gts_csv_path.parts)[1:-3] + 'json'), "w") as file:
#     json.dump(minimal_gts_json, file)

In [None]:
task_orchestrator = TaskOrchestrator(
    inputs_dir=Path("/content/datasets/dataset_test"),
    outputs_dir=Path("/content/outputs/outputs_test"),
    config_path="predict_config_test.yaml",
    model_config_path="model_config_test.yaml",
    resume=False,
)

In [None]:
args = RunCliArgs(
    inputs_directory="/content/datasets/dataset_test",
    outputs_directory="/content/outputs/outputs_test",
    model_config_path="model_config_test.yaml",
    config_path="full_config_test.yaml",
    resume=True,
)

run(args)

In [None]:
args = OptimizationCliArgs(
    inputs_directory="/content/datasets/dataset_test",
    outputs_directory="/content/outputs/outputs_test",
    # model_config_path="../configs/model_config_test.yaml",
    config_path="/content/master-thesis-draft/configs/thresh_eval_config.yaml",
    optimization_config_path="/content/master-thesis-draft/configs/optimization_config.py",
    restore_dir_path="",
)

optimization(args)

### Xview exploration

In [None]:
gdf = gpd.read_file(Path("/path/to/xView_train.geojson"))
gdf.head()

In [None]:
df = pl.DataFrame(gdf.drop(columns=["geometry", "grid_file", "feature_id", "point_geom", "index_right", "ingest_time", "cat_id", "edited_by"]))

In [None]:
df.filter(pl.col("type_id") == 18)

### Synthetic data wingspan exploration

In [None]:
gts = pl.read_csv("/content/datasets/synthetic_data_sampled_10percent_seed42/annotations/gts.csv")

In [None]:
gts.head()

In [None]:
imgs_resolution = 0.31
print(15/imgs_resolution)
print(36/imgs_resolution)

In [None]:
gts.filter(pl.col("wingspan").is_between(112.5, 120))