In [3]:
from parcels import grid

In [None]:
class CurvilinearZGrid(CurvilinearGrid):
    """Curvilinear Z Grid.

    :param lon: 2D array containing the longitude coordinates of the grid
    :param lat: 2D array containing the latitude coordinates of the grid
    :param depth: Vector containing the vertical coordinates of the grid, which are z-coordinates.
           The depth of the different layers is thus constant.
    :param time: Vector containing the time coordinates of the grid
    :param time_origin: Time origin (TimeConverter object) of the time axis
    :param mesh: String indicating the type of mesh coordinates and
           units used during velocity interpolation:

           1. spherical (default): Lat and lon in degree, with a
              correction for zonal velocity U near the poles.
           2. flat: No conversion, lat/lon are assumed to be in m.
    """
    
    def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh='flat'):
        super(CurvilinearZGrid, self).__init__(lon, lat, time, time_origin, mesh)
        assert(isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4]), 'depth is not a 4D numpy array'

        self.gtype = GridCode.CurvilinearZGrid
        self.depth = depth
        self.zdim = self.depth.shape[-3]
        self.z4d = 1 if len(self.depth.shape) == 4 else 0
        if self.z4d:
            # self.depth.shape[0] is 0 for S grids loaded from netcdf file
            assert self.tdim == self.depth.shape[0] or self.depth.shape[0] == 0, 'depth dimension has the wrong format. It should be [tdim, zdim, ydim, xdim]'
            assert self.xdim == self.depth.shape[-1] or self.depth.shape[-1] == 0, 'depth dimension has the wrong format. It should be [tdim, zdim, ydim, xdim]'
            assert self.ydim == self.depth.shape[-2] or self.depth.shape[-2] == 0, 'depth dimension has the wrong format. It should be [tdim, zdim, ydim, xdim]'
        else:
            assert self.xdim == self.depth.shape[-1], 'depth dimension has the wrong format. It should be [zdim, ydim, xdim]'
            assert self.ydim == self.depth.shape[-2], 'depth dimension has the wrong format. It should be [zdim, ydim, xdim]'
        if not self.depth.dtype == np.float32:
            self.depth = self.depth.astype(np.float32)