In [1]:
import xarray as xr
from dask.diagnostics import ProgressBar
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import cartopy.feature as cf
from cartopy.util import add_cyclic_point
import intake, intake_esm
import os
import sys
from pathlib import Path
from glob import glob
import itertools
import time
import seaborn as sns
import math
import scipy.stats as st
import requests
import pickle

import warnings
warnings.filterwarnings('ignore')

%run functions.ipynb

In [2]:
'''
Les fichiers sont stockés dans deux catalogues :
    - PMIPCat.json : Modèles Paleo
    - CMIPCat.json : Modèles Actuel (avec piControl notamment)
'''

PMIP = intake.open_esm_datastore("PMIPCat.json")

In [None]:
CMIP = intake.open_esm_datastore('CMIPCat.json')

# Données Masque

In [None]:
land = PMIP.search(
    institution_id = 'INM',
    source_id = 'INM-CM4-8',
    table_id = 'Lmon',
    experiment_id = 'lig127k',
    member_id = "r1i1p1f1", 
    latest = True,
    variable_id = 'ra'
    )

land_ds = xr.open_mfdataset(
    list(land.df["path"]), 
    chunks = 200, 
    use_cftime=True,
    decode_cf=True
    )
land_ds = land_ds.mean('time')
with ProgressBar():
    land_ds = land_ds.compute()


#land_ds = land_ds.where((land_ds.coords['lon'] < 150) & (land_ds.coords['lon'] > 20) & (land_ds.coords['lat'] <30) & (land_ds.coords['lat'] > -70), drop = True)

land_ds = land_ds.where(land_ds['ra'] > 0)

In [None]:
land_ds.ra.plot()

# PMIP3 Datas

In [None]:
dir_pmip3, pmip3 = PMIP3_dic()
var_def = var_dic()
period = {
    'ds_pi':'piControl',
    'ds_lgm': 'lgm'
    }
grid = ['gn', 'gr1', 'gr']
res_lat = 1.25
res_lon = 1.875

var = ['gpp', 'pr']

#ds_pmip3 = {'model' : {'inst': 'AWI'
#                       'période' :{ 'gpp' : xarray,
#                                    'pr' : xarray
#                                  }
#                      }}

ds_pmip3 = {}

for model in pmip3:
    file_dir, filename = read_model(
        model, 
        'lgm',
        var_def, 
        'gpp', 
        dir_pmip3
        )
    
    if os.path.exists(file_dir) :   
        ds_inter_gpp = {'institute' : pmip3[model]['institute']}
        ds_inter_pr = {'institute' : pmip3[model]['institute']}
        for k, val in period.items() :
            
            # GPP    
            file_dir, filename = read_model(
                model, 
                val,
                var_def, 
                var[0], 
                dir_pmip3
                )
            ds_gpp = xr.open_mfdataset(filename, chunks = 200)[var[0]]
            ds_gpp = ds_gpp.compute() 
            ds_gpp = ds_gpp.mean(['time'])
            ds_gpp = mask2(ds_gpp, land_ds, var[0])
            ds_gpp = ds_gpp.where((ds_gpp > 0) & (ds_gpp < 1e30)) * 86400 * 365.25/12
            ds_inter_gpp[val] = ds_gpp
            
            # PR    
            file_dir, filename = read_model(
                model, 
                val,
                var_def, 
                var[1], 
                dir_pmip3
                )
            ds_pr = xr.open_mfdataset(filename, chunks = 200)[var[1]]
            ds_pr = ds_pr.compute() 
            ds_pr = ds_pr.mean(['time'])
            ds_pr = mask2(ds_pr, land_ds, var[1])
            ds_pr = ds_pr.where((ds_pr > 0) & (ds_pr < 1e30)) * 86400 * 365.25/12   
            ds_inter_pr[val] = ds_pr 

                    
        ds_pmip3[model] = {
            'gpp' : ds_inter_gpp,
            'pr' : ds_inter_pr
            }

In [None]:
# Enregistrer en fichier pickle
pickle.dump(ds_pmip3, open("datas/pmip3_gpp_pr.p", "wb"))

# Ouvrir le fichier
# ds_pmip3 = pickle.load(open("datas/pmip3_gpp.p", "rb"))

# PMIP4 Datas

In [None]:
inst = PMIP.df['institution_id'].unique()
source = PMIP.df['source_id'].unique()
grid = ['gn', 'gr1', 'gr']
res_lat = 1.25
res_lon = 1.875
dates = ['lgm', 'lig127k']
period = {
    'ds_lig':'lig127k',
    'ds_lgm': 'lgm'
    }
var = 'gpp'
table_id = 'Lmon'


ds_pmip4 = {}

for s in source :
    for i in inst :
        for g in grid :
            
            ds_inter_gpp = {}
        # GPP
        # Pre-industrial Control
            if (os.path.exists(
                '/bdd/CMIP6/CMIP/{}/{}/piControl/r1i1p1f1/Lmon/gpp/{}/latest'.format(i, s, g))) & (
                (os.path.exists('/bdd/CMIP6/PMIP/{}/{}/lgm/r1i1p1f1/Lmon/gpp/{}/latest'.format(i, s, g))) | (os.path.exists('/bdd/CMIP6/PMIP/{}/{}/lig127k/r1i1p1f1/Lmon/gpp/{}/latest'.format(i, s, g)))
                ) :
                Y = CMIP.search(
                    institution_id = '{}'.format(i),
                    variable_id = 'gpp',
                    source_id = '{}'.format(s),
                    table_id = 'Lmon',
                    experiment_id = 'piControl',
                    latest = True,
                    member_id = 'r1i1p1f1'
                    )
                pi_ds = xr.open_mfdataset(
                    list(Y.df['path']),
                    chunks = 200,
                    use_cftime = True, 
                    decode_cf = True
                    )['gpp']
                pi_ds = pi_ds.mean('time')
                print('Computing piControl for : {} - {} - {}'.format(i, s, var))
                with ProgressBar():
                    pi_ds = pi_ds.compute()   
                pi_ds = regrid(
                    pi_ds, 
                    res_lat,
                    res_lon,
                    -90, 90, 0, 360
                    )
                pi_ds = mask2(pi_ds, land_ds, 'gpp')
                pi_ds = pi_ds.where((pi_ds > 0) & (pi_ds < 1e30)) * 86400 * 365.25/12
        # Official Periods
                ds_inter_gpp = {
                    'institute':i,
                    'piControl': pi_ds
                    }
                for d in dates :
                    if os.path.exists('/bdd/CMIP6/PMIP/{}/{}/{}/r1i1p1f1/{}/gpp/{}/latest'.format(i, s, d, table_id, g)) :
                        #if s not in name_dic.keys() : 
                        Z = PMIP.search(
                            institution_id = '{}'.format(i),
                            variable_id = 'gpp', 
                            source_id = '{}'.format(s),
                            table_id = 'Lmon',
                            experiment_id = d,
                            latest = True, 
                            member_id = 'r1i1p1f1',
                            grid_label = g
                            )
                        globals()[d] = xr.open_mfdataset(
                            list(Z.df['path']),
                            chunks = 200, 
                            use_cftime = True,
                            decode_cf = True
                            )['gpp']
                        globals()[d] = globals()[d].mean('time')
                        print('Computing official {} for : {} - {} - gpp'.format(d, i, s))
                        with ProgressBar():
                            globals()[d] = globals()[d].compute() 
                        globals()[d] = regrid(
                            globals()[d],
                            res_lat, 
                            res_lon,
                            -90, 90, 0, 360
                            )
                        globals()[d] = mask2(globals()[d], land_ds, 'gpp')
                        globals()[d] = globals()[d].where((globals()[d] > 0) & (globals()[d] < 1e30)) * 86400 * 365.25/12
              
                        ds_inter_gpp[d] = globals()[d]
                

        # PRECIPITATION
        # Preindustrial Control
                Y = CMIP.search(
                    institution_id = '{}'.format(i),
                    variable_id = 'pr',
                    source_id = '{}'.format(s),
                    table_id = 'Amon',
                    experiment_id = 'piControl',
                    latest = True,
                    member_id = 'r1i1p1f1'
                    )
                pi_ds = xr.open_mfdataset(
                    list(Y.df['path']),
                    chunks = 200,
                    use_cftime = True, 
                    decode_cf = True
                    )['pr']
                pi_ds = pi_ds.mean('time')
                print('Computing piControl for : {} - {} - pr'.format(i, s))
                with ProgressBar():
                    pi_ds = pi_ds.compute()   
                pi_ds = regrid(
                    pi_ds, 
                    res_lat,
                    res_lon,
                    -90, 90, 0, 360
                    )
                pi_ds = mask2(pi_ds, land_ds, 'pr')
                pi_ds = pi_ds.where((pi_ds > 0) & (pi_ds < 1e30)) * 86400 * 365.25/12
        # Official Periods
                ds_inter_pr = {
                    'institute': i,
                    'piControl': pi_ds
                    }
                for d in dates :
                    if os.path.exists('/bdd/CMIP6/PMIP/{}/{}/{}/r1i1p1f1/Amon/pr/{}/latest'.format(i, s, d, g)) :
                        #if s not in name_dic.keys() : 
                        Z = PMIP.search(
                            institution_id = '{}'.format(i),
                            variable_id = 'pr', 
                            source_id = '{}'.format(s),
                            table_id = 'Amon',
                            experiment_id = d,
                            latest = True, 
                            member_id = 'r1i1p1f1',
                            grid_label = g
                            )
                        globals()[d] = xr.open_mfdataset(
                            list(Z.df['path']),
                            chunks = 200, 
                            use_cftime = True,
                            decode_cf = True
                            )['pr']
                        globals()[d] = globals()[d].mean('time')
                        print('Computing official {} for : {} - {} - pr'.format(d, i, s))
                        with ProgressBar():
                            globals()[d] = globals()[d].compute() 
                        globals()[d] = regrid(
                            globals()[d],
                            res_lat, 
                            res_lon,
                            -90, 90, 0, 360
                            )
                        globals()[d] = mask2(globals()[d], land_ds, 'pr')
                        globals()[d] = globals()[d].where((globals()[d] > 0) & (globals()[d] < 1e30)) * 86400 * 365.25/12
              
                        ds_inter_pr[d] = globals()[d] 
        
                ds_pmip4[s] = {
                    'gpp': ds_inter_gpp,
                    'pr': ds_inter_pr
                    }


In [None]:
pickle.dump(ds_pmip4, open("datas/pmip4_gpp_pr.p", "wb"))

# PLOT DATAS

In [None]:
ds_pmip3 = pickle.load(open("datas/pmip3_gpp_pr.p", "rb"))
ds_pmip4 = pickle.load(open("datas/pmip4_gpp_pr.p", "rb"))

### PMIP3

In [None]:
ds = ['lgm_gpp', 'pi_gpp', 'lgm_pr', 'pi_pr']
levels = np.arange(-4e-7, 4e-7, 1e-8)

for k, val in ds_pmip3.items():
    
    lgm_gpp = ds_pmip3[k]['gpp']['lgm']
    pi_gpp = ds_pmip3[k]['gpp']['piControl']
    lgm_pr = ds_pmip3[k]['pr']['lgm']
    pi_pr = ds_pmip3[k]['pr']['piControl']
    
    for i in ds :
        globals()[i]['lat'] = globals()[i]['lat'].round(3)
    
    weight_lgm = lgm_gpp * (lgm_pr/lgm_pr.sum()) / lgm_gpp.sum() 
    weight_pi = pi_gpp * (pi_pr/pi_pr.sum()) / pi_gpp.sum()
    
    ds_anom = weight_lgm - weight_pi

    plot_seasons(
        ds_anom['lon'],
        ds_anom,
        ds_anom,
        'gpp_weighted',
        '',
        'Anomalie of weighted GPP (by precipitation), LGM - piControl\nsimulated by {} - {} [PMIP3]'.format(k, ds_pmip3[k]['gpp']['institute']), 
        levels,
        inst = ds_pmip3[k]['gpp']['institute'],
        source = k, 
        date = 'lgm',
        anomalie = 'anomalie',
        pmip = 'PMIP3'
        )    


### PMIP4

In [None]:
ds = ['lgm_gpp', 'pi_gpp', 'lgm_pr', 'pi_pr']
levels = np.arange(-4e-7, 4e-7, 1e-8)

for k, val in ds_pmip4.items():
    if 'lgm' in ds_pmip4[k]['gpp'].keys() :

        lgm_gpp = ds_pmip4[k]['gpp']['lgm']
        pi_gpp = ds_pmip4[k]['gpp']['piControl']
        lgm_pr = ds_pmip4[k]['pr']['lgm']
        pi_pr = ds_pmip4[k]['pr']['piControl']

        for i in ds :
            globals()[i]['lat'] = globals()[i]['lat'].round(3)

        weight_lgm = lgm_gpp * (lgm_pr/lgm_pr.sum()) / lgm_gpp.sum() 
        weight_pi = pi_gpp * (pi_pr/pi_pr.sum()) / pi_gpp.sum()

        ds_anom = weight_lgm - weight_pi

        plot_seasons(
            ds_anom['lon'],
            ds_anom,
            ds_anom,
            'gpp_weighted',
            '',
            'Anomalie of weighted GPP (by precipitation), LGM - piControl\nsimulated by {} - {} [PMIP4]'.format(k, ds_pmip4[k]['gpp']['institute']), 
            levels,
            inst = ds_pmip4[k]['gpp']['institute'],
            source = k, 
            date = 'lgm',
            anomalie = 'anomalie',
            pmip = 'PMIP4'
            )    

        
    if 'lig127k' in ds_pmip4[k]['gpp'].keys() :
        
        lig_gpp = ds_pmip4[k]['gpp']['lig127k']
        pi_gpp = ds_pmip4[k]['gpp']['piControl']
        lig_pr = ds_pmip4[k]['pr']['lig127k']
        pi_pr = ds_pmip4[k]['pr']['piControl']

        for i in ds :
            globals()[i]['lat'] = globals()[i]['lat'].round(3)

        weight_lig = lig_gpp * (lig_pr/lig_pr.sum()) / lig_gpp.sum() 
        weight_pi = pi_gpp * (pi_pr/pi_pr.sum()) / pi_gpp.sum()

        ds_anom = weight_lig - weight_pi

        plot_seasons(
            ds_anom['lon'],
            ds_anom,
            ds_anom,
            'gpp_weighted',
            '',
            'Anomalie of weighted GPP (by precipitation), LIG127k - piControl\nsimulated by {} - {} [PMIP4]'.format(k, ds_pmip4[k]['gpp']['institute']), 
            levels,
            inst = ds_pmip4[k]['gpp']['institute'],
            source = k, 
            date = 'lig127k',
            anomalie = 'anomalie',
            pmip = 'PMIP4'
            )    