In [18]:


import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import os
import glob
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import pandas as pd
import iris
import iris.quickplot as qplt
import iris.plot as iplt
import json
import cftime
from itertools import product
from cftime import DatetimeNoLeap
import esmvalcore.preprocessor
import xesmf as xe
import warnings
%matplotlib inline
import seaborn as sns
from scipy import stats
from scipy.interpolate import interp1d
from sklearn import metrics
import dask
from tqdm import tqdm
from nc_processing import calc_spatial_mean
from mpl_axes_aligner import align
#sns.set()

from xmip.preprocessing import rename_cmip6
import matplotlib.path as mpath

def read_in(dir, ocean = False):
    files = []
    for x in os.listdir(dir):
        files.append(dir + x)
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        ds = rename_cmip6(xr.open_mfdataset(files, use_cftime=True, engine='netcdf4'))
    return ds

def read_in_ens_mean(dirs, ocean = False, zonal_mean=False, max_ens_members=False):
    """ returns (1) the ensemble mean dataset, and (2) the number of ensemble members """
    
    files = []
    if max_ens_members:
        dirs = dirs[0:max_ens_members]
    for dir in dirs:
        for x in os.listdir(dir):
            if '.nc' in x:
                files.append(dir + x)
    with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        ds = rename_cmip6(xr.open_mfdataset(files, use_cftime=True, concat_dim='ensemble',combine='nested'))
        n_ens = len(ds.ensemble) 
        ds = ds.mean(dim='ensemble')
        if zonal_mean:
            ds = ds.mean(dim='x')
        ds['number_ens_mems_meaned'] = n_ens
    return ds

def get_gmst(ds):
    return calc_spatial_mean(ds.tas.mean(dim="time"), lon_name="x", lat_name="y").values

def get_dirs(var, model, scenario, table='Amon'):
    if scenario == 'ARISE':
        dirs = glob.glob('/badc/deposited2022/arise/data/ARISE/MOHC/UKESM1-0-LL/arise-sai-1p5/*/{t}/{v}/*/*/'.format(
            t=table, v=var))
    else:
        dirs = glob.glob('/badc/cmip6/data/CMIP6/*/*/{m}/{s}/*/{t}/{v}/*/latest/'.format(
            m=model, s=scenario, t=table, v=var))
    weird_jasmin_vars = ['rlds', 'rlus'] 
    if model == 'UKESM1-0-LL': # weird error on opening several files, perhaps corrupted?
        if var == 'rlds':
            dirs = dirs[0:2]
        if var == 'rlus':
            dirs = dirs[1:]
    return dirs


def get_all_vars_zonal_monthly(vars, model, scenario,
                                min_year="2080", max_year="2100"):
    ds_list = []
    for var in tqdm(vars):
        ds = read_in_ens_mean(get_dirs(var, model, scenario, table='Amon'),
                              ocean=False, zonal_mean=True,
                              max_ens_members=5)
        ds_list.append(ds)
    
    DS = xr.merge(ds_list, compat='override')
    
    ## take an arctic spatial mean:
    #DS = DS.sel(y=slice(min_lat, max_lat))
    #weights = np.cos(np.deg2rad(DS.y))
    #weights.name = "weights"
    #DS_weighted = DS.weighted(weights)
    #DS_smean = DS_weighted.mean("y")
    
    ## take late-century monthly time-mean
    DS_stmean = DS.sel(time=slice(min_year, max_year)).groupby("time.month").mean(dim="time")
    
    return DS_stmean

def get_surface_area_north_of_lat(min_lat):
    path = '/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/piControl/r1i1p1f2/fx/areacella/gn/latest/areacella_fx_UKESM1-0-LL_piControl_r1i1p1f2_gn.nc'
    areacella = rename_cmip6(xr.open_dataset(path))
    out = areacella.sel(y=slice(min_lat, 90.1)).sum(dim=['x', 'y']).areacella.values*1
    return out


##### SETTINGS
model = 'UKESM1-0-LL'
scenarios = ['G6sulfur', 'ssp245', 'ssp585']#, 'ARISE']
#min_lat, max_lat = 70, 90
min_year_late_cent, max_year_late_cent = "2050", "2070"

do_preprocess = True

In [11]:
models = ['CESM2-WACCM', 'UKESM1-0-LL',  'CNRM-ESM2-1', 
          'IPSL-CM6A-LR', 'MPI-ESM1-2-HR', 'MPI-ESM1-2-LR']
mm_vars = ['rsds', 'tas']

In [None]:

if do_preprocess:
    for scen in scenarios:
        print(scen)
        ds_list = []
        for model in models:
            print(model)
            ds = get_all_vars_zonal_monthly(vars=mm_vars, model=model,
                                              scenario=scen,
                                              min_year=min_year_late_cent, max_year=max_year_late_cent)
            ds = ds.compute()
            ds['Model'] = model
           
            ds.to_netcdf('Preprocessed_data/rsds/{s}_{m}.nc'.format(
            s=scen, m=model))


ssp245
CESM2-WACCM


  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
100%|██████████| 2/2 [00:06<00:00,  3.26s/it]


UKESM1-0-LL


100%|██████████| 2/2 [00:08<00:00,  4.13s/it]


CNRM-ESM2-1


100%|██████████| 2/2 [00:04<00:00,  2.45s/it]


IPSL-CM6A-LR


100%|██████████| 2/2 [00:05<00:00,  2.70s/it]


MPI-ESM1-2-HR


100%|██████████| 2/2 [00:14<00:00,  7.04s/it]


MPI-ESM1-2-LR


100%|██████████| 2/2 [00:11<00:00,  5.57s/it]


ssp585
CESM2-WACCM


  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
100%|██████████| 2/2 [00:10<00:00,  5.46s/it]


UKESM1-0-LL


100%|██████████| 2/2 [00:11<00:00,  5.69s/it]


CNRM-ESM2-1


100%|██████████| 2/2 [00:05<00:00,  2.64s/it]


IPSL-CM6A-LR


100%|██████████| 2/2 [00:19<00:00,  9.87s/it]


In [None]:
## now plot straight from .nc files:

plot_dict = {0:'ssp585-ssp245',
             1:'G6sulfur-ssp585',
             2:'G6sulfur-ssp245'}


fig, axs = plt.subplots(1, len(ds_dict), figsize = (20, 6), sharey=True)

i=0
for title in ds_dict:
    ds_to_plot = ds_dict[title].mean(dim='month')
    
    ax = axs[i]
    
    ax.plot(ds_to_plot.y.values, ds_to_plot.Q.values, label='AHT convergence')
    ax.plot(ds_to_plot.y.values, ds_to_plot.rsds.values, label='rsds')
    ax.plot(ds_to_plot.y.values, ds_to_plot.rsdscs.values, label='rsdscs')
    #ax.plot(ds_to_plot.y.values, ds_to_plot.rlus.values - ds_to_plot.rlut.values, label='rlus-rlut')
    
    ax.axhline(0, ls='--', color='gray')
    ax.legend()
    ax.set_xlim(-90, 90)
    ax.set_title(title)
    i=i+1
    plt.savefig('Figures/zonal_energy_budget/eg1_multi_scen.png', dpi=350)

In [None]:
## now read in to my ds's
ds_g6sulfur = xr.open_dataset('Preprocessed_data/zonal_energy_budget/{s}_{m}.nc'.format(
        s='G6sulfur', m=model))
ds_g6solar = xr.open_dataset('Preprocessed_data/zonal_energy_budget/{s}_{m}.nc'.format(
        s='G6solar', m=model))
ds_ssp245 = xr.open_dataset('Preprocessed_data/zonal_energy_budget/{s}_{m}.nc'.format(
        s='ssp245', m=model))
ds_ssp585 = xr.open_dataset('Preprocessed_data/zonal_energy_budget/{s}_{m}.nc'.format(
        s='ssp585', m=model))


In [16]:
ds_list

[<xarray.Dataset>
 Dimensions:                 (y: 192, month: 12, nbnd: 2)
 Coordinates:
   * y                       (y) float64 -90.0 -89.06 -88.12 ... 88.12 89.06 90.0
   * month                   (month) int64 1 2 3 4 5 6 7 8 9 10 11 12
 Dimensions without coordinates: nbnd
 Data variables:
     rsds                    (month, y) float32 363.7 361.9 361.5 ... 0.0 0.0 0.0
     lat_bounds              (month, y, nbnd) float64 -90.0 -89.53 ... 89.53 90.0
     lon_bounds              (month, nbnd) float64 178.8 180.0 ... 178.8 180.0
     number_ens_mems_meaned  (month) float64 4.0 4.0 4.0 4.0 ... 4.0 4.0 4.0 4.0
     tas                     (month, y) float32 249.1 248.9 249.2 ... 254.2 254.2
     Model                   <U11 'CESM2-WACCM',
 <xarray.Dataset>
 Dimensions:                 (y: 144, month: 12, bnds: 2)
 Coordinates:
   * y                       (y) float64 -89.38 -88.12 -86.88 ... 88.12 89.38
   * month                   (month) int64 1 2 3 4 5 6 7 8 9 10 11 12
 Dimension