In [None]:
%load_ext nb_black

In [1]:
# Import some python libraries
%matplotlib inline

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import intake



In [None]:
# Setup a dask cluster
from dask.distributed import Client
from dask_kubernetes import KubeCluster

cluster = KubeCluster()
cluster.adapt(minimum=16, maximum=64, interval='2s', wait_count=3)

In [None]:
client = Client(cluster)
client

# setup `intake-esm`
intake-esm.readthedocs.io/

In [None]:
col = intake.open_esm_datastore("https://raw.githubusercontent.com/NCAR/intake-esm-datastore/master/catalogs/pangeo-cmip6.json")

In [None]:
variable_id = 'tas'
table_id='Amon'
source_id = 'CESM1-1-CAM5-CMIP5'

# Get hindcast
Boer, G. J., D. M. Smith, C. Cassou, F. Doblas-Reyes, G. Danabasoglu, B. Kirtman, Y. Kushnir, et al. “The Decadal Climate Prediction Project (DCPP) Contribution to CMIP6.” Geosci. Model Dev. 9, no. 10 (October 25, 2016): 3751–77. https://doi.org/10/f89qdf.

In [None]:
cat_cmip = col.search(experiment_id=['dcppA-hindcast'],
                 table_id=table_id,       
                 variable_id=variable_id,
                 source_id=source_id,
                     )
cat_cmip.df['dcpp_init_year'] = cat_cmip.df.dcpp_init_year.astype(int)
#cat_cmip

In [None]:
members = sorted(list(cat_cmip.df.member_id.unique()))[:10]
inits = list(np.arange(1970,2015))

In [None]:
cat_cmip = col.search(experiment_id=['dcppA-hindcast'],
                 table_id=table_id,       
                 variable_id=variable_id,
                 source_id=source_id,
                 member_id = members,
                 dcpp_init_year=inits
                     )
cat_cmip.df['dcpp_init_year'] = cat_cmip.df.dcpp_init_year.astype(int)

In [None]:
def pre(ds,var=variable_id):
    ds['time']=np.arange(1,1+ds.time.size)
    return ds[var].squeeze().to_dataset(name=var)

In [None]:
dset_dict = cat_cmip.to_dataset_dict(zarr_kwargs={'consolidated': True},preprocess=pre,cdf_kwargs={'decode_times':False,
                                                 'chunks':{'time':-1}})
list(dset_dict.keys())

In [None]:
_, hind = dset_dict.popitem()
hind = hind.rename({'member_id':'member','dcpp_init_year':'init','time':'lead'})
hind = hind[variable_id].squeeze()
# throw away first two leads and create annual means
hind=hind.isel(lead=slice(2,None))
hind['lead']=xr.cftime_range(start='2000',freq='MS',periods=hind.lead.size)
hind_ym = hind.groupby('lead.year').mean().rename({'year':'lead'})
hind_ym['lead']=np.arange(1,1+hind_ym.lead.size)
hind_ym.data

## Get historical

In [None]:
cat_cmip_hist = col.search(experiment_id=['historical','ssp45'],
                 table_id=table_id,       
                 variable_id=variable_id,
                 source_id='CESM2', # not exactly CESM-LE
                 )

In [None]:
dset_dict = cat_cmip_hist.to_dataset_dict(zarr_kwargs={'consolidated': True})
dset_dict.keys()

In [None]:
_, hist = dset_dict.popitem()
hist = hist[variable_id].squeeze()
hist = hist.rename({'member_id':'member'})
hist = hist.sel(time=slice('1960','2015'))
hist_ym = hist.groupby('time.year').mean().rename({'year':'time'})
hist_ym.data

## Get assimilation

In [None]:
cat_cmip_hist = col.search(experiment_id=['dcppA-assim'],
                 table_id=table_id,       
                 variable_id=variable_id,
                 source_id=source_id
                          )

In [None]:
if len(cat_cmip_hist.df.source_id.unique()) >= 1:
    dset_dict = cat_cmip_hist.to_dataset_dict(zarr_kwargs={'consolidated': True})
    print(dset_dict.keys())
    _, assim = dset_dict.popitem()
    assim = assim[variable_id].squeeze()
    assim = assim.rename({'member_id':'member'})
    assim = assim.sel(time=slice('1960','2015'))
    assim_ym = assim.groupby('time.year').mean().rename({'year':'time'})
    display(assim_ym.data)
else:
    print(f'no assimilation found for {source_id} {table_id} {variable_id}')

## Get observations

In [3]:
!wget https://crudata.uea.ac.uk/cru/data/temperature/HadCRUT.4.6.0.0.median.nc

--2020-02-18 21:22:07--  https://crudata.uea.ac.uk/cru/data/temperature/HadCRUT.4.6.0.0.median.nc
Aufl"osen des Hostnamens crudata.uea.ac.uk (crudata.uea.ac.uk)... 139.222.133.100
Verbindungsaufbau zu crudata.uea.ac.uk (crudata.uea.ac.uk)|139.222.133.100|:443 ... verbunden.
HTTP-Anforderung gesendet, auf Antwort wird gewartet ... 200 OK
L"ange: 21468796 (20M) [application/x-netcdf]
Wird in >>HadCRUT.4.6.0.0.median.nc<< gespeichert.

dCRUT.4.6.0.0.media   6%[>                   ]   1.33M   339KB/s    ETA 66s    ^C


In [None]:
# weird outcomes
!wget https://www.metoffice.gov.uk/hadobs/hadisst/data/HadISST_sst.nc.gz
!gunzip -k HadISST_sst.nc.gz

In [None]:
if variable_id == 'tas':
    obs = xr.open_dataset('HadCRUT.4.6.0.0.median.nc')
    obs = obs.rename({'latitude':'lat','longitude':'lon','temperature_anomaly':variable_id})[variable_id]
elif variable_id == 'tos':
    obs = xr.open_dataset('HadISST_sst.nc')
    obs = obs.rename({'latitude':'lat','longitude':'lon','temperature_anomaly':variable_id})[variable_id]

In [None]:
obs = obs.sel(time=slice('1960',None))
obs_ym = obs.groupby('time.year').mean().rename({'year':'time'})

# detrend

In [None]:
from climpred.stats import rm_poly
order = 2
detrend = False

In [None]:
if detrend:
    hind_ym = rm_pol(hind_ym,'init')
    hist_ym = rm_pol(hist_ym,'init')
    obs_ym = rm_pol(obs_ym,'init')

# regrid

In [None]:
import xesmf as xe

def regrid(ds, deg=5):
    ds_out = xe.util.grid_global(deg,deg)
    regridder = xe.Regridder(ds, ds_out, 'bilinear')
    ds_out = regridder(ds)
    return ds_out

In [None]:
# fails when lazy
hind_ym_regridded = regrid(hind_ym.load())

In [None]:
# fails when lazy
hist_ym_regridded = regrid(hist_ym.load())

In [None]:
obs_ym_regridded = regrid(obs_ym)

#### check the inputs 👀

In [None]:
hind_ym_regridded.isel(member=0,lead=2,init=2).plot()

In [None]:
hist_ym_regridded.isel(member=0,time=2).plot()

In [None]:
obs_ym_regridded.isel(time=2).plot()

## Skill

In [None]:
from climpred.prediction import compute_hindcast

In [None]:
cp_kwargs = {'metric': 'acc', 'comparison':'e2r'}

In [None]:
%time skill = compute_hindcast(hind_ym_regridded, obs_ym_regridded, **cp_kwargs)

In [None]:
import cartopy.crs as ccrs

def plot_skill(skill,map_proj=ccrs.PlateCarree(),**plot_kwargs):
    p=skill.plot(col='lead', col_wrap=5, robust=True,
                 transform=ccrs.PlateCarree(),  # the data's projection
                 aspect=skill["lon"].size / skill["lat"].size,  # for a sensible figsize
                 subplot_kws={"projection": map_proj},  # the plot's projection
                **plot_kwargs)
    #for ax in p.axes.flat:
    #    ax.coastlines()
    return p

In [None]:
plot_skill(skill)

In [None]:
import dask
if dask.is_dask_collection(hind_ym_regridded):
    %time skill = compute_hindcast(hind_ym_regridded, obs_ym_regridded, **cp_kwargs)
    display(skill.data)
    %time skillc = skill.compute()

## Bootstrap significant skill

In [None]:
from climpred.bootstrap import bootstrap_hindcast

In [None]:
bootstrap=100

In [None]:
%time bskill = bootstrap_hindcast(hind_ym_regridded, hist_ym_regridded, obs_ym_regridded, bootstrap=bootstrap, **cp_kwargs)

In [None]:
improved_by_init = bskill.sel(results='skill',kind='init').where(bskill.sel(results='p',kind='uninit') <= 0.05)
plot_skill(improved_by_init)

#### correct for FDR

In [None]:
from esmtools.testing import multipletests
_, bskill_fdr_corr_p = multipletests(
    bskill.sel(kind="uninit", results="p"), method="fdr_bh", alpha=.05
)

improved_by_init_corr = bskill.sel(results='skill',kind='init').where(bskill_fdr_corr_p <= .05)
plot_skill(improved_by_init_corr)

#### lazily with `dask`
when chunk by lead ~ 100mb

In [None]:
hind_ym_regridded = hind_ym_regridded.chunk({'lead':-1}).persist()
hind_ym_regridded.data

In [None]:
hist_ym_regridded = hist_ym_regridded.chunk({'member':-1}).persist()
obs_ym_regridded = obs_ym_regridded.chunk({'time':-1}).persist()

In [None]:
%time bskill = bootstrap_hindcast(hind_ym_regridded, hist_ym_regridded, obs_ym_regridded, bootstrap=bootstrap, **cp_kwargs)

In [None]:
bskill.data

In [None]:
%time bskillc = bskill.compute()

#### Close down

In [None]:
regridder.clean_weight_file()

In [None]:
client.close()
cluster.close()