# Demo notebook to evaluate pre-trained model from ML TC detection & tracking

First import relevant libraries to run the inference and define common info

In [None]:
import sys
import warnings
warnings.filterwarnings('ignore')

sys.path.append('../resources/library')
from tropical_cyclone.inference import SingleModelInference, get_observations, get_observed_tracks
from tropical_cyclone.visualize import plot_detections, plot_tracks
from tropical_cyclone.models import *
from tropical_cyclone.mlflow import load_model_from_mlflow, set_mlflow_endpoint, check_backend

In [None]:
main_dir = '/home/jovyan/work/ml-tropical-cyclones-detection/'

# path to dataset directory (if CMIP6 data must be in the proper grid format)
dataset_dir = f'/home/jovyan/data/ERA5_Cyclones'
# path to IBTrACS file to match ML model detections
ibtracs_src = f'{main_dir}/data/ibtracs/filtered/ibtracs_main-tracks_6h_1980-2021_TS-NR-ET-MX-SS-DS.csv'
# path to configutation file for the model
config_file = f'{main_dir}/src/config/cnns.toml'

# define lat and lon ranges
lat_range = (0,70)
lon_range = (100,320)

Select the model by specfying the run name from the MLFlow and download model, scaler and provenance document

In [None]:
set_mlflow_endpoint(config_file)
run_name=input()
registered_model, path = load_model_from_mlflow(run_name, provenance=True)
registered_model 

## Inference Workflow

Let's create the ML model object and get the data on a given time frame (year and month) for the evaluation

In [None]:
device = check_backend()
inference = SingleModelInference(model=registered_model, config_file=config_file, device=device)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

month = widgets.Dropdown(
    options=[('Jan', '01'), ('Feb', '02'), ('Mar', '03'),
             ('Apr', '04'), ('May', '05'), ('Jun', '06'),
             ('Jul', '07'), ('Aug', '08'), ('Sep', '09'),
             ('Oct', '10'), ('Nov', '11'), ('Dec', '12'),
             ('ALL', None)
            ], value=None, description='Month:', disabled=False,)

year = widgets.IntSlider(
    value=2014, min=1980, max=2021, step=1,
    description='Year:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

y = 2014
m = 8

ui = widgets.HBox([year, month])
def f(a, b):
    global y
    y = a
    global m 
    m = b

out = widgets.interactive_output(f, {'a': year, 'b': month})

display(ui)

In [None]:
ds, dates = inference.load_dataset(dataset_dir=dataset_dir, year=y, month=m)

We can now detect and localize the TC centers with the ML model

In [None]:
detections = inference.predict(ds, patch_size=40)

And load also the observed TCs

In [None]:
observations = get_observations(ibtracs_src=ibtracs_src, dates=dates, lat_range=lat_range, lon_range=lon_range)

## Apply Tracking Algorithm
Apply the tracking algorithm to link the TC centers and get the different tracks

In [None]:
det_tracks = inference.tracking(detections, max_distance=400.0, min_track_count=12)
obs_tracks = get_observed_tracks(observations)

## Plot detections

In [None]:
plot_detections(detections, observations, lat_range, lon_range)

## Plot tracks

In [None]:
plot_tracks(det_tracks, obs_tracks, lat_range, lon_range)