In [14]:
import os
import tempfile
from pathlib import Path

import mlflow
import pandas as pd
import torch
from azureml.fsspec import AzureMachineLearningFileSystem
from torch import nn

from src.azure_wrap.blob_storage_sdk_v2 import DATASTORE_URI
from src.training.loss_functions import TwoPartLoss
from src.training.transformations import ConcatenateSnapshots
from src.utils.parameters import MAIN_BANDS, S2_BANDS, SNAPSHOTS, TEMPORAL_BANDS, SatelliteID
from src.validation.metrics import FalseMetrics, TrueMetrics
from src.validation.validation_metrics import (
    all_error_analysis_plots,
    data_preparation,
    diff_plots,
    prep_predictions_for_plot,
)

### Load the run parameters and metrics

In [15]:
# the run ID (job name) can be copy-pasted from the job's page in the Azure ML studio
run_id = "mango_chaconia_xshmtkk1yx"

# get the runs parameters
run = mlflow.get_run(run_id=run_id)
params = run.data.params

BINARY_THRESHOLD = float(params["binary_threshold"])
MSE_MULTIPLIER = float(params["MSE_multiplier"])
model_identifier = params["model_name"]
# note, the validation dataset (and training / test) aren't the actual URIs to the data
# since we have AML download the data to disk the paths get converted to a local path
# we can get the actual paths from the training config if needed.
validation_datasets = [params[x] for x in params.keys() if "validation_dataset" in x]

# grab the parent folder from the first dataset so we can glob all the validation data
validation_uri = Path(validation_datasets[0]).parent.as_posix()

In [None]:
# Download run metrics into a temporary directory so we can load into memory
with tempfile.TemporaryDirectory() as dst_path:
    mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path="metrics_per_crop.parquet", dst_path=dst_path)
    metrics_df = pd.read_parquet(os.path.join(dst_path, "metrics_per_crop.parquet"))

metrics_df

In [17]:
lossFn = TwoPartLoss(binary_threshold=BINARY_THRESHOLD, MSE_multiplier=MSE_MULTIPLIER)
band_concatenator = ConcatenateSnapshots(
    snapshots=SNAPSHOTS,
    s2_bands=S2_BANDS,
    temporal_bands=TEMPORAL_BANDS,
    main_bands=MAIN_BANDS,
    satellite_id=SatelliteID.S2,
)
fs = AzureMachineLearningFileSystem(DATASTORE_URI)

validation_dataset = data_preparation(validation_uri, band_concatenator=band_concatenator, filesystem=fs)

### Load the model

In [None]:
# Load and prep model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = mlflow.pytorch.load_model(model_identifier, map_location=device)
if isinstance(model, nn.DataParallel):
    model = model.module  # if it's wrapped in DataParallel, unwrap it

model = model.to(device)

### Generate Plots

In [None]:
print("Possible metrics to sort by")
list([k.value for k in TrueMetrics] + [k.value for k in FalseMetrics])

In [None]:
# let's plots come crops (chips)
num_worst_crops = 10
probability_threshold = 0.25
sorting_metric = "false_negatives"
ascending = sorting_metric not in set(FalseMetrics)
metrics_df = metrics_df.sort_values(by=sorting_metric, ascending=ascending)

for i in range(num_worst_crops):
    crop = metrics_df.iloc[i]
    index = (crop.partition, crop.row)
    pred = prep_predictions_for_plot(model, validation_dataset.dataset, index, lossFn, probability_threshold)
    fig = all_error_analysis_plots(probability_threshold=probability_threshold, **pred)
    fig_simple = diff_plots(probability_threshold=probability_threshold, **pred)