# Single ESM plot to test everything

In [19]:
## general
import numpy as np
import pandas as pd
import xarray as xr
import copy
import numpy.ma as ma
from itertools import compress
from sklearn.externals import joblib
import os
import regionmask
import time 
import geopandas
import cftime

## plotting
import matplotlib.pyplot as plt
from matplotlib.colors import from_levels_and_colors
from matplotlib.lines import Line2D 
import cartopy.crs as ccrs
from matplotlib.colors import LogNorm
import mplotutils as mpu

## statistics
from statsmodels.nonparametric.smoothers_lowess import lowess # lowess filter
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr

# extra stuff for Lea's loading codes
import glob
from datetime import datetime

In [3]:
## Lea's stuff
os.chdir('/home/tristan/mesmer/tools/')
from loading_all_tristan import load_data_single_mod
os.chdir('/home/tristan/mesmer/plots/')

In [4]:
def norm_cos_wgt(lats):
    
    from numpy import cos, deg2rad
    
    return np.cos(np.deg2rad(lats))

In [13]:
def load_data_single_mod(gen,model,Tref_all=True,Tref_start='1870-01-01',Tref_end='1900-01-01',usr_time_res="ann"):

    """ Load the all initial-condition members of a single model in cmip5 or cmip6 for given scenario plus associated historical period.

    Keyword argument:
    gen: generation (cmip5 = 5 and cmip6 = 6 are implemented)
    model: model str
    scenario: scenario str
    Tanglob_idx: decides if wgt Tanglob is computed (and returned) or not, default is not returned
    Tref_all: decides if the Tref at each grid point is dervied based on all available runs or not, default is yes       
    Tref_start: starting point for the reference period with default 1870
    Tref_end: first year to no longer be included in reference period with default 1900

    Output:
    y: the land grid points of the anomalies of the variable on grid centered over 0 longitude (like the srexgrid) 
    time: the time slots
    srex: the gridded srex regions
    df_srex: data frame containing the shape files of the srex regions
    lon_pc: longitudes for pcolormesh (needs 1 more than on grid)
    lat_pc: latitudes for pcolormesh (needs 1 more than on grid)
    idx_l: array with 0 where sea, 1 where land (assumption: land if frac land > 0)
    wgt_l: land grid point weights to derive area weighted mean vars
    Tan_wgt_globmean = area weighted global mean temperature

    """
    print("start with model", model)
        
    var='tas'
    temp_res=usr_time_res
    spatial_res='g025'
        
    ##load in constant files
    dir_data = "/home/tristan/mesmer/data/"
    file_ls = "interim_invariant_lsmask_regrid.nc"
    file_srex = "srex-region-masks_20120709.srex_mask_SREX_masks_all.25deg.time-invariant.nc"
    file_srex_shape = "referenceRegions.shp"
        
    # SREX names ordered according to SREX mask
    srex_names = ['ALA','CGI','WNA','CNA','ENA','CAM','AMZ','NEB','WSA','SSA','NEU','CEU','MED','SAH','WAF','EAF','SAF',
             'NAS','WAS','CAS','TIB','EAS','SAS','SEA','NAU','SAU'] 
        
    # srex_raw nrs from 1-26
    srex_raw = xr.open_mfdataset(dir_data+file_srex, combine='by_coords',decode_times=False) 
    lons, lats = np.meshgrid(srex_raw.lon.values,srex_raw.lat.values) #derive the lat, lon grid
    
    #apply land mask
    frac_l = xr.open_mfdataset(dir_data+file_ls, combine='by_coords',decode_times=False) #land-sea mask
    frac_l_raw = np.squeeze(copy.deepcopy(frac_l.lsm.values))
    frac_l = frac_l.where(frac_l.lat>-60,0)
    idx_l=np.squeeze(frac_l.lsm.values)>0.0 
    
    wgt = norm_cos_wgt(lats) # area weights of each grid point
    wgt_l = (wgt*frac_l_raw)[idx_l] # area weights for land grid points (including taking fraction land into consideration)
    lon_pc, lat_pc = mpu.infer_interval_breaks(frac_l.lon, frac_l.lat) # the lon / lat for the plotting with pcolormesh
    srex=(np.squeeze(srex_raw.srex_mask.values)-1)[idx_l] # srex indices on land
    
    y={}
    T_ref = np.zeros(idx_l.shape)
    run_nrs={}
    
    if gen !== 6:
        print("This is only for gen 6 CMIP6 ESMs!")
    
    else:
        dir_var= '/home/tristan/mesmer/CMIP6/tas/%s/g025/'%usr_time_res  #select the correct directory
        
        run_names_list = sorted(glob.glob(dir_var+var+'_'+temp_res+'_'+model+'_ssp*_'+'r*i1p1f*'+'_'+spatial_res+'.nc'))
        run_names_list_historical=sorted(glob.glob(dir_var+var+'_'+temp_res+'_'+model+'_historical_'+'r*i1p1f*'+'_'+spatial_res+'.nc'))
        
        for run_name in run_names_list:
            run_name_ssp = run_name
            data = xr.open_mfdataset(run_name_ssp,concat_dim='time').sel(time=slice('1870-01-01', '2101-01-01')).roll(lon=72)
            data = data.assign_coords(lon= (((data.lon + 180) % 360) - 180)) #assign coords so same labels as others
            scen = run_name.split('/')[-1].split('_')[-3]
            run = int(run_name.split('/')[-1].split('/')[-2].split('r')[1].split('i')[0])
            
            if scen not in list(y.keys()):
                y[scen]={}
                run_nrs[scen]=[]
                y[scen][run] = data.tas.values  #still absolute values (contains also sea pixels)
                run_nrs[scen].append(run)
                
            else:
                y[scen][run] = data.tas.values
                run_nrs[scen].append(run)
        
        y['historical']={}
        run_nrs['historical']=[]
        for run_name in run_names_list_historical:
            run_name_hist = run_name
            data = xr.open.mfdataset(run_name_hist,concat_dim='time').sel(time=slice('1870-01-01', '2101-01-01')).rol(lon=72)
            run = int(run_name.split('/')[-1].split('_')[-2].split('r')[1].split('i')[0])
            y['historical'][run] = data.tas.values
            run_nrs['historical'].append(run)
        
        #sum up all ref climates
        T_ref += data.tas.sel(time=slice(Tref_start, Tref_end)).mean(dim='time').values*1.0/len(run_names_list_historical)
        y['historical'][run]=y['historical'][run][:,idx_l]
        
        for scen in [i for i in y.keys()]:
            print(scen, run_nrs[scen])
            for run in run_nrs[scen]:
                if Tref_all = True:
                    try:
                        y[scen][run]=(y[scen][run]-T_ref)[:,idx_l]
                    except:
                        y[scen][run]=(y[scen][run]-T_ref[idx_l])
                        print('Exception dealth with, ', scen,y[scen][run].shape)
                else:
                    try:
                        y[scen][run]=y[scen][run][:,idx_l]
                    except:
                        y[scen][run]=y[scen][run]
                        
        if (data.lon!=srex_raw.lon).any() and (srex_raw.lon!=frac_l.lon).any():
            print("There is an error. The grids do not agree.")
        time=data["time"]
        if Tref_all == False:
            return y, run_nrs
        else:
            return y,T_ref, run_nrs

In [17]:
##load in constant files
dir_data = "/home/tristan/mesmer/data/"
file_ls = "interim_invariant_lsmask_regrid.nc"
file_srex = "srex-region-masks_20120709.srex_mask_SREX_masks_all.25deg.time-invariant.nc"
file_srex_shape = "referenceRegions.shp"
        
# SREX names ordered according to SREX mask
srex_names = ['ALA','CGI','WNA','CNA','ENA','CAM','AMZ','NEB','WSA','SSA','NEU','CEU','MED','SAH','WAF','EAF','SAF',
             'NAS','WAS','CAS','TIB','EAS','SAS','SEA','NAU','SAU'] 
        
# srex_raw nrs from 1-26
srex_raw = xr.open_mfdataset(dir_data+file_srex, combine='by_coords',decode_times=False) 
lons, lats = np.meshgrid(srex_raw.lon.values,srex_raw.lat.values)

# apply land mask
frac_l = xr.open_mfdataset(dir_data+file_ls, combine='by_coords',decode_times=False)
frac_l_raw = np.squeeze(copy.deepcopy(frac_l.lsm.values))
frac_l = frac_l.where(frac_l.lat>-60,0)
idx_l=np.squeeze(frac_l.lsm.values)>0.0 

wgt = norm_cos_wgt(lats) # area weights of each grid point
wgt_l = (wgt*frac_l_raw)[idx_l] # area weights for land grid points (including taking fraction land into consideration)
lon_pc, lat_pc = mpu.infer_interval_breaks(frac_l.lon, frac_l.lat) # the lon / lat for the plotting with pcolormesh
srex=(np.squeeze(srex_raw.srex_mask.values)-1)[idx_l] # srex indices on land

y={}
T_ref = np.zeros(idx_l.shape)
run_nrs={}

In [40]:
# start_time = np.datetime64('1850-01-31')
# end_time = np.datetime64('2018-12-31')

start_time = '1850-01-01'
end_time = '2018-01-01'

#########################################################################################
############## good working code for fixing time if necessary ###########################
# ## set correct timestamps
# x = pd.date_range('1850-1-15','2022-1-15', 
#               freq='M').strftime("%Y-%m-%d").to_numpy()
# time = pd.to_datetime(x)
# time

#.assign_coords({"time": time})

##########################################################################################
##########################################################################################

### open data
obs_file='obs_data_25.nc'
ds_obs=xr.open_mfdataset(dir_data+obs_file).rename({'temperature':'tas'})
ds_obs = ds_obs.sel(time=slice(start_time, end_time))
ds_obs

Unnamed: 0,Array,Chunk
Bytes,81.00 kiB,81.00 kiB
Shape,"(72, 144)","(72, 144)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 81.00 kiB 81.00 kiB Shape (72, 144) (72, 144) Count 2 Tasks 1 Chunks Type float64 numpy.ndarray",144  72,

Unnamed: 0,Array,Chunk
Bytes,81.00 kiB,81.00 kiB
Shape,"(72, 144)","(72, 144)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,79.73 MiB,79.73 MiB
Shape,"(2016, 72, 144)","(2016, 72, 144)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 79.73 MiB 79.73 MiB Shape (2016, 72, 144) (2016, 72, 144) Count 3 Tasks 1 Chunks Type float32 numpy.ndarray",144  72  2016,

Unnamed: 0,Array,Chunk
Bytes,79.73 MiB,79.73 MiB
Shape,"(2016, 72, 144)","(2016, 72, 144)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,486.00 kiB,486.00 kiB
Shape,"(12, 72, 144)","(12, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 486.00 kiB 486.00 kiB Shape (12, 72, 144) (12, 72, 144) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",144  72  12,

Unnamed: 0,Array,Chunk
Bytes,486.00 kiB,486.00 kiB
Shape,"(12, 72, 144)","(12, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [41]:
esm_fname = 'tas_mon_MPI-ESM1-2-LR_ssp585_r3i1p1f1_g025.nc'
dir_esms = '/home/tristan/mesmer/CMIP6/tas/mon/g025/'

ds_MPI = xr.open_mfdataset(dir_esms+esm_fname).sel(time=slice('1870-01-01', '2101-01-01')) 
ds_MPI

Unnamed: 0,Array,Chunk
Bytes,81.63 MiB,81.63 MiB
Shape,"(1032, 72, 144)","(1032, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 81.63 MiB 81.63 MiB Shape (1032, 72, 144) (1032, 72, 144) Count 2 Tasks 1 Chunks Type float64 numpy.ndarray",144  72  1032,

Unnamed: 0,Array,Chunk
Bytes,81.63 MiB,81.63 MiB
Shape,"(1032, 72, 144)","(1032, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [92]:
GMST_idx = True
compute_linreg=False

if GMST_idx == True:
    dir_in_data= '/home/tristan/mesmer/data/'
    dir_out_plot = '/home/tristan/mesmer/plots/'
    
start_year_glob='1980-01-15'
start_year_reg='1950-01-15'
start_year_list_glob = [1975,1980,1985,1990,1995,2000]
start_year_list_reg = [1950,1960,1970]
ref_start_obs='1951-01-31'
ref_end_obs='1980-01-31' # cut-off differently defined than in ESMs, is in fact the same ref period
ref_start=datetime.strptime('1951-01-31', '%Y-%m-%d')
ref_end=datetime.strptime('1980-12-31', '%Y-%m-%d')

numpy.datetime64('1850-01-31')

## load in the data

In [21]:
dir_data= '/home/tristan/mesmer/data/'

file_ls = 'interim_invariant_lsmask_regrid.nc' # ERA-interim mask regridded by Richard from 73x144 to 72x144
frac_l = xr.open_mfdataset(dir_data+file_ls)
frac_l = frac_l.where(frac_l.lat>-60,0) # remove Antarctica from frac_l field (ie set frac l to 0)
idx_l=np.squeeze(frac_l.lsm.values)>0.0 # idex land #-> everything >0 I consider land
lons, lats = np.meshgrid(frac_l.lon.values,frac_l.lat.values) # the lon, lat grid (just to derive weights)   
wgt = norm_cos_wgt(lats) # area weights of each grid point

## set correct timestamps
x = pd.date_range('1850-1-15','2022-1-15', 
              freq='M').strftime("%Y-%m-%d").to_numpy()
time = pd.to_datetime(x)
time

### open data
obs_file='obs_data_25.nc'
ds_obs=xr.open_mfdataset(dir_data+obs_file, decode_times=False).rename({'temperature':'tas'}).assign_coords({"time": time})
ds_obs

# ################################################################################################
# ########## only if blended (ocean plus land) is being used: ####################################
# tas=ds_obs.tas.values-T_ref.values # anomalies, ocean included
# Tblendglob=np.zeros(tas.shape[0])
# for t in np.arange(tas.shape[0]):
#     idx_valid = ~np.isnan(tas[t])
#     Tblendglob[t] = np.average(tas[t,idx_valid],weights=wgt[idx_valid]) #area weighted of available obs -> less data available at beginning        

# #################################################################################################
# ############################################### #################################################

# obs_y=(ds_obs.tas.values-T_ref.values)[:,idx_l]
# obs_time=ds_obs.time.values
# Tblendglob=Tblendglob
# obs_time    
# # obs_idx_t_start_glob = np.where(obs_time==start_year_glob)[0][0]
# # obs_idx_t_start_reg = np.where(obs_time['best']==start_year_reg)[0][0]

Unnamed: 0,Array,Chunk
Bytes,81.00 kiB,81.00 kiB
Shape,"(72, 144)","(72, 144)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 81.00 kiB 81.00 kiB Shape (72, 144) (72, 144) Count 2 Tasks 1 Chunks Type float64 numpy.ndarray",144  72,

Unnamed: 0,Array,Chunk
Bytes,81.00 kiB,81.00 kiB
Shape,"(72, 144)","(72, 144)"
Count,2 Tasks,1 Chunks
Type,float64,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,81.63 MiB,81.63 MiB
Shape,"(2064, 72, 144)","(2064, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 81.63 MiB 81.63 MiB Shape (2064, 72, 144) (2064, 72, 144) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",144  72  2064,

Unnamed: 0,Array,Chunk
Bytes,81.63 MiB,81.63 MiB
Shape,"(2064, 72, 144)","(2064, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,486.00 kiB,486.00 kiB
Shape,"(12, 72, 144)","(12, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 486.00 kiB 486.00 kiB Shape (12, 72, 144) (12, 72, 144) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray",144  72  12,

Unnamed: 0,Array,Chunk
Bytes,486.00 kiB,486.00 kiB
Shape,"(12, 72, 144)","(12, 72, 144)"
Count,2 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [None]:
y={}
GMST={}
GSAT={}
time={} # because CAMS-CSM1 has 1 less entry than all other models
for model in models:
    y[model],time[model],srex,srex_names,df_srex,lon_pc,lat_pc,idx_l,wgt_l,GSAT[model],GMST[model]=load_data_single_mod(6,model,'ssp585',GMST_idx=True,Tref_all=True,Tref_start=ref_start,Tref_end=ref_end)


for model in models_cmip5:    
    y[model],time[model],srex,srex_names,df_srex,lon_pc,lat_pc,idx_l,wgt_l,GSAT[model]=load_data_single_mod(5,model,'rcp85',GMST_idx=False,Tref_all=True,Tref_start=ref_start,Tref_end=ref_end)

    
Tanglob_wgt={}
for model in models:
    Tanglob_wgt[model]={}
    for run in y[model].keys():
        if GMST_idx==True:
            Tanglob_wgt[model][run]=GMST[model][run]
        else:
            Tanglob_wgt[model][run]=GSAT[model][run]