## Unit Testing for the Evaluator 

- Test that the .update() functionality works
- Test that the .finalize() functionality works

- Test 
    * MSE
    * Object Matching 
    * PMM storm structure 
    * Spectra 


In [1]:
import sys, os
package_path = os.path.dirname(os.path.dirname(os.getcwd())) 
sys.path.insert(0, package_path)

from wofscast.evaluate.metrics import (MSE,
                                       ObjectBasedContingencyStats,
                                       PowerSpectra,
                                       FractionsSkillScore,
                                       PMMStormStructure,
                                       )
import numpy as np
import xarray as xr

### MSE Test

In [2]:
# Test the MSE class with batch size of 2
def test_mse_with_batches():
    # Create a fake dataset for forecast and truth with batch dimension
    # Dimensions: ('batch', 'time', 'lat', 'lon')
    batch_size = 2
    time = np.arange(5)  # 5 time steps
    lat = np.arange(2)  # 2 lat points
    lon = np.arange(2)  # 2 lon points

    forecast_data = np.ones()
    
    
    
    # Create known forecast and truth data with batch size of 2
    forecast_data = np.array([[[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]],
                              [[[2, 3], [4, 5]], [[2, 3], [4, 5]], [[2, 3], [4, 5]], [[2, 3], [4, 5]], [[2, 3], [4, 5]]]])

    truth_data = np.array([[[[1, 1], [3, 3]], [[1, 1], [3, 3]], [[1, 1], [3, 3]], [[1, 1], [3, 3]], [[1, 1], [3, 3]]],
                           [[[2, 2], [4, 4]], [[2, 2], [4, 4]], [[2, 2], [4, 4]], [[2, 2], [4, 4]], [[2, 2], [4, 4]]]])

    # Create xarray Datasets for forecast and truth with the 'batch' dimension
    forecast = xr.Dataset({
        'var1': (['batch', 'time', 'lat', 'lon'], forecast_data),
    }, coords={'batch': np.arange(batch_size), 'time': time, 'lat': lat, 'lon': lon})

    truth = xr.Dataset({
        'var1': (['batch', 'time', 'lat', 'lon'], truth_data),
    }, coords={'batch': np.arange(batch_size), 'time': time, 'lat': lat, 'lon': lon})

    # Initialize the MSE class
    mse_calculator = MSE()

    # Perform 2 update calls for batch=0 and batch=1
    mse_calculator.update(forecast.isel(batch=0), truth.isel(batch=0))
    mse_calculator.update(forecast.isel(batch=1), truth.isel(batch=1))

    # Finalize and get the RMSE
    rmse_results = mse_calculator.finalize()

    # Print the results
    print("Final RMSE Results:")
    print(rmse_results)

    # Manually compute the expected RMSE across both batches
    expected_rmse = np.sqrt(np.mean(np.concatenate([(forecast_data[0] - truth_data[0]) ** 2, 
                                                    (forecast_data[1] - truth_data[1]) ** 2]), axis=(0, 1, 2)))

    # Assert that the computed RMSE matches the expected RMSE
    np.testing.assert_almost_equal(rmse_results['var1_rmse'].values, expected_rmse, decimal=5)
    print("Test passed: RMSE with batch dimension is correct.")

# Run the test
test_mse_with_batches()


AttributeError: 'MSE' object has no attribute 'results_'

### Object Matching Test

In [None]:
# Test the ObjectBasedContingencyStats class with a larger area (10x10)
def test_object_based_metrics_larger_area():
    # Create a fake dataset for forecast and truth with storms
    # Dimensions: ('batch', 'time', 'lat', 'lon')
    batch_size = 2
    time = np.arange(3)  # 3 time steps
    lat = np.arange(10)  # 10 lat points
    lon = np.arange(10)  # 10 lon points

    # Designed for 1 hit and false per batch update, repeated 
    # for each time step. 
    n_expected_hits = len(time)*batch_size 
    n_expected_false_alarms = len(time)*batch_size
    # For misses, there are 2 misses per batch updated, 
    # repeated for each time step. 
    n_expected_misses = len(time)*2*batch_size
    
    # Create labelled regions for storms in forecast and truth
    # Batch 0: 1 hit, 1 false alarm, 1 miss
    forecast_data_batch_0 = np.zeros((3, 10, 10))
    truth_data_batch_0 = np.zeros((3, 10, 10))

    # Label storms in batch 0
    # Forecast storm 1 matches truth, forecast storm 2 is a false alarm, truth storm 3 is a miss
    forecast_data_batch_0[:, 1:4, 1:4] = 1  # Storm 1 (hit)
    forecast_data_batch_0[:, 6:8, 6:8] = 2  # Storm 2 (false alarm)
    
    truth_data_batch_0[:, 1:4, 1:4] = 1     # Storm 1 (hit)
    truth_data_batch_0[:, 7:9, 1:3] = 3     # Storm 3 (miss)
    truth_data_batch_0[:, 9:, 1:3] = 9     # Storm 4 (miss)
    
    
    # Batch 1: Different configuration with hits, false alarms, and misses
    forecast_data_batch_1 = np.zeros((3, 10, 10))
    truth_data_batch_1 = np.zeros((3, 10, 10))

    # Label storms in batch 1
    forecast_data_batch_1[:, 2:5, 2:5] = 4  # Storm 4 (hit)
    forecast_data_batch_1[:, 5:7, 7:9] = 5  # Storm 5 (false alarm)
    truth_data_batch_1[:, 2:5, 2:5] = 4     # Storm 4 (hit)
    truth_data_batch_1[:, 6:8, 1:4] = 6     # Storm 6 (miss)
    truth_data_batch_1[:, 9:, 1:3] = 7     # Storm 4 (miss)

    # Create xarray Datasets for forecast and truth with the 'batch' dimension
    forecast = xr.Dataset({
        'storms': (['batch', 'time', 'lat', 'lon'], np.stack([forecast_data_batch_0, forecast_data_batch_1]))
    }, coords={'batch': np.arange(batch_size), 'time': time, 'lat': lat, 'lon': lon})

    truth = xr.Dataset({
        'storms': (['batch', 'time', 'lat', 'lon'], np.stack([truth_data_batch_0, truth_data_batch_1]))
    }, coords={'batch': np.arange(batch_size), 'time': time, 'lat': lat, 'lon': lon})

    # Initialize the ObjectBasedContingencyStats class
    obj_stats = ObjectBasedContingencyStats(matching_dist=1)

    # Perform 2 update calls for batch=0 and batch=1
    obj_stats.update(forecast.isel(batch=0), truth.isel(batch=0))
    obj_stats.update(forecast.isel(batch=1), truth.isel(batch=1))

    # Finalize and get the results
    results = obj_stats.finalize()

    # Print the results
    print("Final Object-Based Metrics Results:")
    print(results)

    # Perform basic assertions
    assert results['wofscast_vs_wofs_hits'].sum() == n_expected_hits, f"Number of hits should be {n_expected_hits}"
    assert results['wofscast_vs_wofs_false_alarms'].sum() == n_expected_false_alarms, f"Number of false alarms should be {n_expected_false_alarms}"
    assert results['wofscast_vs_wofs_misses'].sum() == n_expected_misses, f"Number of misses should be {n_expected_misses}"
    
    # POD
    expected_pod = n_expected_hits / (n_expected_hits + n_expected_misses)
    pod = results['wofscast_vs_wofs_pod'].isel(time=0).values
    np.testing.assert_almost_equal(pod, expected_pod, decimal=5)
    
    # SR
    expected_sr = n_expected_hits / (n_expected_hits + n_expected_false_alarms)
    sr = results['wofscast_vs_wofs_sr'].isel(time=0).values
    np.testing.assert_almost_equal(sr, expected_sr, decimal=5)
    
    # CSI
    expected_csi = n_expected_hits / (n_expected_hits + n_expected_false_alarms + n_expected_misses)
    csi = results['wofscast_vs_wofs_csi'].isel(time=0).values
    np.testing.assert_almost_equal(csi, expected_csi, decimal=5)
    
    # FB
    expected_fb = expected_pod / expected_sr
    fb = results['wofscast_vs_wofs_fb'].isel(time=0).values
    np.testing.assert_almost_equal(fb, expected_fb, decimal=5)
    
    print("Test passed: Object-based metrics are correct for larger area.")

    return forecast, truth
    
# Run the test
forecast, truth = test_object_based_metrics_larger_area()

In [None]:
truth['storms'].isel(batch=0, time=0).plot()

In [None]:
forecast['storms'].isel(batch=1, time=0).plot()

### Power Spectra

In [None]:
def test_power_spectra_with_constant_data():
    # Dimensions
    batch_size = 1
    time_steps = 3
    levels = 1
    lat_points = 10
    lon_points = 10

    # Create constant forecast and truth data
    constant_forecast = np.ones((batch_size, time_steps, levels, lat_points, lon_points))
    constant_truth = np.ones((batch_size, time_steps, levels, lat_points, lon_points))

    # Create xarray Datasets for forecast and truth
    forecast = xr.Dataset({
        'var1': (['batch', 'time', 'level', 'lat', 'lon'], constant_forecast)
    }, coords={'batch': np.arange(batch_size), 'time': np.arange(time_steps), 
               'level': np.arange(levels), 'lat': np.arange(lat_points), 'lon': np.arange(lon_points)})

    truth = xr.Dataset({
        'var1': (['batch', 'time', 'level', 'lat', 'lon'], constant_truth)
    }, coords={'batch': np.arange(batch_size), 'time': np.arange(time_steps), 
               'level': np.arange(levels), 'lat': np.arange(lat_points), 'lon': np.arange(lon_points)})

    # Initialize the PowerSpectra class
    spectra_calculator = PowerSpectra(grid_spacing_in_km=3.0, variables=['var1'], level=0)

    # Call update for the first batch
    spectra_calculator.update(forecast.isel(batch=0), truth.isel(batch=0))

    # Finalize and get the results
    results = spectra_calculator.finalize()

    # Since the input is constant, we expect flat spectra
    forecast_spectra = results['var1_forecast_spectra'].values
    truth_spectra = results['var1_truth_spectra'].values

    # Assert the spectra are flat (constant or zero)
    assert np.allclose(forecast_spectra, forecast_spectra[0]), "Forecast spectra should be flat"
    assert np.allclose(truth_spectra, truth_spectra[0]), "Truth spectra should be flat"
    print("Test passed: Constant data spectra are correct.")

# Run the test
test_power_spectra_with_constant_data()
