In [1]:
%%capture
try:
    import xclim
except ModuleNotFoundError:
    ! pip install xclim

In [2]:
%matplotlib inline 
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from cartopy import config
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import os 
import gcsfs 
from matplotlib import cm
import warnings 
import seaborn as sns
import yaml
from tqdm.auto import tqdm
import fsspec
from zarr.errors import GroupNotFoundError, ContainsGroupError

from xclim.indicators import cf
from xclim.indicators import icclim, atmos 

# from science_validation_manual import *
import dc6_functions
from scipy.stats import kstest
import geopandas as gpd
import dask
import dask.dataframe
import dask.array as da
import dask.distributed as dd
import rhg_compute_tools.kubernetes as rhgk
import rhg_compute_tools.utils as rhgu

  from distributed.utils import LoopRunner, format_bytes


In [3]:
plot_dir = '/gcs/impactlab-data/climate/downscaling/paper/'

In [4]:
models_dict = dc6_functions.get_cmip6_models()
ensemble_members = dc6_functions.get_cmip6_ensemble_members()
grids = dc6_functions.get_cmip6_grids()
institutions = dc6_functions.get_cmip6_institutions()

In [5]:
EXTRA_PIP_SPEC = f'xclim=={xclim.__version__}'
EXTRA_PIP_SPEC

'xclim==0.34.0'

In [6]:
with open('../version_specs/dcmip6_all_paths.yaml', 'r') as f:
    all_paths = yaml.safe_load(f)

In [7]:
DC6_VERSION = 'v1.1'
SUMMARY_VERSION = 'v1.0'

DC6_SUMMARY_PATT = (
    'gs://downscaled-288ec5ac/diagnostics/RELEASE-{dc6_version}/'
    '21-year-average-summaries/gridded-clean/{activity}/{institution_id}/{source_id}/{experiment_id}/'
    '{member_id}/{table_id}/{variable_id}/{summary_version}.zarr'
)

DC6_REGION_PATT = (
    'gs://downscaled-288ec5ac/diagnostics/RELEASE-{dc6_version}/'
    '21-year-average-summaries/regional-cities-clean/{region}/{activity}/'
    '{institution_id}/{source_id}/{experiment_id}/'
    '{member_id}/{table_id}/{variable_id}/{summary_version}.zarr'
)

POP_RASTER_WITH_REGION_ASSIGNMENTS = (
    'gs://downscaled-288ec5ac/diagnostics/rasters/'
    'gpw-v4-population-count-adjusted-to-2015-unwpp-country-totals-rev11_2020_30_sec/'
    'naturalearth_v5.0.0_adm0_and_adm1_assignments.parquet'
)

AGGREGATION_WEIGHTS_FILES = (
    'gs://downscaled-288ec5ac/diagnostics/aggregation_weights/popwt/{name}.parquet'
)

In [8]:
@rhgu.block_globals
def summarize_tasmax_ann(ds):
    tasmax_summaries = xr.Dataset({
        'number_of_days_with_tasmax_above_25C': atmos.tx_days_above(ds=ds, thresh='25 degC'),
        'number_of_days_with_tasmax_above_90F': atmos.tx_days_above(ds=ds, thresh='90 degF'),
        'number_of_days_with_tasmax_above_95F': atmos.tx_days_above(ds=ds, thresh='95 degF'),
        'annual_average_tasmax': cf.txmean(ds=ds),
    })

    return tasmax_summaries

In [9]:
@rhgu.block_globals
def summarize_tasmin_ann(ds):
    tasmin_summaries = xr.Dataset({
        'number_of_days_with_tasmin_above_20C': atmos.tn_days_above(ds=ds, thresh='20 degC'),
        'number_of_days_with_tasmin_below_0C': atmos.tn_days_below(ds=ds, thresh='0 degC'),
        'annual_average_tasmin': cf.tnmean(ds=ds),
    })

    return tasmin_summaries

In [10]:
@rhgu.block_globals
def summarize_pr_ann(ds):
    pr_summaries = xr.Dataset({
        'wetdays_above_1mm': atmos.wetdays(ds=ds, thresh='1 mm/day'),
        'wetdays_above_10mm': atmos.wetdays(ds=ds, thresh='10 mm/day'),
        'annual_total_precip': icclim.PRCPTOT(ds=ds),
    })

    return pr_summaries

In [11]:
@rhgu.block_globals
def summarize_tasmax_seas(ds):
    tasmax_summaries = xr.Dataset({
        'seasonal_average_tasmax': cf.txmean(ds=ds, freq='QS-DEC'),
    })

    return tasmax_summaries

In [12]:
@rhgu.block_globals
def summarize_tasmin_seas(ds):
    tasmin_summaries = xr.Dataset({
        'seasonal_average_tasmin': cf.tnmean(ds=ds, freq='QS-DEC'),
    })

    return tasmin_summaries

In [13]:
@rhgu.block_globals
def summarize_pr_seas(ds):
    pr_summaries = xr.Dataset({
        'seasonal_total_precip': icclim.PRCPTOT(ds=ds, freq='QS-DEC'),
    })

    return pr_summaries

In [14]:
ANNUAL_SUMMARY_FUNCS = {
    'tasmax': summarize_tasmax_ann,
    'tasmin': summarize_tasmin_ann,
    'pr': summarize_pr_ann,
}

SEASONAL_SUMMARY_FUNCS = {
    'tasmax': summarize_tasmax_seas,
    'tasmin': summarize_tasmin_seas,
    'pr': summarize_pr_seas,
}

In [15]:
@rhgu.block_globals(whitelist=[
    'DC6_VERSION',
    'all_paths',
    'SUMMARY_VERSION',
    'DC6_SUMMARY_PATT',
    'ANNUAL_SUMMARY_FUNCS',
    'SEASONAL_SUMMARY_FUNCS',
])
def compute_period_summaries(
    source_id,
    experiment_id,
    variable_id,
    return_results=False,
):

    fs = fsspec.filesystem(
        'gs',
        timeout=120,
        cache_timeout=120,
        requests_timeout=120,
        read_timeout=120,
        conn_timeout=120,
    )

    activity = ('CMIP' if experiment_id == 'historical' else 'ScenarioMIP')
    institution_id = dc6_functions.get_cmip6_institutions()[source_id]
    member_id = dc6_functions.get_cmip6_ensemble_members()[source_id]

    if (source_id == 'MPI-ESM1-2-HR') and (experiment_id == 'historical'):
        institution_id = 'MPI-M'

    if experiment_id == 'historical':
        periods = [1980, 2004]
    else:
        periods = [2030, 2050, 2089]

    fp = all_paths[source_id + '-' + variable_id][experiment_id]['clean']

    output_fp = DC6_SUMMARY_PATT.format(
        CRS_SUPPORT_BUCKET=os.environ['CRS_SUPPORT_BUCKET'],
        activity=activity,
        institution_id=institution_id,
        source_id=source_id,
        experiment_id=experiment_id,
        member_id=member_id,
        table_id='21yrroll',
        variable_id=variable_id,
        dc6_version=DC6_VERSION,
        summary_version=SUMMARY_VERSION,
    )

    if fs.isdir(output_fp):
        try:
            ds = xr.open_zarr(output_fp, consolidated=True, chunks=None).load()
            if return_results:
                return ds
            return
        except (FileNotFoundError, GroupNotFoundError, KeyError):
            try:
                fs.rm(output_fp)
            except (IOError, OSError):
                pass

            fs.invalidate_cache(path=None)

    ds = xr.open_zarr(fp)

    annual_summary = ANNUAL_SUMMARY_FUNCS[variable_id](ds)
    seasonal_summary = SEASONAL_SUMMARY_FUNCS[variable_id](ds)

    summary_periods = xr.merge([
        xr.concat(
            [
                annual_summary.sel(time=slice(str(y-10), str(y+10))).mean(dim='time')
                for y in periods
            ],
            dim=pd.Index(periods, name='period'),
        ),
        xr.concat(
            [
                seasonal_summary.sel(time=slice(str(y-10), str(y+10))).groupby('time.season').mean(dim='time')
                for y in periods
            ],
            dim=pd.Index(periods, name='period'),
        ),
    ])

    summary_periods = summary_periods.compute(retries=3)
    summary_periods.attrs.update(ds.attrs)

    summary_periods.attrs.update({
        'frequency': '21yr_mean',
        'frequency_description': '21-year mean of annual summaries',
        'summary_update': pd.Timestamp.now(tz='US/Pacific').strftime('%c (%Z)'),
    })

    try:
        summary_periods.chunk().to_zarr(output_fp, consolidated=True)
    except ContainsGroupError:
        raise ContainsGroupError(f'directory not empty: {output_fp}')

    if return_results:
        return summary_periods

In [16]:
@rhgu.block_globals(whitelist=['all_paths'])
def get_grid_spec(source_id):

    experiment_id = 'historical'
    variable_id = 'tasmax'

    activity = ('CMIP' if experiment_id == 'historical' else 'ScenarioMIP')
    institution_id = dc6_functions.get_cmip6_institutions()[source_id]
    member_id = dc6_functions.get_cmip6_ensemble_members()[source_id]

    if (source_id == 'MPI-ESM1-2-HR') and (experiment_id == 'historical'):
        institution_id = 'MPI-M'

    if experiment_id == 'historical':
        periods = [1980, 2004]
    else:
        periods = [2030, 2050, 2089]

    fp = all_paths[source_id + '-' + variable_id][experiment_id]['clean']
    with xr.open_zarr(fp, chunks=None, consolidated=True) as ds:
        lon = ds['lon'].load()
        lat = ds['lat'].load()
        lon_bnds = ds['lon_bnds'].load()
        lat_bnds = ds['lat_bnds'].load()
        
    return lat, lon, lat_bnds, lon_bnds

# Produce all summaries

In [17]:
# client, cluster = rhgk.get_standard_cluster(extra_pip_packages=EXTRA_PIP_SPEC)
# cluster.scale(40)
# cluster

In [18]:
# with tqdm(list(dc6_functions.get_cmip6_models().items())) as pbar:
#     for source_id, experiments in pbar:
#         for experiment_id in experiments:
#             for variable_id in ['tasmax', 'tasmin', 'pr']:
#                 pbar.set_postfix(
#                     {'model': source_id, 'scen': experiment_id, 'variable': variable_id}
#                 )

#                 compute_period_summaries(source_id, experiment_id, variable_id)

In [19]:
# client.restart()
# cluster.scale(0)
# client.close()
# cluster.close()

# Aggregate to regions

In [17]:
def bin_lat(df, edges, centers):
    return pd.cut(
        df.y,
        bins=edges,
        right=True,
        labels=centers,
        include_lowest=False,
    )

In [18]:
shapefile_sources = {
    'admin1': 'https://naciscdn.org/naturalearth/10m/cultural/ne_10m_admin_1_states_provinces.zip',
    'admin0': 'https://naciscdn.org/naturalearth/10m/cultural/ne_10m_admin_0_countries.zip'
}

In [19]:
admin1 = gpd.read_file(shapefile_sources['admin1'])

In [20]:
admin0 = gpd.read_file(shapefile_sources['admin0'])

In [21]:
city_spec = pd.DataFrame([
    {'city': 'Tokyo', 'lat': 35.681, 'lon': 139.767},
    {'city': 'Delhi', 'lat': 28.625, 'lon': 77.125},
    {'city': 'Shanghai', 'lat': 31.125, 'lon': 121.375},
    {'city': 'Sao Paulo', 'lat': -23.625, 'lon': -46.625},
    {'city': 'Mexico City', 'lat': 19.375, 'lon': -99.125},
    {'city': 'Cairo', 'lat': 30.125, 'lon': 31.125},
    {'city': 'Dhaka', 'lat': 23.875, 'lon': 90.375},
    {'city': 'New York', 'lat': 40.625, 'lon': -74.125},
    {'city': 'Buenos Aires', 'lat': -34.625, 'lon': -58.375},
    {'city': 'Istanbul', 'lat': 41.125, 'lon': 28.875},
    {'city': 'Lagos', 'lat': 6.510, 'lon': 3.370},
    {'city': 'Paris', 'lat': 48.875, 'lon': 2.375},
    {'city': 'Moscow', 'lat': 55.875, 'lon': 37.625},
    {'city': 'Miami', 'lat': 25.875, 'lon': -80.125},
    {'city': 'Mumbai', 'lat': 19.125, 'lon': 72.875},
    {'city': 'Manila', 'lat': 14.599, 'lon': 120.984},
    {'city': 'London', 'lat': 51.625, 'lon': -0.125},
]).set_index('city')

points = gpd.points_from_xy(city_spec.lon, city_spec.lat, crs='epsg:4326')
city_spec = gpd.GeoDataFrame(city_spec, geometry=points)
city_spec

Unnamed: 0_level_0,lat,lon,geometry
city,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Tokyo,35.681,139.767,POINT (139.76700 35.68100)
Delhi,28.625,77.125,POINT (77.12500 28.62500)
Shanghai,31.125,121.375,POINT (121.37500 31.12500)
Sao Paulo,-23.625,-46.625,POINT (-46.62500 -23.62500)
Mexico City,19.375,-99.125,POINT (-99.12500 19.37500)
Cairo,30.125,31.125,POINT (31.12500 30.12500)
Dhaka,23.875,90.375,POINT (90.37500 23.87500)
New York,40.625,-74.125,POINT (-74.12500 40.62500)
Buenos Aires,-34.625,-58.375,POINT (-58.37500 -34.62500)
Istanbul,41.125,28.875,POINT (28.87500 41.12500)


In [22]:
city_spec['ADM0_A3'] = gpd.sjoin(city_spec, admin0, how='left')['ADM0_A3']
city_spec['adm1_code'] = gpd.sjoin(city_spec, admin1, how='left')['adm1_code']

In [23]:
city_spec

Unnamed: 0_level_0,lat,lon,geometry,ADM0_A3,adm1_code
city,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Tokyo,35.681,139.767,POINT (139.76700 35.68100),JPN,JPN-1860
Delhi,28.625,77.125,POINT (77.12500 28.62500),IND,IND-2428
Shanghai,31.125,121.375,POINT (121.37500 31.12500),CHN,CHN-1819
Sao Paulo,-23.625,-46.625,POINT (-46.62500 -23.62500),BRA,BRA-1311
Mexico City,19.375,-99.125,POINT (-99.12500 19.37500),MEX,MEX-2727
Cairo,30.125,31.125,POINT (31.12500 30.12500),EGY,EGY-1544
Dhaka,23.875,90.375,POINT (90.37500 23.87500),BGD,BGD-1806
New York,40.625,-74.125,POINT (-74.12500 40.62500),USA,USA-3559
Buenos Aires,-34.625,-58.375,POINT (-58.37500 -34.62500),ARG,ARG-5493
Istanbul,41.125,28.875,POINT (28.87500 41.12500),TUR,TUR-2265


In [24]:
grid_specs = {}

for source_id in tqdm(list(dc6_functions.get_cmip6_models().keys())):
    tasmax_summary = compute_period_summaries(source_id, 'historical', 'tasmax', return_results=True)

    grid_name = tasmax_summary.attrs['grid']

    tasmax_summary['lon'] = ((tasmax_summary['lon'] + 180) % 360) - 180
    tasmax_summary = tasmax_summary.sortby('lon')
    lonres = np.unique(np.round(tasmax_summary.lon.diff(dim='lon'), 10)).item()
    lonoffset = np.round(tasmax_summary.lon.values[0], 10) / lonres

    np.testing.assert_equal(
        np.arange(lonoffset * lonres, 180, lonres),
        np.sort(np.round(tasmax_summary.lon.values, 10)),
    )

    latres = np.unique(np.round(tasmax_summary.lat.diff(dim='lat'), 10))
    if len(latres) > 1:
        lat_centers, _, lat_bnds, _ = get_grid_spec(source_id)
        latres = None
        latoffset = None
        lat_centers = lat_centers.values
        lat_bnds = np.array(
            [lat_bnds.isel(lat=0, bnds=0)] + lat_bnds.sel(bnds=1).values.tolist()
        )

    else:
        latres = latres.item()
        latoffset = np.round(tasmax_summary.lat.values[0], 10) / latres
        cell_offset = ((latoffset * latres) % latres)
        initial_offset = (latoffset * latres + 90) % latres
        lat_centers = tasmax_summary.lat.values
        lat_bnds = np.arange(
            -90 + initial_offset - 0.5 * latres,
            90 - initial_offset + 0.51 * latres,
            latres,
        )

        try:
            np.testing.assert_allclose(
                np.round(lat_centers, 12),
                np.round(np.arange(latoffset * latres, 90 + latres / 5, latres), 12),
                rtol=0,
                atol=1e-8,
            )
        except AssertionError as e:
            a = np.round(lat_centers, 6)
            b = np.round(np.arange(latoffset * latres, 90 + latres / 5, latres), 6)
            a_in_b = pd.Series(a).isin(b)
            b_in_a = pd.Series(b).isin(a)
            diff = (~a_in_b) | (~b_in_a)
            raise ValueError(str(e) + f'\n diff positions: {np.arange(len(diff))[diff]}\n x diffs: {a[diff]}\n y diffs: {b[diff]}')

    assert len(lat_bnds) == len(lat_centers) + 1
    assert lat_centers[0] > lat_bnds[0]
    assert lat_centers[-1] < lat_bnds[-1]

    # test to ensure other scenarios have the same grid spec
    for scen in dc6_functions.get_cmip6_models()[source_id][1:]:
        tasmax_summary_scen = compute_period_summaries(
            source_id, scen, 'tasmax', return_results=True,
        )
        scen_lons = np.sort(((tasmax_summary_scen['lon'].values + 180) % 360) - 180)
        np.testing.assert_allclose(
            tasmax_summary.lon.values,
            scen_lons,
            rtol=0,
            atol=1e-8,
        )

        np.testing.assert_allclose(
            lat_centers,
            tasmax_summary_scen.lat.values,
            rtol=0,
            atol=1e-8,
        )

    grid_specs[source_id] = {
        'grid': grid_name,
        'grid_label': tasmax_summary.attrs['grid_label'],
        'spec': {
            'lonres': lonres,
            'lonoffset': lonoffset,
            'latres': latres,
            'latoffset': latoffset,
            'lat_centers': lat_centers,
            'lat_bnds': lat_bnds,
        },
    }

  0%|          | 0/25 [00:00<?, ?it/s]

In [32]:
grid_specs['EC-Earth3']['grid']

'T255L91'

In [37]:
print('EC-Earth3: {}'.format(grid_specs['EC-Earth3']['grid']))
print('EC-Earth3-AerChem: {}'.format(grid_specs['EC-Earth3-AerChem']['grid']))
print('EC-Earth3-CC: {}'.format(grid_specs['EC-Earth3-CC']['grid']))
print('EC-Earth3-Veg: {}'.format(grid_specs['EC-Earth3-Veg']['grid']))
print('EC-Earth3-Veg-LR: {}'.format(grid_specs['EC-Earth3-Veg-LR']['grid']))

EC-Earth3: T255L91
EC-Earth3-AerChem: T255L91
EC-Earth3-CC: T255L91
EC-Earth3-Veg: T255L91-ORCA1L75
EC-Earth3-Veg-LR: T159L62-ORCA1L75


In [45]:
grid_specs['EC-Earth3']['spec']['lonres'], pd.Series(grid_specs['EC-Earth3']['spec']['lat_centers']).diff().mean()

(0.703125, 0.7016691887731558)

In [53]:
for m in models_dict.keys():
    print('{}: {} ({:0.4f}º lat x ~ {:0.4f}º lon)'.format(
        m,
        grid_specs[m]['grid'],
        pd.Series(grid_specs[m]['spec']['lat_centers']).diff().mean(),
        grid_specs[m]['spec']['lonres'],
    ))

BCC-CSM2-MR: T106 (1.1213º lat x ~ 1.1250º lon)
FGOALS-g3: native atmosphere area-weighted latxlon grid (80x180 latxlon) (2.2785º lat x ~ 2.0000º lon)
ACCESS-ESM1-5: native atmosphere N96 grid (145x192 latxlon) (1.2500º lat x ~ 1.8750º lon)
ACCESS-CM2: native atmosphere N96 grid (144x192 latxlon) (1.2500º lat x ~ 1.8750º lon)
INM-CM4-8: gs2x1.5 (1.5000º lat x ~ 2.0000º lon)
INM-CM5-0: gs2x1.5 (1.5000º lat x ~ 2.0000º lon)
MIROC-ES2L: native atmosphere T42 Gaussian grid (2.7893º lat x ~ 2.8125º lon)
MIROC6: native atmosphere T85 Gaussian grid (1.4004º lat x ~ 1.4062º lon)
NorESM2-LM: finite-volume grid with 1.9x2.5 degree lat/lon resolution (1.8947º lat x ~ 2.5000º lon)
NorESM2-MM: finite-volume grid with 0.9x1.25 degree lat/lon resolution (0.9424º lat x ~ 1.2500º lon)
GFDL-ESM4: atmos data regridded from Cubed-sphere (c96) to 180,288; interpolation method: conserve_order2 (1.0000º lat x ~ 1.2500º lon)
GFDL-CM4: atmos data regridded from Cubed-sphere (c96) to 180,288; interpolation meth

In [30]:
pd.Series(grid_specs['EC-Earth3']['spec']['lat_centers']).diff()

0           NaN
1      0.695870
2      0.699980
3      0.700908
4      0.701260
         ...   
251    0.701431
252    0.701260
253    0.700908
254    0.699980
255    0.695870
Length: 256, dtype: float64

In [25]:
@rhgu.block_globals(whitelist=['AGGREGATION_WEIGHTS_FILES', 'city_spec'])
def get_aggregators(source_id, grid_specs, pop_raster, admin0, admin1):
    binned_lat = pop_raster.map_partitions(
        bin_lat,
        edges=grid_specs[source_id]['spec']['lat_bnds'],
        centers=grid_specs[source_id]['spec']['lat_centers'],
    )

    lon_res = grid_specs[source_id]['spec']['lonres']
    lonoffset = grid_specs[source_id]['spec']['lonoffset']
    lon_cell_offset = ((lonoffset * lon_res) % lon_res)
    binned_lon = ((pop_raster.x // lon_res) * lon_res + lon_cell_offset)

    this_pop_raster = pop_raster.assign(grid_lat=binned_lat, grid_lon=binned_lon)

    adm0_fp = AGGREGATION_WEIGHTS_FILES.format(
        name=f'model-native-grid-aggregators/{source_id}_adm0_cities_popwt'
    )

    try:
        aggregator_adm0 = pd.read_parquet(adm0_fp)
    except (FileNotFoundError, IOError):

        aggregator_adm0 = (
            this_pop_raster
            .map_partitions(lambda df: df.groupby(['grid_lat', 'grid_lon', 'ADM0_A3']).population.sum())
            .compute()
            .groupby(level=['grid_lat', 'grid_lon', 'ADM0_A3'])
            .sum()
            .to_frame()
            .reset_index(drop=False)
        )

        aggregator_adm0 = aggregator_adm0[aggregator_adm0.population > 0]
        aggregator_adm0.to_parquet(adm0_fp)

    adm1_fp = AGGREGATION_WEIGHTS_FILES.format(
        name=f'model-native-grid-aggregators/{source_id}_adm1_cities_popwt'
    )

    try:
        aggregator_adm1 = pd.read_parquet(adm1_fp)
    except (FileNotFoundError, IOError):

        aggregator_adm1 = (
            this_pop_raster[this_pop_raster.adm1_code.isin(city_spec.adm1_code.values)]
            .map_partitions(lambda df: df.groupby(['grid_lat', 'grid_lon', 'adm1_code']).population.sum())
            .compute()
            .groupby(level=['grid_lat', 'grid_lon', 'adm1_code'])
            .sum()
            .to_frame()
            .reset_index(drop=False)
        )

        aggregator_adm1 = aggregator_adm1[aggregator_adm1.population > 0]
        aggregator_adm1.to_parquet(adm1_fp)

    return aggregator_adm0, aggregator_adm1

# Build all aggregation maps

In [29]:
# client, cluster = rhgk.get_big_cluster(extra_pip_packages=EXTRA_PIP_SPEC)
# cluster.scale(12)
# cluster

In [26]:
pop_raster = dask.dataframe.read_parquet(POP_RASTER_WITH_REGION_ASSIGNMENTS)
pop_raster = pop_raster[
    (pop_raster.population > 0)
    & (
        pop_raster.ADM0_A3.isin(city_spec.ADM0_A3)
        | pop_raster.adm1_code.isin(city_spec.adm1_code)
    )
]

In [31]:
with tqdm(list(dc6_functions.get_cmip6_models().items())) as pbar:
    for source_id, _ in pbar:
            aggregator_adm0, aggregator_adm1 = get_aggregators(
                source_id=source_id,
                grid_specs=grid_specs,
                pop_raster=pop_raster,
                admin0=admin0,
                admin1=admin1,
            )

  0%|          | 0/25 [00:00<?, ?it/s]

In [32]:
# client.restart()
# cluster.scale(0)
# client.close()
# cluster.close()

In [33]:
shp_meta_cols = {
    'admin0': ['NAME', 'SOVEREIGNT'],
    'admin1': ['name', 'adm0_a3', 'admin'],
}

In [34]:
shapes = {
    'admin0': admin0,
    'admin1': admin1,
}

In [35]:
import pprint

In [36]:
def isdir(dirname):
    if dirname.startswith('gs://') or dirname.startswith('/gcs/'):
        fs = fsspec.filesystem(
            'gs',
            timeout=120,
            cache_timeout=120,
            requests_timeout=120,
            read_timeout=120,
            conn_timeout=120,
        )
        isdir = fs.isdir
        ls = fs.ls
    else:
        isdir = os.path.isdir
        ls = os.listdir

    try:
        return len(ls(dirname)) > 0

    except IOError:
        return False

In [37]:
@rhgu.block_globals(whitelist=['DC6_REGION_PATT', 'DC6_VERSION', 'SUMMARY_VERSION', 'shp_meta_cols', 'shapes', 'shapefile_sources'])
def aggregate(source_id, experiment_id, variable_id, aggregator_adm0, aggregator_adm1):


    fs = fsspec.filesystem(
        'gs',
        timeout=120,
        cache_timeout=120,
        requests_timeout=120,
        read_timeout=120,
        conn_timeout=120,
    )

    activity = ('CMIP' if experiment_id == 'historical' else 'ScenarioMIP')
    institution_id = dc6_functions.get_cmip6_institutions()[source_id]
    member_id = dc6_functions.get_cmip6_ensemble_members()[source_id]

    if (source_id == 'MPI-ESM1-2-HR') and (experiment_id == 'historical'):
        institution_id = 'MPI-M'

#     adm0_out_fp = 'local_adm0.zarr'
    adm0_out_fp = DC6_REGION_PATT.format(
        CRS_SUPPORT_BUCKET=os.environ['CRS_SUPPORT_BUCKET'],
        activity=activity,
        institution_id=institution_id,
        source_id=source_id,
        experiment_id=experiment_id,
        member_id=member_id,
        table_id='21yrroll',
        variable_id=variable_id,
        dc6_version=DC6_VERSION,
        summary_version=SUMMARY_VERSION,
        region='admin0',
    )
    
#     adm1_out_fp = 'loacal_adm1.zarr'
    adm1_out_fp = DC6_REGION_PATT.format(
        CRS_SUPPORT_BUCKET=os.environ['CRS_SUPPORT_BUCKET'],
        activity=activity,
        institution_id=institution_id,
        source_id=source_id,
        experiment_id=experiment_id,
        member_id=member_id,
        table_id='21yrroll',
        variable_id=variable_id,
        dc6_version=DC6_VERSION,
        summary_version=SUMMARY_VERSION,
        region='admin1',
    )

    if isdir(adm0_out_fp) and isdir(adm1_out_fp):
        return

    summary = compute_period_summaries(
        source_id,
        experiment_id,
        variable_id,
        return_results=True,
    )

    summary['lon'] = ((summary['lon'] + 180) % 360) - 180
    summary = summary.sortby('lon')

    for agg_kind, aggregator, agg_col, agg_fp in [
        ('admin0', aggregator_adm0, 'ADM0_A3', adm0_out_fp),
        ('admin1', aggregator_adm1, 'adm1_code', adm1_out_fp),
    ]:

        if isdir(agg_fp):
            continue

        summary_by_region = (
            (
                summary
                .sel(
                    lat=aggregator.grid_lat.to_xarray(),
                    lon=aggregator.grid_lon.to_xarray(),
                    method='nearest',
                    tolerance=1e-6,
                )
                * aggregator.population.to_xarray()
            )
            .groupby(aggregator[agg_col].to_xarray()).sum()
            / aggregator.groupby(agg_col).population.sum().to_xarray()
        )

        for v in summary_by_region.data_vars.keys():
            summary_by_region[v].attrs.update(summary[v].attrs)

        for col in shp_meta_cols[agg_kind]:
            summary_by_region.coords[col] = (
                shapes[agg_kind].set_index(agg_col)[col].to_xarray().astype(str)
            )
            summary_by_region.coords[col] = [str(v) for v in summary_by_region.coords[col].values]

        summary_by_region.attrs.update(summary.attrs)

        summary_by_region.attrs['regionagg_updated'] = (
            pd.Timestamp.now(tz='US/Pacific').strftime('%c (%Z)')
        )

        summary_by_region.attrs['regionagg_shape_name'] = agg_kind
        summary_by_region.attrs['regionagg_shape_source'] = shapefile_sources[agg_kind]
        summary_by_region.attrs['regionagg_weighting'] = 'population-weighted mean'

        summary_by_region.attrs['regionagg_weight_source'] = (
            'https://sedac.ciesin.columbia.edu/data/set/gpw-v4-population-density-'
            'adjusted-to-2015-unwpp-country-totals-rev11'
        )

        summary_by_region.attrs['regionagg_method'] = (
            '30 arcsecond population pixels are assigned (based on intersection) to '
            'polygons in the shapefile as well as to the coarse grid, then gridded '
            'summary statistics from the coarse resolution were weighted based on the '
            'fraction of regional population contained in the intersection of the region '
            'and each coarse grid cell, then averaged to produce the regional value'
        )

        to_store = summary_by_region.copy()
        for var in to_store.variables:
            to_store[var].encoding.clear()
        for var in to_store.coords.keys():
            to_store.coords[var].encoding.clear()
            
        for c in to_store.coords.keys():
            if to_store.coords[c].dtype.kind in 'SOU':
                to_store.coords[c] = to_store.coords[c].astype("unicode")
                
        try:
            to_store.chunk().to_zarr(agg_fp, consolidated=True)
            return
        except ContainsGroupError:
            raise ContainsGroupError(agg_fp)

        raise ValueError("I shouldn't get here")

In [None]:
with tqdm(list(dc6_functions.get_cmip6_models().items())) as pbar:
    for source_id, experiments in pbar:
            aggregator_adm0, aggregator_adm1 = get_aggregators(
                source_id=source_id,
                grid_specs=grid_specs,
                pop_raster=pop_raster,
                admin0=admin0,
                admin1=admin1,
            )

            for experiment_id in experiments:
                for variable_id in ['tasmax', 'tasmin', 'pr']:
                    pbar.set_postfix(
                        {'model': source_id, 'scen': experiment_id, 'variable': variable_id}
                    )

                    aggregate(
                        source_id,
                        experiment_id,
                        variable_id,
                        aggregator_adm0,
                        aggregator_adm1,
                    )