In [17]:
import sys
import os
import numpy as np
import netCDF4 as nc
from scipy import ndimage as nd

def unmask_2d(var, mask, missing_value):
    
    if mask is None:
        assert missing_value is not None
        mask = np.zeros_like(var.data)
        mask[np.where(var.data == missing_value)] = 1

    ind = nd.distance_transform_edt(mask[:, :],
                                    return_distances=False,
                                    return_indices=True)

    var[:, :] = var[tuple(ind)]
    print('2d done', flush=True)

def unmask_3d(v, mask, missing_value):
    for t in range(v.shape[0]):
        unmask_2d(v[t, :], mask, missing_value)

def unmask_4d(v, mask, missing_value):
    for t in range(v.shape[0]):
        unmask_3d(v[t, :], mask, missing_value)

def unmask_file(filename, mask=None, missing_value=None, skip_vars=[]):
    with nc.Dataset(filename, 'r+') as f:
        for v in f.variables:
            if v in skip_vars or v.startswith('atm'):
                continue
                
            print(f"Unmasking variable: {v}")
            var = f.variables[v][:]
            print(f"Unmasking 2D: var shape = {var.shape}, mask shape = {mask.shape}")
            if mask is None and missing_value is None:
                missing_value = var.fill_value

            if len(var.shape) == 4:
                unmask_4d(var, mask, missing_value)
                f.variables[v][:] = var[:]
            elif len(var.shape) == 3:
                unmask_3d(var, mask, missing_value)
                f.variables[v][:] = var[:]
            elif len(var.shape) == 2:
                unmask_2d(var, mask, missing_value)
                f.variables[v][:] = var[:]
            else:
                print('WARNING: not unmasking {} because it is 1d'.format(v))

    return 0

def apply_mask_2d(v, landmask, mask_val):
    v[np.where(landmask)] = mask_val

def apply_mask_3d(v, landmask, mask_val):
    for d in range(v.shape[0]):
        apply_mask_2d(v[d, :], landmask, mask_val)

def apply_mask_4d(v, landmask, mask_val):
    for t in range(v.shape[0]):
        apply_mask_3d(v[t, :], landmask, mask_val)

def apply_mask_file(filename, mask, mask_val=0.0, skip_vars=[]):
    with nc.Dataset(filename, 'r+') as f:
        for v in f.variables:
            if v in skip_vars or v.startswith('atm'):
                continue

            var = f.variables[v][:]

            if len(var.shape) == 4:
                apply_mask_4d(var, mask, mask_val)
            elif len(var.shape) == 3:
                apply_mask_3d(var, mask, mask_val)
            elif len(var.shape) == 2:
                apply_mask_2d(var, mask, mask_val)
            else:
                print('WARNING: not applying mask {} because it is 1d'.format(v))
            f.variables[v][:] = var[:]
    return 0

def main():
    input_file = '/scratch/tm70/ek4684/access-om3/archive/test_restart_nan/restart000/access-om3.cpl.r.1945-01-01-00000.nc'
    mask_file = '/g/data/tm70/ek4684/topog_edits_Persian_Gulf_red_sea/kmt.nc'
    mask_var = 'kmt'
    missing_value = 1e30

    with nc.Dataset(mask_file) as f:
        mask = np.array(f.variables[mask_var][:], dtype=bool)
        mask = ~mask

    skip_vars = ['time','time_bnds','start_ymd', 'start_tod','curr_ymd', 'curr_tod', 'ocnExpAccum_cnt']
    unmask_file(input_file, mask, missing_value, skip_vars=skip_vars)
    apply_mask_file(input_file, mask, mask_val=0.0, skip_vars=skip_vars)

if __name__ == '__main__':
    sys.exit(main())

Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1440)

2d done
Unmasking 2D: var shape = (1, 1142, 1440), mask shape = (1142, 1

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
