# PLOTS

### Class to center colormap

In [4]:
import matplotlib as mpl
import numpy as np

class MidpointNormalize(mpl.colors.Normalize):
    def __init__(self, vmin, vmax, midpoint=0, clip=False):
        self.midpoint = midpoint
        mpl.colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        normalized_min = max(0, 1 / 2 * (1 - abs((self.midpoint - self.vmin) / (self.midpoint - self.vmax))))
        normalized_max = min(1, 1 / 2 * (1 + abs((self.vmax - self.midpoint) / (self.midpoint - self.vmin))))
        normalized_mid = 0.5
        x, y = [self.vmin, self.midpoint, self.vmax], [normalized_min, normalized_mid, normalized_max]
        return np.ma.masked_array(np.interp(value, x, y))

### Function able to build a standardized plot for Indian ocean

In [2]:
def plot_seasons(
    wrap_lon, ds, wrap_data, 
    var, 
    labeled, title, levels, 
    inst = '', source = '', date = '',
    season = '', anomalie = "", pmip = '',
    mask_land = "", mask_ocean = ""):
    
    '''
    wrap_lon :
    ds :
    wrap_data :
    var : string of variable to observe
    inst : institute which provide the model
    date : period of the simulation
    labeled : string of the variable legend
    title : string of the title
    season : string to express the season
    levels : range du paramètre détude (ex : np.arange(-0.000005,0.000075,0.0000025))
    anomalie : indication si l'on étudie une anomalie d'un paramètre. Si
    '''
    
    import matplotlib
    import numpy
    import os
    import cartopy
    
    
    if os.path.exists('/home/bchaigneau/Stage_LSCE/plot/{}/mean/{}/{}'.format(var, date, anomalie)) == False :
        path = os.path.join(
            '/home/bchaigneau/Stage_LSCE/plot', 
            '{}'.format(var), 
            'mean', 
            '{}'.format(date), 
            '{}'.format(season), 
            '{}'.format(anomalie)
        )
        os.makedirs(path)

    fig = matplotlib.pyplot.figure(figsize=(12, 7))
    ax = matplotlib.pyplot.axes(projection=ccrs.PlateCarree())
    ax.coastlines()
    #ax.gridlines()
    norm = MidpointNormalize(vmin=-4e-7, vmax=4e-7, midpoint=0)
    
    m = matplotlib.pyplot.contourf(
        wrap_lon, 
        ds.lat, 
        wrap_data,
        transform=ccrs.PlateCarree(), 
        cmap = 'twilight_shifted', 
        alpha = 0.9, 
        extend = 'both', 
        norm = norm, 
        levels = levels
    )
    if mask_land == True :
        ax.add_feature(cartopy.feature.NaturalEarthFeature(
            'physical', 
            'land', 
            '110m', 
            edgecolor='black', 
            facecolor='white'
        ))   
    if mask_ocean == True :
        ax.add_feature(cartopy.feature.NaturalEarthFeature(
            'physical', 
            'ocean', 
            '110m', 
            edgecolor = 'black', 
            facecolor = 'white'
        ))
    #ax.set_extent([0, 180, -90, 20], crs=ccrs.PlateCarree())    
    matplotlib.pyplot.colorbar(m,label = labeled)
    matplotlib.pyplot.title(title, size = 15) 
    matplotlib.pyplot.savefig(
        '/home/bchaigneau/Stage_LSCE/plot/{}/mean/{}/{}/{}_{}_{}_{}_{}_{}.jpeg'.format(
            var, 
            date, 
            anomalie, 
            inst, 
            source, 
            var, 
            date, 
            anomalie,
            pmip
        ),
        dpi= 300)
    matplotlib.pyplot.show()   

# SEARCH

### Function able to find the min resolution of a list of xarrays

In [None]:
def res_min(
    models):
    
    '''
    models : list of datasets you want to compare resolution with no time dimension
    '''
    
    import numpy
    
    
    res_lat = 1000
    res_lon = 1000
    for i in range(len(models)):
        if numpy.diff(models[i]["lon"])[0] < res_lon :
            res_lon = numpy.diff(models[i]["lon"])[0]
            
        if numpy.diff(models[i]["lat"])[0] < res_lat :
            res_lat = numpy.diff(models[i]["lat"])[0]
            
    return res_lon, res_lat

### Function to regridde ds

In [11]:
def regrid(
    ds, 
    res_lat, res_lon, 
    lat_min, lat_max, lon_min, lon_max, 
    ignore_degenerate = False):
    
    '''
    ds : Dataset to regrid
    res_lat, res_lon : resolution you want to have
    lat_min, lat_max, lon_min, lon_max : limits of the coordinates
    '''
    
    import xarray
    import numpy
    import xesmf
    
    
    outGrid_highRes = xarray.Dataset(
            {
                "lat": (["lat"], numpy.arange(lat_min, lat_max, res_lat)),
                "lon": (["lon"], numpy.arange(lon_min, lon_max, res_lon)),
            })

    regridde = xesmf.Regridder(
        ds, 
        outGrid_highRes, 
        "bilinear", 
        ignore_degenerate = ignore_degenerate, 
        periodic = True
    )
    ds_inter = regridde(ds)
    
    return ds_inter

### Function which find the piControl associated to the studied model and regrid it

In [1]:
def import_piControl(
    CMIP, 
    inst, v, source, 
    tbl, reg, 
    res_lat, res_lon, 
    min_lat, max_lat, min_lon, max_lon, 
    season = ""):
    
    '''
    CMIP : CMIP catalog to use
    inst : string of the institute ID
    v : string of the variable to search
    source : string of the source_ID
    season : string of the interested season ('winter' or 'summer')
    reg : indicate if you need to regrid (put True of False)
    res_lat, res_lon : resolution you want for the regrid
    min_lat, max_lat, min_lon, max_lon : limits of the coordinates
    '''
    
    import xarray
    import numpy
    import xesmf
    from dask.diagnostics import ProgressBar

    
    if season == 'summer' :

        Y = CMIP.search(
            institution_id= inst, 
            variable_id = v,
            source_id = source,
            table_id = tbl, 
            experiment_id= 'piControl', 
            latest = True, 
            member_id = "r1i1p1f1"
            )
        
        pi_ds = xarray.open_mfdataset(
            list(Y.df["path"]),
            chunks = 500,
            use_cftime=True,
            decode_cf=True
            )
        pi_summer = pi_ds.where((
            (pi_ds.coords["time.month"] > 9) 
            | (pi_ds.coords['time.month'] < 4)),
            drop=True
            )
        pi_moy_summer = pi_summer.mean("time")
        
        # Regrid :
        if reg == True :
            outGrid_highRes = xarray.Dataset(
                {
                    "lat": (["lat"], numpy.arange(min_lat, max_lat, res_lat)),
                    "lon": (["lon"], numpy.arange(min_lon, max_lon, res_lon)),
                })
            regridde = xesmf.Regridder(
                pi_moy_summer, 
                outGrid_highRes, 
                "bilinear", 
                periodic = True
                )
            pi_moy_summer = regridde(pi_moy_summer)   
        pi_moy_ds = pi_moy_summer.where((
            (pi_moy_summer.coords['lat'] < 25) 
                & (pi_moy_summer.coords['lon'] < 180)),
            drop = True
            )
        print('Creation of the df : {} - {} - piControl - {} :'.format(inst, v, season))
        with ProgressBar():
            pi_moy=pi_moy_ds.compute()
      
    
    
    elif season == 'winter' :
        Y = CMIP.search(
            institution_id= inst, 
            variable_id = v,
            source_id = source,
            table_id = tbl, 
            experiment_id= 'piControl', 
            latest = True, 
            member_id = "r1i1p1f1"
            )
        pi_ds = xarray.open_mfdataset(
            list(Y.df["path"]),
            chunks = 500,
            use_cftime=True,
            decode_cf=True
            )
        pi_winter = pi_ds.where((
            (pi_ds.coords["time.month"] > 3) 
                | (pi_ds.coords['time.month'] < 10)),
            drop=True
            )
        pi_moy_winter = pi_winter.mean("time")
        
        # Regrid :
        if reg == True :
            outGrid_highRes = xarray.Dataset(
                {
                    "lat": (["lat"], numpy.arange(-90, 90, res_lat)),
                    "lon": (["lon"], numpy.arange(0, 360, res_lon)),
                })
            regridde = xesmf.Regridder(
                pi_moy_winter, 
                outGrid_highRes, 
                "bilinear", 
                periodic = True
                )
            pi_moy_winter = regridde(pi_moy_winter)        
        pi_moy_ds = pi_moy_winter.where((
            (pi_moy_winter.coords['lat'] < 25) 
                & (pi_moy_winter.coords['lon'] < 180)),
            drop = True
            )
        print('Creation of the df : {} - {} - piControl - {} :'.format(inst, v, season))
        with ProgressBar():
            pi_moy=pi_moy_ds.compute()
        
        
        
    elif season == "" :
        Y = CMIP.search(
            institution_id= inst, 
            variable_id = v,
            source_id = source,
            table_id = tbl,
            experiment_id= 'piControl', 
            latest = True, 
            member_id = "r1i1p1f1"
            )
        pi_ds = xarray.open_mfdataset(
            list(Y.df["path"]),
            chunks = 500,
            use_cftime=True,
            decode_cf=True
            )
        pi_moy = pi_ds.mean("time")
        
        # Regrid :
        if reg == True :
            outGrid_highRes = xarray.Dataset(
                {
                    "lat": (["lat"], numpy.arange(-90, 90, res_lat)),
                    "lon": (["lon"], numpy.arange(0, 360, res_lon)),
                })
            regridde = xesmf.Regridder(pi_moy, outGrid_highRes, "bilinear", periodic = True)
            pi_moy = regridde(pi_moy)        
        # pi_moy_ds = pi_moy.where(((pi_moy_winter.coords['lat'] < 25) & (pi_moy_winter.coords['lon'] < 180)), drop = True)
        print('Creation of the df : {} - {} - piControl - {} :'.format(inst, v, season))
        with ProgressBar():
            pi_moy=pi_moy.compute()

        
    else :
        print("Wrong season passed. Ask 'winter' or 'summer'")
        
    return pi_moy


# TRANSFORM

### Function to apply a continental mask to any xarray dataset

In [7]:
def mask(
    used_ds, 
    land_ds, 
    used_var):
    
    '''
    used_ds : xarray dataset to restrict only the oceans
    land_ds : xarray dataset containing only land data
    ued_var : string of the variable which is considered to be masked
    '''
    
    import numpy
    
    for i in range(len(used_ds['lat'][:])):
        for j in range(len(used_ds['lon'][:])):
            
            if numpy.isnan(float(land_ds['ra'][i,j].values)) != True:
                used_ds[used_var][i,j] = numpy.nan
                
    return used_ds

In [3]:
#could use ra for high resolution or gpp low resolution

def mask2(
    used_ds,
    land_ds,
    used_var):

    import numpy

    for i in range(len(used_ds['lat'][:])):
        for j in range(len(used_ds['lon'][:])):
            x = used_ds['lat'][i].values
            y = used_ds['lon'][j].values
            clst_lat = int(numpy.abs(x - land_ds['lat']).argmin())
            clst_lon = int(numpy.abs(y - land_ds['lon']).argmin())


            if numpy.isnan(float(land_ds['gpp'][clst_lat, clst_lon])) == False :
                used_ds[i,j] = numpy.nan
                
    return used_ds

# PMIP3

In [1]:
def PMIP3_dic() :
    dir_pmip3 = '/bdd/PMIP3/output/'

    pmip3 = {
        'CNRM-CM5': {
            'institute': 'CNRM-CERFACS',
            'period': ['lgm', 'midHolocene', 'piControl']},
        
        'COSMOS-ASO': {
            'institute': 'FUB', 
            'period': ['lgm', 'piControl']},
        
        'IPSL-CM5A-LR': {
            'institute': 'IPSL', 
            'period': ['lgm', 'midHolocene', 'piControl']},
        
        'FGOALS-g2': {
            'institute': 'LASG-CESS',
            'period': ['lgm', 'midHolocene', 'piControl']},
        
        'MIROC-ESM': {
            'institute': 'MIROC', 
            'period': ['lgm', 'midHolocene', 'piControl']},
        
        'MPI-ESM-P': {
            'institute': 'MPI-M', 
            'period': ['lgm', 'midHolocene', 'piControl']},
        
        'MRI-CGCM3': {
            'institute': 'MRI',
            'period': ['lgm', 'midHolocene', 'piControl']},
        
        'GISS-E2-R': {
            'institute': 'NASA-GISS', 
            'period': ['lgm', 'midHolocene', 'piControl']},
        
        'CCSM4': {
            'institute': 'NCAR',
            'period': ['lgm', 'midHolocene', 'piControl']}
        }
    for model in pmip3 :
        pmip3[model]['rr'] = 'r1i1p1'
    pmip3['GISS-E2-R']['rr'] = 'r1i1p150'

    
#  
# "bcc-csm1-1': {'institute': 'BCC','period': ['midHolocene', 'piControl']}
#         'FGOALS-s2': { 'institute': 'LASG-IAP',    'period': ['midHolocene', 'piControl']},
#'CSIRO-Mk3-6-0': {'institute': 'CSIRO-QCCCE', 'period': ['midHolocene', 'piControl']},
# 'KCM1-2-2': {'institute': 'CAU-GEOMAR', 'period': ['1pctCO2', 'midHolocene', 'piControl']}, # pas de Aclim & pas de lgm
# 'HadCM3': {'institute': 'UOED', 'period': ['past1000', 'piControl']}}, # pas de Aclim & pas de lgm
# 'EC-EARTH-2-2': {'institute': 'ICHEC', 'period': ['midHolocene', 'piControl']}, # pas de Aclim & pas de lgm
# 'CSIRO-Mk3L-1-2': {'institute': 'UNSW', 'period': ['1pctCO2', 'midHolocene', 'past1000', 'piControl']}

    return dir_pmip3, pmip3



def var_dic() :
    var_list = ['ta', 'tas', 'ts', 'hurs', 'evspsbl', 'pr', 'ps']
    var_def = {var: {'realm': 'atmos', 'clim': 'Aclim'} for var in var_list}
    var_def['tos'] = {'realm': 'ocean', 'clim': 'Oclim'}
    var_def['gpp'] = {'realm': 'land', 'clim': 'Lclim'}

    return var_def



def read_model(model, period, var_def, var, dir_pmip3):
    print('Reading ' + var + ' ' + model + ' for ' + period)
    rr = 'r1i1p1' if period == 'piControl' else pmip3[model]['rr']
    institute = pmip3[model]['institute']
    realm, clim = var_def[var]['realm'], var_def[var]['clim']
    filedir = dir_pmip3 + institute + '/' + model + '/' + period + '/monClim/' + realm + '/' + clim + '/' + rr + '/latest/' + var + '/'
    filename = filedir + var + '_' + clim + '_' + model + '_' + period + '_*.nc'

    return filedir, filename



def zonal_reduction(ds_tmp, lon_point, lat_point, lon_radius, lat_radius) :
    mask = (ds_tmp["lat"] > lat_point - lat_radius) & (ds_tmp['lat'] < lat_point + lat_radius) & (ds_tmp['lon'] > lon_point - lon_radius) & (ds_tmp['lon'] < lon_point + lon_radius)
    ds_tmp = ds_tmp.where(mask, drop=True) 
    return ds_tmp

### Pour créer un catalogue personalisé

import intake, intake_esm

CMIP6Cat = intake.open_esm_datastore("/modfs/catalogs/CMIP6.json")

CMIPCat = CMIP6Cat.search(activity_id="CMIP")
CMIPCat
CMIPCat.serialize("CMIPCat")

PMIPCat = CMIP6Cat.search(activity_id="PMIP")
PMIPCat
PMIPCat.serialize("CMIPCat")

### Pour enregistrer un répertoire dans le pc

!zip -r example.zip plot