# Evaluation Examples
Once the model has been trained, we need to do some proper evaluation!

Doing so involves a few steps! The first being actually getting our model after we've trained it!

**IMPORTANT: DO NOT run any of the below cells as they should run on the GPU! The use of a notebook is purely for organizational!**

## Acquiring Model and Data (post-training)
We do this using MLFlow so make sure that you have set that up and trained your model on a run. 
All of the python scripts in the `runs/` directory should contain functionality at the bottom on how they ran MLFlow during training.
Feel free to use them as a reference!

To get our model and data we'll need to find the run ID of the run itself. 
This can be found by opening the run in the MLFlow UI and copying the field titled "run ID".

For example, I have a run that trained on CoSMIR-H data to predict temperature and I know it's ID is: `0aded8e7916b43a889ea076604236b13`.
From here I can run the following code to retrieve the model and data module used to train it:

In [None]:
import mlflow
from hympi_ml.utils import mlf
from hympi_ml.model import MLPModel
from hympi_ml.data.batches import RawDataModule

# MLFlow setup
# NOTE: this tracking uri may be specific to where YOUR runs are stored! Modify accordingly!
tracking_uri = "/explore/nobackup/people/dgershm1/mlruns"
mlflow.set_tracking_uri(tracking_uri)

# set the id of the run we want and get the path for the best checkpoint
run_id = "166a8b9bb76840769b63be777a0eae30"
cpath = mlf.get_checkpoint_path(run_id)

# get the mod and training data from the checkpoint
model = MLPModel.load_from_checkpoint(cpath)
raw_data = RawDataModule.load_from_checkpoint(cpath)

Quite simple! As you can see above, we first need to define where our runs come from (specifically the file where the `mlruns/` directory is located) and the run ID which I described earlier. 

From there, we get the path of the best performing checkpoint of the model by default from that run ID. This is automatically extracted for you as the checkpoints are saved after each epoch during model training.

Then, that checkpoint path is used to load up the model that was used for that run `MLPModel` and the data in the form of the `RawDataModule`, more details can be found in the data and model example notebook for what these are.

## Defining New Test Metrics
Now that we have this data, we can do pretty much whatever we want but we'll start with calculating some new metrics:

In [None]:
from torchmetrics import MetricCollection
from torchmetrics.regression import MeanAbsoluteError

model.log_metrics = False  # important to avoid logging vector metrics
model.unscale_metrics = True  # automatically unscales

# redefine test metrics for the model
model.test_metrics = {
    "TEMPERATURE": MetricCollection(
        {
            "mae_profile": MeanAbsoluteError(num_outputs=72),
        }
    )
}

In the above code, we do a bit of set up in order to get our test metrics calculating in the proper manner. This is different from validation or train metrics in the sense that we can re-define our test metrics and use new ones for further analysis. Further, we want to unscale the values that are used during metrics if scaling is used for our targets. In this case, the temperature target is scaled using a MinMaxScaler from 175 - 325 degrees K. The values that come out of this target may not be very useful to us directly and thus, it's a good idea to tell the model to unscale the metrics using the `unscale_metrics` boolean and setting it to `True` as it's `False` by default for training / validation.

Note: We also disable metric logging as we will be doing this ourselves. This is due to the fact that our new metric mae_profile" is a vector metric of 72 values and MLFlow or PyTorch lightning accept only scalar metrics for logging purposes.

## Running the Test
Now, we simply create a PyTorch Lightning trainer and use the `test(...)` method to test on our metrics using the test dataset defined in our data module!

In [None]:
import lightning as L

# set up trainer and run test!
trainer = L.Trainer(enable_progress_bar=True)
trainer.test(model, dataloaders=raw_data.test_dataloader())

The above runs the test and once complete, since we've disabled logging, won't actually do anything other than update the values in the metrics we've defined above.

## Saving Figures for Profile Metrics
Thus, we'll need to extract the information from each of them if we'd like to log a figure, for example.

The following code does just that for the "mae_profile" temperature metric:

In [None]:
import matplotlib.pyplot as plt
from hympi_ml.evaluation import figs

# NOTE: The code below works assuming all new metrics are able to be turned into figures based on sigma levels!

# Iterate through the metric collections in our model's test metrics
for target_name, collection in model.test_metrics.items():
    computed_metrics = collection.compute() # compute the metrics in the collection
    
    # Iterate through each computed metric in this target's collection
    for metric_name, metric in computed_metrics.items():
        # Use the figure style defined in 'figs' to plot the value from the metric
        figs.plot_profiles({metric_name: metric.cpu()}, value_axis=metric_name)
        
        # Save the figure to a temporary path
        local_path = f"/tmp/{target_name}_{metric_name}.png"
        plt.savefig(local_path)
        
        # Log the temporarily saved figure as an artifact in the run defined above (from the run_id)
        mlflow.log_artifact(local_path, "metric_figures", run_id=run_id)

With the above code, we can take the metrics we've computed above and plot them directly to our mlflow run as figures to be analyzed!

## Working with Extras: Filtering
Every `ModelDataSpec` has room to define an optional set of 'extras', that is, extra information that is passed alongside the features and targets which should be used for exactly what we are here for: evaluation! One way to use them is for modifying our test dataset so that we run our above tests on a smaller set of data using filters.

Since filters are applied to the features and targets from the extras as well, we can use this to our advantage to define filters to limit our dataset even if our model didn't use for the extras for training!

Here's how we can do it for a PBLH filter:

In [None]:
from hympi_ml.data import NRSpec
from hympi_ml.data.filter import SimpleRangeFilter

extras = {
    "PBLH": NRSpec(
        dataset="PBLH",
        filter=SimpleRangeFilter(minimum=200, maximum=500),
    ),
}

raw_data.spec.extras = extras
model.spec.extras = extras

It's quite simple! We just define a new set of extras as a familiar dictionary of `DataSpec` and we make sure to set the filter value for the `PBLH` data spec.
This time, it's a simple range filter from 200 - 500 meters.

After that, we set the extras for *both* specs in the data and the model itself. This way it actually get's applied when we use the trainer to test again.

Now, if we re-run our test we'll be checking Mean Absolute Error across the temperature profile but specifically for values of relatively low PBLH between 200 and 500 meters! The great thing is that you can easily define more extras beyond just PBLH and use multiple extras to filter for specific locations or atmospheric conditions!

## Working with Extras: Iterative Analysis
What if you'd like to do analysis by iterating through each of the values in the test dataset and performing some kind of metric or other analysis calculation. 

Let's say you'd like to find the error for each batch for, as an example, mapping to a geographical plot of earth with latitude and longitude.

Let's see an example of how you would set it up (just before actually mapping it):

In [None]:
from hympi_ml.data import NRSpec
from hympi_ml.data.filter import ResolutionFilter
from hympi_ml.utils import mlf
from hympi_ml.model import MLPModel
from hympi_ml.data.batches import RawDataModule

import mlflow

# MLFlow setup
# NOTE: this tracking uri may be specific to where YOUR runs are stored! Modify accordingly!
tracking_uri = "/explore/nobackup/people/dgershm1/mlruns"
mlflow.set_tracking_uri(tracking_uri)

# set the id of the run we want and get the path for the best checkpoint
run_id = "35056b02f1da4a6ca7e01d83f248a7bb"
cpath = mlf.get_checkpoint_path(run_id)

# get the mod and training data from the checkpoint
model = MLPModel.load_from_checkpoint(cpath)
raw_data = RawDataModule.load_from_checkpoint(cpath)

# Make the extras with latitude and longitude sampled at a lower resolution of 0.01 degrees
extras = {
    "LATITUDE": NRSpec(
        dataset="LATITUDE",
        filter=ResolutionFilter(resolution=0.01),
    ),
    "LONGITUDE": NRSpec(
        dataset="LONGITUDE",
        filter=ResolutionFilter(resolution=0.01),
    ),
}

raw_data.spec.extras = extras
model.spec.extras = extras

for raw_batch in raw_data.test_dataloader():
    # transforms the loaded raw batch data (filters and scales)
    features, targets, extras = raw_data.spec.transform_batch(raw_batch)

    unscaled_true = raw_data.spec.unscale_targets(targets)
    pred = model(features)
    unscaled_pred = raw_data.spec.unscale_targets(pred)
    
    pblh_error = unscaled_true["PBLH"] - unscaled_pred["PBLH"]
    
    lat = extras["LATITUDE"]
    lon = extras["LONGITUDE"]
    
    # We have our error and lat and lon for this entire batch, now we can plot it on a map or whatever we'd like!
    # The way you do this is up to you! 

As you can see, it's the same set up that was described above, except that we don't use a PyTorch Lightning `Trainer` to run the test. We instead iterate through each batch ourselves. What this menas is that we need to transform the batch (apply filtering and scaling to raw data), run it through the model, *and* unscale the truth and predictions to calculate the error as needed. This is done automatically when dealing with metrics using a trainer but for iterating directly like we are here, we need to do it manually.

As a result, we have lots of control over what data is passed around, allowing us to do all sorts of things such as placing the error on a map, etc.