In [19]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import gcsfs
from tqdm.autonotebook import tqdm
from scipy import signal

In [20]:
# Experiments to process
experiment_ids = ['historical', 'ssp370']

# Seasons to process
seasons = ['all','DJF','JJA']

# Time slices (future) to process
time_slices = ([['2011','2040'],
                ['2021','2050'],
                ['2031','2060'],
                ['2041','2070'],
                ['2051','2080'],
                ['2061','2090'],
                ['2071','2100']])

In [21]:
# Read data locations
df = pd.read_csv('https://storage.googleapis.com/pangeo-cmip6/pangeo-cmip6-zarr-consolidated-stores.csv')

In [22]:
# Subset to monthly precipitation (let's just look at one run per model for now)
df_mon_pr = (df[((df.table_id == 'Amon') 
                 & (df.variable_id == 'pr')
                 & (df.member_id == "r1i1p1f1"))])

In [23]:
# Get names of all the models in the subset
source_ids = []
for name, group in df_mon_pr.groupby('source_id'):
    if all([expt in group.experiment_id.values
            for expt in experiment_ids]):
        source_ids.append(name)

In [24]:
# Function to load precipitation data
def load_pr_data(source_id, expt_id):
    """
    Load 3hr precip data for given source and expt ids
    """
    uri = df_mon_pr[(df_mon_pr.source_id == source_id) &
                         (df_mon_pr.experiment_id == expt_id)].zstore.values[0]
    
    gcs = gcsfs.GCSFileSystem(token='anon')
    ds = xr.open_zarr(gcs.get_mapper(uri), consolidated=True)
    return ds

In [25]:
# Function to get mean and variability and their changes
def mean_var_calc(data_tmp,seas='all'):
    # Set if season is 'all'
    if seas=="all":
        seas = ['DJF','MAM','JJA','SON']
    
    # Load
    mu_hist = data_tmp['hist'].load()
    mu_futr = data_tmp['futr'].load()
    
    # Area weighting
    latr = np.deg2rad(data_tmp['hist'].lat)
    # Use the cosine of the converted latitudes as weights for the average
    weights = np.cos(latr)

    # Calculate mean of raw series
    mu_hist = (data_tmp['hist'].sel(time=data_tmp['hist'].time.dt.season.isin(seas))
               .pr.mean('time'))
    mu_futr = (data_tmp['futr'].sel(time=data_tmp['futr'].time.dt.season.isin(seas))
               .pr.mean('time'))
    dmu = mu_futr/mu_hist
    dmuG = np.average(dmu.mean("lon"),weights=weights.values)

    # Calculate standard deviation of detrended series
    sd_hist = (xr.apply_ufunc(signal.detrend, data_tmp['hist'].fillna(0),
                                    kwargs={'axis': 0}).where(~data_tmp['hist'].isnull())
               .sel(time=data_tmp['hist'].time.dt.season.isin(seas))
               .std("time"))
    sd_futr = (xr.apply_ufunc(signal.detrend, data_tmp['futr'].fillna(0),
                                    kwargs={'axis': 0}).where(~data_tmp['futr'].isnull())
               .sel(time=data_tmp['futr'].time.dt.season.isin(seas))
               .std("time"))
    dsd = sd_futr/sd_hist
    dsdG = np.average(dsd.mean("lon").pr,weights=weights.values)

    # Out 
    outp = xr.Dataset(
        data_vars = {'mu_hist': mu_hist,
                     'mu_futr': mu_futr,
                     'dmu':     dmu,
                     'dmuG':    ([],dmuG),
                     'sd_hist': sd_hist.pr,
                     'sd_futr': sd_futr.pr,
                     'dsd':     dsd.pr,
                     'dsdG':    ([],dsdG)},
        #coords={'lon':            (['lon'],results_tmp['hist'].lon),
        #        'lat':             (['lat'],results_tmp['hist'].lat)}
        )


    return(outp)

In [9]:

data = {}
results = {}

for mod_name in tqdm(source_ids):
    # get a 20 year period
    # try:
        print('\n\nStarting '+mod_name+'\n')
        ds_hist = load_pr_data(mod_name, experiment_ids[0]).sel(time=slice('1976', '2005'))
        ds_ssp = load_pr_data(mod_name, experiment_ids[1])
        
        data[mod_name] = {}
        results[mod_name] = {}
        
        for time_slice in time_slices:
            print('Begin processing time slice '+time_slice[0]+'-'+time_slice[1])
            if ds_ssp.time.max().dt.year+1<int(time_slice[1]):
                print("Future time series only goes until "+str(ds_ssp.time.max().dt.year.values))
                break
            
            # Get corresponding temporal slice of data and stage it
            ds_ssp_tmp = ds_ssp.sel(time=slice(time_slice[0],time_slice[1]))
            data[mod_name]["t"+time_slice[0]] = {'hist':ds_hist,'futr':ds_ssp_tmp}
            
            
            results[mod_name]["t"+time_slice[0]] = {}
            for seas in seasons:
                # Calculate means, sds,...
                results[mod_name]['t'+time_slice[0]][seas] = mean_var_calc(data[mod_name]['t'+time_slice[0]],seas)
                print(seas+' processed!')
                
            print(time_slice[0]+'-'+time_slice[1]+' processed!')
            
        print(mod_name+' processed!')
    #except:
    #    print(mod_name+"broken")

HBox(children=(IntProgress(value=0, max=11), HTML(value='')))



Starting BCC-CSM2-MR

Begin processing time slice 2006-2035
all processed!
DJF processed!
JJA processed!
2006-2035 processed!
Begin processing time slice 2016-2045
all processed!
DJF processed!
JJA processed!
2016-2045 processed!
Begin processing time slice 2026-2055
all processed!
DJF processed!
JJA processed!
2026-2055 processed!
Begin processing time slice 2036-2065
all processed!
DJF processed!
JJA processed!
2036-2065 processed!
Begin processing time slice 2046-2075
all processed!
DJF processed!
JJA processed!
2046-2075 processed!
Begin processing time slice 2056-2085
all processed!
DJF processed!
JJA processed!
2056-2085 processed!
BCC-CSM2-MR processed!


Starting BCC-ESM1

Begin processing time slice 2006-2035
all processed!
DJF processed!
JJA processed!
2006-2035 processed!
Begin processing time slice 2016-2045
all processed!
DJF processed!
JJA processed!
2016-2045 processed!
Begin processing time slice 2026-2055
all processed!
DJF processed!
JJA processed!
2026-2055 process

In [10]:
results

{'BCC-CSM2-MR': {'t2006': {'all': <xarray.Dataset>
   Dimensions:  (lat: 160, lon: 320)
   Coordinates:
     * lat      (lat) float64 -89.14 -88.03 -86.91 -85.79 ... 86.91 88.03 89.14
     * lon      (lon) float64 0.0 1.125 2.25 3.375 4.5 ... 355.5 356.6 357.8 358.9
   Data variables:
       mu_hist  (lat, lon) float32 2.4228807e-06 2.4099031e-06 ... 5.6194244e-06
       mu_futr  (lat, lon) float32 2.8700958e-06 2.8530749e-06 ... 6.3723937e-06
       dmu      (lat, lon) float32 1.18458 1.1838961 ... 1.1390164 1.133994
       dmuG     float64 1.019
       sd_hist  (lat, lon) float32 1.1514595e-06 1.142341e-06 ... 2.9312864e-06
       sd_futr  (lat, lon) float32 1.3052028e-06 1.2912004e-06 ... 3.6640247e-06
       dsd      (lat, lon) float32 1.1335204 1.1303108 ... 1.2573708 1.2499716
       dsdG     float64 1.03, 'DJF': <xarray.Dataset>
   Dimensions:  (lat: 160, lon: 320)
   Coordinates:
     * lat      (lat) float64 -89.14 -88.03 -86.91 -85.79 ... 86.91 88.03 89.14
     * lon      (lo