## Generate OBC data for a regional subdomain with identical grid node locations

In [None]:
#### Load required modules
#### MIDAS can be doma

In [1]:

import xarray as xr
import numpy as np
from midas.rectgrid import *

In [2]:
#parent state
ds=xr.open_dataset('/archive/Matthew.Harrison/WenHao/NAtl/A01/19920101.ocean_hourly.nc')
sgrid=supergrid(file='NWAtl/ocean_hgrid.nc')
grid=quadmesh(supergrid=sgrid)

In [3]:
lonq=ds.xq.load().data
latq=ds.yq.load().data
i0=np.where(lonq>=grid.lonq[0])[0][0]
i1=np.where(lonq>grid.lonq[-1])[0][0]
j0=np.where(latq>=grid.latq[0])[0][0]
j1=np.where(latq>grid.latq[-1])[0][0]
print('q-node indices: ',i0,i1,j0,j1)

q-node indices:  0 169 61 259


In [4]:
seg_001_temp=ds['temp'].isel(xh=slice(i0,i1-1),yh=slice(j0-1,j0+1)).mean(axis=2,keepdims=True)
seg_001_salt=ds['salt'].isel(xh=slice(i0,i1-1),yh=slice(j0-1,j0+1)).mean(axis=2,keepdims=True)
seg_001_ssh=ds['ssh'].isel(xh=slice(i0,i1-1),yh=slice(j0-1,j0+1)).mean(axis=1,keepdims=True)
seg_001_u=ds['u'].isel(xq=slice(i0,i1),yh=slice(j0-1,j0+1)).mean(axis=2,keepdims=True)
seg_001_v=ds['v'].isel(xh=slice(i0,i1-1),yq=slice(j0,j0+1))

seg_002_temp=ds['temp'].isel(xh=slice(i1-1,i1+1),yh=slice(j0,j1-1)).mean(axis=3,keepdims=True)
seg_002_salt=ds['salt'].isel(xh=slice(i1-1,i1+1),yh=slice(j0,j1-1)).mean(axis=3,keepdims=True)
seg_002_ssh=ds['ssh'].isel(xh=slice(i1-1,i1+1),yh=slice(j0,j1-1)).mean(axis=2,keepdims=True)
seg_002_u=ds['u'].isel(xq=slice(i1,i1+1),yh=slice(j0,j1-1))
seg_002_v=ds['v'].isel(xh=slice(i1-1,i1+1),yq=slice(j0,j1)).mean(axis=3,keepdims=True)


In [5]:
seg_001_temp=seg_001_temp.rename({'yh':'lat','xh':'pos'})
seg_001_salt=seg_001_salt.rename({'yh':'lat','xh':'pos'})
seg_001_ssh=seg_001_ssh.rename({'yh':'lat','xh':'pos'})
seg_001_u=seg_001_u.rename({'yh':'lat','xq':'pos'})
seg_001_v=seg_001_v.rename({'yq':'lat','xh':'pos'})

seg_002_temp=seg_002_temp.rename({'yh':'pos','xh':'lon'})
seg_002_salt=seg_002_salt.rename({'yh':'pos','xh':'lon'})
seg_002_ssh=seg_002_ssh.rename({'yh':'pos','xh':'lon'})
seg_002_u=seg_002_u.rename({'yh':'pos','xq':'lon'})
seg_002_v=seg_002_v.rename({'yq':'pos','xh':'lon'})



In [7]:
params=[]

params.append({'suffix':'_segment_001','dim0':2,'temp_in':seg_001_temp,'salt_in':seg_001_salt,\
               'tr_out':'obc_ts_south.nc',\
               'u_in':seg_001_u,'v_in':seg_001_v,'u_out':'obc_u_south.nc','v_out':'obc_v_south.nc',\
               'ssh_in':seg_001_ssh,'ssh_out':'obc_ssh_south.nc'})

params.append({'suffix':'_segment_002','dim0':3,'temp_in':seg_002_temp,'salt_in':seg_002_salt,\
               'tr_out':'obc_ts_east.nc',\
               'u_in':seg_002_u,'v_in':seg_002_v,'u_out':'obc_u_east.nc','v_out':'obc_v_east.nc',\
               'ssh_in':seg_002_ssh,'ssh_out':'obc_ssh_east.nc'})

for pr in params:
    ds_temp=pr['temp_in']
    ds_salt=pr['salt_in']
    ds_u=pr['u_in']
    ds_v=pr['v_in']
    zl=ds_temp.z_l
    zi=0.5*(np.roll(zl,shift=-1)+zl)
    zi[-1]=6500.
    ds_temp['z_i']=zi
    dz=zi-np.roll(zi,shift=1)
    dz[0]=zi[0]
    ds_temp['dz']=dz
    nt=ds_temp.time.shape[0]
    nx=ds_temp.pos.shape[0]
    dz=np.tile(ds_temp.dz.data[np.newaxis,:,np.newaxis],(nt,1,nx))
    
    if pr['dim0']==3:
        dz=dz[:,:,:,np.newaxis]
        da_dz=xr.DataArray(dz,coords=[('time',ds_temp.time),('z_l',ds_temp.z_l),('pos',ds_temp.pos),('lon',ds_u.lon.data)])
    elif pr['dim0']==2:
        dz=dz[:,:,np.newaxis,:]
        da_dz=xr.DataArray(dz,coords=[('time',ds_temp.time),('z_l',ds_temp.z_l),('lat',ds_v.lat.data),('pos',ds_temp.pos)])
    ds_temp.time.attrs['modulo']=' '
    ds_salt.time.attrs['modulo']=' '
    da_temp=xr.DataArray(ds_temp.ffill(dim='pos',limit=None).ffill(dim='z_l').fillna(0.))
    da_salt=xr.DataArray(ds_salt.ffill(dim='pos',limit=None).ffill(dim='z_l').fillna(0.))
    ds_=xr.Dataset({'temp'+pr['suffix']:da_temp,'salt'+pr['suffix']:da_salt,\
                    'dz_temp'+pr['suffix']:da_dz,'dz_salt'+pr['suffix']:da_dz})
    for v in ds_:
        ds_[v].encoding['_FillValue']=1.e20    
    ds_['time'].encoding['_FillValue']=1.e20
    ds_['pos'].encoding['_FillValue']=1.e20
    ds_['pos'].encoding['_FillValue']=1.e20
    ds_['z_l'].encoding['_FillValue']=1.e20
    ds_=ds_.drop_vars('z_i')
    ds_=ds_.drop_vars('dz')
    
    if pr['dim0']==3:
        ds_['lon'].encoding['_FillValue']=1.e20
        ds_['lon'].attrs['cartesian_axis']='X'
        ds_['pos'].attrs['cartesian_axis']='Y'
    else:
        ds_['lat'].encoding['_FillValue']=1.e20
        ds_['lat'].attrs['cartesian_axis']='Y'
        ds_['pos'].attrs['cartesian_axis']='X'

    ds_.to_netcdf(pr['tr_out'],unlimited_dims=('time'))
    
    ds_u.time.attrs['modulo']=' '
    ds_v.time.attrs['modulo']=' '
    da_u=xr.DataArray(ds_u.ffill(dim='pos',limit=None).ffill(dim='z_l').fillna(0.))
    #da_u=da_u.expand_dims('dim_0',pr['dim0'])
    da_v=xr.DataArray(ds_v.ffill(dim='pos',limit=None).ffill(dim='z_l').fillna(0.))
    #da_v=da_v.expand_dims('dim_0',pr['dim0'])
    if pr['dim0']==3:
        dzu=np.tile(ds_temp.dz.data[np.newaxis,:,np.newaxis,np.newaxis],(nt,1,nx,1))
        dzv=np.tile(ds_temp.dz.data[np.newaxis,:,np.newaxis,np.newaxis],(nt,1,nx+1,1))
        da_dzu=xr.DataArray(dzu,coords=[('time',ds_u.time),('z_l',ds_u.z_l),('pos',ds_u.pos),('lon',da_u.lon)])
        da_dzv=xr.DataArray(dzv,coords=[('time',ds_v.time),('z_l',ds_v.z_l),('pos',ds_v.pos),('lon',da_u.lon)])
        ds_u=xr.Dataset({'u'+pr['suffix']:da_u,'dz_u'+pr['suffix']:da_dzu})
        ds_v=xr.Dataset({'v'+pr['suffix']:da_v,'dz_v'+pr['suffix']:da_dzv})
    elif pr['dim0']==2:
        dzu=np.tile(ds_temp.dz.data[np.newaxis,:,np.newaxis,np.newaxis],(nt,1,1,nx+1))
        dzv=np.tile(ds_temp.dz.data[np.newaxis,:,np.newaxis,np.newaxis],(nt,1,1,nx))
        da_dzu=xr.DataArray(dzu,coords=[('time',ds_u.time),('z_l',ds_u.z_l),('lat',da_u.lat),('pos',ds_u.pos)])
        da_dzv=xr.DataArray(dzv,coords=[('time',ds_v.time),('z_l',ds_v.z_l),('lat',da_v.lat),('pos',ds_v.pos)])
        ds_u=xr.Dataset({'u'+pr['suffix']:da_u,'dz_u'+pr['suffix']:da_dzu})
        ds_v=xr.Dataset({'v'+pr['suffix']:da_v,'dz_v'+pr['suffix']:da_dzv})

    for v in ds_u:
        #print(v)
        ds_u[v].encoding['_FillValue']=1.e20
    for v in ds_v:
    #    print(v)
        ds_v[v].encoding['_FillValue']=1.e20
      
    ds_u['time'].encoding['_FillValue']=1.e20
    ds_u['pos'].encoding['_FillValue']=1.e20
    ds_v['time'].encoding['_FillValue']=1.e20
    ds_v['pos'].encoding['_FillValue']=1.e20
    if pr['dim0']==3:
        ds_u['lon'].encoding['_FillValue']=1.e20
        ds_v['lon'].encoding['_FillValue']=1.e20
        ds_u['lon'].attrs['cartesian_axis']='X'
        ds_u['pos'].attrs['cartesian_axis']='Y'
        ds_v['lon'].attrs['cartesian_axis']='X'
        ds_v['pos'].attrs['cartesian_axis']='Y'
        ds_v=ds_v.assign_coords({'lon':ds_u.lon})
    else:
        ds_u['lat'].encoding['_FillValue']=1.e20
        ds_v['lat'].encoding['_FillValue']=1.e20
        ds_u['lat'].attrs['cartesian_axis']='Y'
        ds_u['pos'].attrs['cartesian_axis']='X'
        ds_v['lat'].attrs['cartesian_axis']='Y'
        ds_v['pos'].attrs['cartesian_axis']='X'

        ds_u=ds_u.assign_coords({'lat':ds_v.lat})

    ds_u['z_l'].encoding['_FillValue']=1.e20
    ds_v['z_l'].encoding['_FillValue']=1.e20

    
    ds_u.to_netcdf(pr['u_out'],unlimited_dims=('time'))
    ds_v.to_netcdf(pr['v_out'],unlimited_dims=('time'))

    ds=pr['ssh_in']
    ds.time.attrs['modulo']=' '
    da_ssh=xr.DataArray(ds.ffill(dim='pos',limit=None).fillna(0.))
    da_ssh=da_ssh.expand_dims('dim_0',pr['dim0']-1)
    ds_=xr.Dataset({'ssh'+pr['suffix']:da_ssh})
    for v in ds_:
        ds_[v].encoding['_FillValue']=1.e20 
    ds_['time'].encoding['_FillValue']=1.e20
    ds_['pos'].encoding['_FillValue']=1.e20
    if pr['dim0']==3:
        ds_['lon'].encoding['_FillValue']=1.e20
        ds_['lon'].attrs['cartesian_axis']='X'
        ds_['pos'].attrs['cartesian_axis']='Y'
    else:
        ds_['lat'].encoding['_FillValue']=1.e20
        ds_['lat'].attrs['cartesian_axis']='Y'
        ds_['pos'].attrs['cartesian_axis']='X'

    ds_.to_netcdf(pr['ssh_out'],unlimited_dims=('time'))