# 1. Definition
## 1.1 Imports

In [None]:
from pathlib import Path
import polars as pl

from lt_lib.viz.experiment_viz_utils import (
    get_result_grid, 
    filter_df_with_dict, 
    rename_config_params_column_name, 
    keep_non_dominated_points, 
)
from lt_lib.viz.plot_utils import plot_px_line

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

# 2. Load saved experiment

In [None]:
experiment_path = Path("expriement")
results = get_result_grid(experiment_path)
all_trials_df = pl.from_pandas(results.get_dataframe()).with_row_index("id")
all_trials_df = rename_config_params_column_name(all_trials_df)

In [None]:
results.get_best_result("custom_metrics.super_metric", "max").metrics_dataframe["custom_metrics.super_metric"]

# 3. Visualization

In [None]:
X = "level1.recall"
Y = "level1.f1"
GROUP = "config/nms_iou_threshold"
# ADDITIONAL_HOVER_DATA = ["config/threshold"]
ADDITIONAL_HOVER_DATA = ["level1.f1", "custom_metrics.super_metric", "config/nms_iou_threshold"]
# ADDITIONAL_HOVER_DATA = ["level1.f1", "config/plane-civilSmall", "config/plane-civilMedium", "config/plane-civilLarge"]

## 3.1 All-trials viz
### 3.1.1 All points

In [None]:
all_trials_df.head()

In [None]:
plot_px_line(
    all_trials_df.sort([GROUP, Y], descending=[True, False]), 
    x=X,
    y=Y,
    line_group=GROUP,
    color=GROUP,
    markers=True,
    hover_data=[X, Y, GROUP, *ADDITIONAL_HOVER_DATA],
    x_label="Recall (lvl1)",
    y_label="Precision (lvl1)",
    title="Precsion-recall graph"
)

### 3.1.2 Non-dominated points

In [None]:
all_trials_nd_df = keep_non_dominated_points(all_trials_df, [X, Y])

plot_px_line(
    all_trials_nd_df.sort([GROUP, Y], descending=[True, False]), 
    x=X,
    y=Y,
    line_group=GROUP,
    color=GROUP,
    markers=True,
    hover_data=[X, Y, GROUP, *ADDITIONAL_HOVER_DATA],
    x_label="Recall (lvl1)",
    y_label="Precision (lvl1)",
    title="Precsion-recall graph"
)

## 3.2 Filtering trials with regression for fine-grain viz
### 3.2.1 Filtering operation

In [None]:
MIN_VAL_FILTERING_DICT = {
    "level1.recall": 0.9, 
    "level1.precision":0.9,
}

filtered_df = filter_df_with_dict(all_trials_df, MIN_VAL_FILTERING_DICT)

In [None]:
filtered_df.head()

### 3.2.2 All points

In [None]:
plot_px_line(
    filtered_df.sort([GROUP, Y], descending=[True, False]), 
    x=X,
    y=Y,
    line_group=GROUP,
    color=GROUP,
    markers=True,
    hover_data=[X, Y, GROUP, *ADDITIONAL_HOVER_DATA],
    x_label="Recall (lvl1)",
    y_label="Precision (lvl1)",
    title="Precsion-recall graph"
)

### 3.2.3 Non-dominated points

In [None]:
filtered_nd_df = keep_non_dominated_points(filtered_df, [X, Y])

plot_px_line(
    filtered_nd_df.sort([GROUP, Y], descending=[True, False]), 
    x=X,
    y=Y,
    line_group=GROUP,
    color=GROUP,
    markers=True,
    hover_data=[X, Y, GROUP, *ADDITIONAL_HOVER_DATA],
    x_label="Recall (lvl1)",
    y_label="Precision (lvl1)",
    title="Precsion-recall graph"
)