In [None]:
import xarray as xr
import numpy as np

In [None]:
input_dir = '/g/data/er4/vd5822/files_to_check/mrnbc_speckling/input/'
output_dir = '/g/data/er4/vd5822/files_to_check/mrnbc_speckling/output/'

file_name_base = 'pr_mrnbc_CNRM-CM5_rcp85.nc4'
#file_name_base = 'tasmin_mrnbc_CNRM-CM5_rcp85.nc4'
input_file = input_dir + file_name_base
output_file = output_dir + file_name_base

In [None]:
ds = xr.open_dataset(input_file)

In [None]:
var_name = 'pr'
threshold = 0.009259
# Known speckled grid cells for pr CNRM-CM5 (time,lat,lon)
# 0,6,145
# 0,455,675
# 0,456,675

#var_name = 'tasmin'
#threshold = 333.1
# Known speckled grid cells for pr CNRM-CM5 (time,lat,lon)
# 0,6,145
# 0,455,675
# 0,456,675

# This location (0,417,260) is already NaN?

In [None]:
# Remove the specks.
# This where() statement keeps everything below the threshold,
# and anything equal or above will be NaN.
specks_removed = ds.where(ds < threshold)

In [None]:
# Now that we have the specks removed and replaced with NaNs, we can
# use interpolate_na() to interpolate just the NaN cells, using surrounding
# cells.
# interpolate_na() only works on a single dimension.
# To do bilinear filtering, we have to interpolate in both
# lat and lon dimensions, then combine the result
interpolated_on_lat = specks_removed.interpolate_na(dim='lat')
interpolated_on_lon = specks_removed.interpolate_na(dim='lon')
interpolated = (interpolated_on_lat + interpolated_on_lon) * 0.5

In [None]:
# We can now write the data with interpolated cells
interpolated.to_netcdf(output_file)

In [None]:
# Each cell in the interpolated data was modified, though unaffected cells
# should maintain their original value (doubled, then halved).
# If we just wanted to apply the interpolated values to the affected
# cells, we can use where() again
output = xr.where(ds < threshold, ds, interpolated)

In [None]:
def naive_method():
    # This is a naive method where each grid cell at time=0 is looped over,
    # finding which cell is over the threshold, then performing an interp.
    # It's placed in a function for the sole purpose of avoid being run
    # when "run all cells" is used.
    lats = ds['lat'].data
    lons = ds['lon'].data

    for lati,lat in np.ndenumerate(lats):
        for loni,lon in np.ndenumerate(lons):
            if ds[var_name][0,lati[0],loni[0]].data > threshold:
                d = specks_removed.interp(lat=lat, lon=lon)[var_name].data
                ds[var_name][:,lati[0],loni[0]] = d