In [28]:
import xarray as xr
import xarray
import dask 
import os

In [2]:
# Rather an a strict MSE, for the rainfall, reflectivity, UH, we want to penalize specific behavior 

In [3]:
%%time
in_path = '/work/mflora/wofs-cast-data/train_datasets'
inputs = xr.open_dataset(os.path.join(in_path, 'train_inputs.nc'), chunks={})
inputs = inputs.isel(batch=slice(0, 2))
inputs = dask.compute(inputs)

CPU times: user 17.6 s, sys: 7.48 s, total: 25.1 s
Wall time: 1min 20s


In [6]:
inputs = inputs[0]

In [11]:
predict_vars = ['U', 'V', 'W', 'T', 'T2', 'COMPOSITE_REFL_10CM', 
                'UP_HELI_MAX', 'RAINNC'
               ]

In [17]:
prediction = inputs[predict_vars].isel(time=0)
target = inputs[predict_vars].isel(time=1)

In [37]:
def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
    return x.mean([d for d in x.dims if d != 'batch'], skipna=False)


def threshold_tuned_loss(target, prediction, underpredict_threshold, overpredict_threshold, 
                underpredict_penalty, overpredict_penalty):
    """
    Custom loss function that penalizes underpredictions above a certain threshold and overpredictions below a certain threshold.

    Parameters:
    - target: xarray.DataArray representing the true target values.
    - prediction: xarray.DataArray representing the predicted values.
    - underpredict_threshold: Threshold above which underpredictions are penalized.
    - overpredict_threshold: Threshold below which overpredictions are penalized.
    - underpredict_penalty: Penalty multiplier for underpredictions.
    - overpredict_penalty: Penalty multiplier for overpredictions.

    Returns:
    - loss: The calculated loss as an xarray.DataArray.
    """
    # Calculate basic error
    error = prediction - target
    
    # Identify where to apply underprediction penalty
    underpredict_mask = (target > underpredict_threshold) & (error < 0)
    underpredict_loss = error.where(underpredict_mask) * underpredict_penalty
    
    # Identify where to apply overprediction penalty
    overpredict_mask = (target < overpredict_threshold) & (error > 0)
    overpredict_loss = error.where(overpredict_mask) * overpredict_penalty
    
    # Combine losses, treating non-penalized errors as normal (error^2)
    combined_loss = xr.where(underpredict_mask | overpredict_mask, 
                             underpredict_loss + overpredict_loss, error**2)
    
    return combined_loss

def custom_loss(predictions, targets): 
    
    custom_loss_params = {'COMPOSITE_REFL_10CM': {'underpredict_threshold' : 30., # dBZ 
                                                   'overpredict_threshold' : 15., # dBZ
                                                   'underpredict_penalty'  : 10, 
                                                   'overpredict_penalty' : 10, 
                                                  }, 
                           'UP_HELI_MAX':     {'underpredict_threshold' : 50., # UH units
                                                   'overpredict_threshold' : 5., # UH units
                                                   'underpredict_penalty'  : 10, 
                                                   'overpredict_penalty' : 10, 
                                                  },
                           'RAINNC':         {'underpredict_threshold' : 5.0, # mm 
                                                   'overpredict_threshold' : 0.25, # mm
                                                   'underpredict_penalty'  : 10, 
                                                   'overpredict_penalty' : 10, 
                                                  },
                   
                         }

    all_data_vars = list(targets.data_vars)
    
    custom_loss_vars = custom_loss_params.keys()
    
    standard_vars = [item for item in all_data_vars if item not in custom_loss_vars]
    
    mse_loss = _mean_preserving_batch((prediction[standard_vars] - target[standard_vars])**2)
    
    custom_loss_total=0
    for var, params in custom_loss_params.items():
        custom_loss_total += _mean_preserving_batch(threshold_tuned_loss(targets[var],
                                                                         predictions[var], **params))
    
    
    return mse_loss, custom_loss_total
    

In [38]:
mse_loss, custom_loss = custom_loss(prediction, target)