In [None]:
import cartopy.crs as ccrs # just for plotting
from glob import glob
import matplotlib.pyplot as plt # just for plotting
import numpy as np
import os
import pandas as pd
from pygris import counties, states
import xarray as xr # also need to install netcdf4 and dask[complete]
import zarr
import fsspec

In [None]:
# method to fix the weird WRF indexing in the original NetCDF files
# and load the time/space dimensions into memory
#
# def preprocess(d):
#     d = d.rename_dims({
#         'Time': 'time',
#     }).rename_vars({
#         'XLAT': 'lat',
#         'XLONG': 'lon',
#     })
#     d['time'] = pd.to_datetime(
#         d.Times.load().astype(str).str.replace('_', ' ')
#     )
#     d = d.drop_vars(['Times'])
#     d['lat'] = d.lat.isel(time=0).load()
#     d['lon'] = d.lon.isel(time=0).load()
#     return d


In [None]:
# I used this to convert the NetCDFs
# NOTE that it may be important to process them in time order...
# NOTE not sure what happens if you append_dim out of order
#
# for i, f in enumerate(sorted(glob('./tgw_wrf_*.nc'))):
#     d = xr.open_mfdataset(f, preprocess=preprocess)
#     if i==0:
#         d.to_zarr('./tgw_wrf_rcp85hotter_hourly_2088_default_chunks.zarr')
#     else:
#         d.to_zarr('./tgw_wrf_rcp85hotter_hourly_2088_default_chunks.zarr', append_dim='time')
#     d.close()


In [None]:
# prepare dataset of counties and states for subsetting

# CONUS states
conus_states = states(cb=True, year=2020, cache=True).to_crs("epsg:4326")
conus_states = conus_states[~conus_states.NAME.isin([
    'Alaska','American Samoa','Puerto Rico','United States Virgin Islands',
    'Hawaii','Guam','Commonwealth of the Northern Mariana Islands',
])]

# CONUS counties
conus_counties = counties(cb=True, year=2020, cache=True).to_crs("epsg:4326")
conus_counties = conus_counties[conus_counties.STATEFP.isin(
    conus_states.STATEFP
)]


In [None]:
def get_tgw_subset(
    *,
    start: str, # 'YYYY-MM-DD' or 'YYYY-MM-DDTHH:MM:SS'
    end: str, # 'YYYY-MM-DD' or 'YYYY-MM-DDTHH:MM:SS'
    county_fips: str = None, # county FIPS code to keep, None for all
    state_abbreviation: str = None, # State abbreviation to keep, None for all
    min_lat: float = None, # minimum latitude in WGS84 (epsg:4326)
    max_lat: float = None, # maximum latitude in WGS84 (epsg:4326)
    min_lon: float = None, # minimum longitude in WGS84 (epsg:4326)
    max_lon: float = None, # maximum longitude in WGS84 (epsg:4326)
    variables = None, # list of variables to keep, None for all
    data_store = './data/tgw_wrf_rcp85hotter_hourly_2088_default_chunks.zarr', # path to the zarr
    load = True, # if True, load the data before returning; otherwise return the chunked dask dataset
    write_to_file = False, # if a path, write subset to that path; if False don't
):

    # NOTE that certain variables (precipitation, etc) are presented as "cumulative",
    #      meaning that the user may actually need one timestep before the requested
    #      start time in order to fully resolve those variables
    # TODO this is not accounted for in this method

    # NOTE that the WRF data presented in WGS84 (epsg:4326) projection as is the case
    #      here is NOT on a rectilinear grid, which can be confusing to work with, but
    #      the native WRF projection IS on a rectilinear grid but those coordinates are
    #      not provided by default (see the python package salem for more details...)

    # NOTE the data_store must be used to filter by scenario,
    #      but users may benefit from a wrapper for that functionality too

    # open the files with dask chunks
    # TODO may be more efficient chunking method than the default...
    # d = xr.open_mfdataset(data_store, engine='zarr', parallel=True)
    
    import s3fs
    import configparser
    import os
    # Load AWS credentials from the file
    # config = configparser.ConfigParser()
    # config.read(os.getenv('AWS_SHARED_CREDENTIALS_FILE'))
    # aws_access_key_id = config.get('default', 'aws_access_key_id')
    # aws_secret_access_key = config.get('default', 'aws_secret_access_key')
    
    # print(f'aws_access_key_id {aws_access_key_id}')
    # print(f'aws_secret_access_key {aws_secret_access_key}')
    
    # s3 = s3fs.S3FileSystem(
    #     key=aws_access_key_id,
    #     secret=aws_secret_access_key,
    #     client_kwargs={
    #     'endpoint_url': 'https://8mg1a-s4774-889772541283.s3-accesspoint.us-west-2.amazonaws.com'
    #     }
    # )
    # mapper = s3.get_mapper('8mg1a-s4774/tgw_wrf_rcp85hotter_hourly_2088_default_chunks.zarr')
    # d = xr.open_mfdataset(mapper, engine='zarr', parallel=True)


    # worked with zoe's admin creds copied to ec2 instance's aws_creds file (sudo dnf install nano) 
    # stop service that refreshes creds once a min was: 
    # sudo systemctl stop msdlive_creds.service 
    # sudo systemctl stop msdlive_creds.timer
    
    # but crashed the kernel when doing a 1 month run
    # s3 = s3fs.S3FileSystem()
    # mapper = s3.get_mapper('msdlive-project-cats-dev/8mg1a-s4774/tgw_wrf_rcp85hotter_hourly_2088_default_chunks.zarr')
    # d = xr.open_zarr(mapper, parallel=True)
    
    # did not work:
    # d = xr.open_mfdataset(mapper, engine='zarr', parallel=True)
    
    # Create an S3FileSystem object
    s3 = s3fs.S3FileSystem()
    
    # Use the Access Point ARN
    # Replace 'region', 'account-id', and 'access-point-name' with your actual values
    access_point_arn = 'arn:aws:s3:us-west-2:889772541283:accesspoint/8mg1a-s4774'
    
    # Now, use this ARN to get the mapper
    # Append your specific path after the ARN
    mapper = s3.get_mapper(f'{access_point_arn}/8mg1a-s4774/tgw_wrf_rcp85hotter_hourly_2088_default_chunks.zarr')
    d = xr.open_zarr(mapper, consolidated=True)
    

    # load the coordinates so they can be used as indexers
    d.lat.load();
    d.lon.load();
    d.time.load();

    # subset by variables
    # TODO may help to ignore or just warn about requested variables that don't
    #      exist, rather than just fail
    if variables is not None:
        d = d[variables]
    
    # subset by date
    d = d.sel(time=slice(start, end))

    # subset by space
    # NOTE that there could be errors caused by use of -180 to 180 vs 0 to 360 nomenclature
    # TODO may want to build in a buffer to be sure to catch the edges of the shape
    if (state_abbreviation is not None):
        state_bounds = conus_states[conus_states.STUSPS == state_abbreviation.upper()].bounds.iloc[0]
        d = d.where(
            (d.lat>=state_bounds.miny) &
            (d.lat<=state_bounds.maxy) &
            (d.lon>=state_bounds.minx) &
            (d.lon<=state_bounds.maxx),
            drop=True,
        )
    if (county_fips is not None):
        county_bounds = conus_counties[conus_counties.GEOID == county_fips].bounds.iloc[0]
        d = d.where(
            (d.lat>=county_bounds.miny) &
            (d.lat<=county_bounds.maxy) &
            (d.lon>=county_bounds.minx) &
            (d.lon<=county_bounds.maxx),
            drop=True,
        )
    if (min_lat is not None) or (max_lat is not None) or (min_lon is not None) or (max_lon is not None):
        d = d.where(
            (d.lat>=(min_lat if min_lat is not None else -np.Inf)) &
            (d.lat<=(max_lat if max_lat is not None else np.Inf)) &
            (d.lon>=(min_lon if min_lon is not None else -np.Inf)) &
            (d.lon<=(max_lon if max_lon is not None else np.Inf)),
            drop=True,
        )

    # write the data to file if requested
    if write_to_file:
        d.to_netcdf(write_to_file)

    # loading the data fully into memory takes some time
    # a user skilled with dask may benefit from keeping the data unloaded
    # until the end of their data transformations
    if load:
        return d.load()
    return d
    

In [None]:
%%time
HOME = os.environ.get("HOME")
DATA_DIR = os.path.join(HOME, "data/s3")
d = get_tgw_subset(
    start='2088-01-01',
    end='2088-02-01T00:00:00',
    variables=['T2'],
    # county_fips='53033',
    state_abbreviation='wa',
    # min_lat=45.543830,
    # max_lat=49.002405,
    # min_lon=-124.7336,
    # max_lon=-116.9161,
    data_store=f'{DATA_DIR}/tgw_wrf_rcp85hotter_hourly_2088_default_chunks.zarr',
    load=False,
    # write_to_file='./subset.nc',
)

In [None]:
# look at the subset
d

In [None]:
%%time
# plot the subset in the usual WGS84 datum
fig = plt.figure(figsize=(10.8, 7.2), dpi=150, layout='tight')
ax = plt.axes(projection=ccrs.PlateCarree(), frameon=False)
conus_counties[conus_counties.STUSPS == 'WA'].boundary.plot(ax=ax, transform=ccrs.PlateCarree(), linewidth=0.5, color='black')
d.isel(time=0).T2.plot(ax=ax, transform=ccrs.PlateCarree(), x="lon", y="lat", alpha=0.5, cmap='coolwarm')
ax.set_title('');

In [None]:
%%time
# plot the subset in the TGW-WRF native projection
# '+proj=lcc +lat_0=40.0000076293945 +lon_0=-97 +lat_1=30 +lat_2=45 +x_0=0 +y_0=0 +R=6370000 +units=m +no_defs'
tgw_crs = ccrs.LambertConformal(
    central_longitude=-97.0,
    central_latitude=40.0000076293945,
    standard_parallels=(30, 45),
    globe=None,
)
fig = plt.figure(figsize=(10.8, 7.2), dpi=150, layout='tight')
ax = plt.axes(projection=tgw_crs, frameon=False)
conus_counties[conus_counties.STUSPS == 'WA'].boundary.plot(ax=ax, transform=ccrs.PlateCarree(), linewidth=0.5, color='black')
d.isel(time=0).T2.plot(ax=ax, transform=ccrs.PlateCarree(), x="lon", y="lat", alpha=0.5, cmap='coolwarm')
ax.set_title('');