In [1]:
import os
import sys
import yaml

# ---------- #
# Numerics
import xarray as xr
import numpy as np

# ---------- #
# AI libs
import torch
from torchvision import transforms

# ---------- #
# credit
from credit.data404 import CONUS404Dataset
from credit.transforms404 import ToTensor, NormalizeState


In [2]:
config = "/glade/work/mcginnis/ML/GWC/miles-credit/config/test.save.conus404.yml"

with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

conf["data"]["history_len"]=1
conf["data"]["forecast_len"]=1

conf["predict"]["start"] = "2017-11-01 00:00:00"
conf["predict"]["finish"] = "2017-11-01 23:00:00"

In [3]:
transform = transforms.Compose(
        [
            # NormalizeState(conf), # uncommenting this validates that data gets transformed
            ToTensor(conf),
        ]
    )

ds = CONUS404Dataset(
    varnames=conf["data"]["variables"],
    history_len=conf["data"]["history_len"],
    forecast_len=conf["data"]["forecast_len"],
    transform=transform,
    start=conf["predict"]["start"],
    finish=conf["predict"]["finish"]
    )


In [4]:
ds

CONUS404Dataset(zarrpath='/glade/campaign/ral/risc/DATA/conus404/zarr', varnames=['PREC_ACC_NC', 'PSFC', 'Q2', 'T2', 'TD2'], history_len=1, forecast_len=1, transform=Compose(
    <credit.transforms404.ToTensor object at 0x14fbb70efed0>
), seed=22, skip_periods=None, one_shot=False, start='2017-11-01 00:00:00', finish='2017-11-01 23:00:00')

In [5]:
len(ds)

23

In [6]:
## create a list of tensors from the input C404Dataset
## (This is what you'd get from a model rollout on a sequence of input samples)

outdims = ["time","vars","z","y","x"]

tensorlist = []

for index in range(len(ds)):
    x = ds[index]['x'].unsqueeze(0)
    tensorlist.append(x)

outtensor = torch.cat(tensorlist, dim=0)
print(outtensor.size())

torch.Size([23, 5, 1, 512, 512])


In [7]:
## open template file
c404_path = "/glade/campaign/collections/rda/data/ds559.0/"
template_path = "wy1980/197910/wrf2d_d01_1979-10-01_00:00:00.nc"
# template_path = "wy1980/197910/wrf3d_d01_1979-10-01_00:00:00.nc"
raw_template = xr.open_dataset(c404_path + template_path)

## subset to region (needs to move from transforms404 to config)
x0 = 120
xsize = 512
y0 = 300
ysize = 512

xs = slice(x0, x0+xsize)
ys = slice(y0, y0+ysize)

template = raw_template.isel(south_north=ys, south_north_stag=ys, west_east=xs, west_east_stag=xs)

print(template)

<xarray.Dataset> Size: 228MB
Dimensions:                 (Time: 1, south_north: 512, west_east: 512,
                             west_east_stag: 512, south_north_stag: 512,
                             soil_layers_stag: 4, snow_layers_stag: 3,
                             snso_layers_stag: 7)
Coordinates:
  * Time                    (Time) datetime64[ns] 8B 1979-10-01
    XLAT                    (south_north, west_east) float32 1MB ...
    XLONG                   (south_north, west_east) float32 1MB ...
    XLAT_U                  (south_north, west_east_stag) float32 1MB ...
    XLONG_U                 (south_north, west_east_stag) float32 1MB ...
    XLAT_V                  (south_north_stag, west_east) float32 1MB ...
    XLONG_V                 (south_north_stag, west_east) float32 1MB ...
    XTIME                   (Time) datetime64[ns] 8B ...
Dimensions without coordinates: south_north, west_east, west_east_stag,
                                south_north_stag, soil_layers_sta

In [8]:
## Convert output tensor to list of xarray dataarrays

## We have to split it into separate xarrays for each variable because CONUS404 still has winds
## on staggered grids, and to attach the 3 different sets of 2D lat/lon coordinates appropriately,
## they can't all be merged into the same dataset.

## HOWEVER, this is only for 3-D wind variables; not surface vars. BUT, for 3-D vars I'd have to 
## split the varnames back into var + hPa-to-level, and I think the right thing to do there is to
## follow what the new data pipeline is doing and leave the 3D vars as 3D vars instead of pulling
## out the levels we want, so I'd either need to change my zarrification workflow and redo all of
## those, or (better) (hopefully) just throw some kerchunk / virtualiZarr indexing onto the raw
## CONUS404 netcdf files and not actually zarrify them at all.  And that's still a ways down the
## road, so the multiple 2D lat/lon for different staggers is completely untested at this point.

xrlist = []

## Loop on output variables:

vnames = conf["data"]["variables"]

for i in range(len(vnames)):
    v = vnames[i]
    
    ## Pull corresponding slice from tensor
    ## tensor dimensions: time, var, z, y, x
    ## squeeze to drop singleton z dim
    vdata = outtensor[:,i,:,:,:].squeeze()  

    ## convert to xarray & name dims according to template
    vxra = xr.DataArray(vdata, dims=template[v].dims)

    ## copy metadata
    vxra.attrs = template[v].attrs

    ## copy coords -- first, we need to drop time coordinates
    tvc = template[v].coords
    ckeep = [cname for cname in list(tvc) if "time" not in cname.lower()]
    ## then create a dictionary of those coordinates
    #cdict = {k: tvc[k] for k in ckeep}
    cdict = {}
    for k in ckeep:
        coord = tvc[k]
        for L in ("latitude", "longitude"):
            if L in coord.attrs["description"].lower():
                coord.attrs["standard_name"] = L
        cdict[k] = coord

    ## and now we can assign those coordinates to the dataarray
    vxra = vxra.assign_coords(cdict)
    
    ## add  to list
    xrlist.append({v: vxra})

#print(xrlist)

In [9]:
## combine list into dataset
dsout = xr.merge(xrlist)

## WRF time coordinate is missing attributes, create it by hand
nt = outtensor.size()[0]  ## time is first dimension
timevals = [float(t) for t in range(nt)]
timeatts = {
    "standard_name": "time",
    "long_name": "time",
    "units": "hours since " + conf["predict"]["start"],
    "calendar": "proleptic_gregorian"
}
time = xr.DataArray(data=timevals, dims=ds.tdimname, attrs=timeatts)
dsout = dsout.assign_coords({ds.tdimname: time})

## copy / write global metadata

dsout.attrs["Conventions"] = "CF-1.11"
dsout.attrs["frequency"] = "1hr"


In [10]:
## add map projection 

tatt = template.attrs

if tatt["MAP_PROJ_CHAR"] != "Lambert Conformal":
    raise ValueError("WRF map projection is not Lambert Conformal; don't know how to deal with others yet")

proj = xr.DataArray(None, attrs={
    "grid_mapping_name": "lambert_conformal_conic",
    "earth_radius": float(6370000),
    "standard_parallel": [tatt["TRUELAT1"], tatt["TRUELAT2"]],
    "longitude_of_central_meridian": tatt["STAND_LON"],
    "latitude_of_projection_origin": tatt["MOAD_CEN_LAT"]
})

pname = "LCC"

dsout[pname] = proj


In [11]:
## standard names

stdname = {
    # "ACSWDNLSM": "", divide by 60 minutes to get downwelling_shortwave_flux_in_air
    # "COSZEN": "",  # solar_zenith_angle exists, but not cosine thereof
    "PREC_ACC_NC": "lwe_thickness_of_precipitation_amount",
    # CONUS404 is convection-resolving, so grid-scale precip is all of it
    # units of length makes it lwe (liquid water equivalent)
    "PSFC": "surface_air_pressure",
    "Q2": "humidity_mixing_ratio",
    "SNOW": "snowfall_amount",
    "TD2": "dew_point_temperature",
    "T2": "air_temperature",
    "totalVap": "atmosphere_mass_content_of_water_vapor",
}

for v in vnames:
    if v in stdname:
        dsout[v].attrs["standard_name"] = stdname[v]
    if v in ("T2"):  # for temperatures, need to indicate absolute vs delta
        dsout[v].attrs["units_metadata"] = "on_scale"

# Possible TODO: sort attributes alphabetically

In [None]:
## TODO: add coordinate values (1:N x 4) to x & y axes, units = 'km'

In [12]:
## TODO: add scalar 2-m height coordinate for appropriate variables
## it appears this needs to be done using xarray coordinates, not just appending to attribute

two_meter = ("Q2", "TD2", "T2")  # vars that need 2-m scalar height


In [13]:
## write to netcdf

commonatts = {"_FillValue": 1e20,
              "missing_value": 1e20,
              "grid_mapping": "LCC"}

encodict = {v: commonatts for v in vnames}

dsout.to_netcdf("/glade/work/mcginnis/ML/GWC/format-test.nc",
               unlimited_dims = ds.tdimname,
               encoding = encodict)