# Multi-model analysis

This notebook shows a simple example of multi-model analysis. 

As input files, we are going to use the CMIP6 ```tas```(near-surface air temperature) files related to the ```ssp585``` experiment (update of emission-driven RCP8.5 based on SSP5) and respectively produced  by 

- ```CMCC``` from the ```CMCC-CM2-SR5``` global coupled general circulation model
- ```CMCC``` from the ```CMCC-ESM2``` global climate model
- ```NCAR``` from the ```CESM2-WACCM``` global climate model

Let's import the main Python modules and define the filepaths array.

In [None]:
import xarray as xr
import cftime
import datetime
import pandas as pd
import numpy as np
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
from os.path import expanduser
home = expanduser("~")

files=[home+"/data/CMIP6/ScenarioMIP/CMCC/CMCC-CM2-SR5/ssp585/r1i1p1f1/Amon/tas/gn/v20200622/tas_Amon_CMCC-CM2-SR5_ssp585_r1i1p1f1_gn_201501-210012.nc", 
    home+"/data/CMIP6/ScenarioMIP/CMCC/CMCC-ESM2/ssp585/r1i1p1f1/Amon/tas/gn/v20210126/tas_Amon_CMCC-ESM2_ssp585_r1i1p1f1_gn_201501-210012.nc",
    home+"/data/CMIP6/ScenarioMIP/NCAR/CESM2-WACCM/ssp585/r1i1p1f1/Amon/tas/gn/v20200702/tas_Amon_CESM2-WACCM_ssp585_r1i1p1f1_gn_201501-210012.nc"]

Let's load the three datasets.

In [None]:
ds_list = []
for ffile in files:
    ds = xr.open_dataset(ffile)
    ds = ds.assign_coords({'model_id':ds.source_id})
    ds['time'] = pd.to_datetime(ds.time.dt.strftime("%Y-%m"))
    try:
        ds = ds.reset_coords('height',True)
    except:
        pass
    ds_list.append(ds['tas'])
    
print(ds_list)

### Spatial grid

As we can note from the output above, the 3 models use the same spatial grid, so we can directly compare them at the grid cell level without any regridding operation. 

```python
<xarray.DataArray 'tas' (time: 1032, lat: 192, lon: 288)>
[57065472 values with dtype=float32]
Coordinates:
  * time      (time) datetime64[ns] 2015-01-01 2015-02-01 ... 2100-12-01
  * lat       (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0
  * lon       (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
    model_id  <U12 'CMCC-CM2-SR5'

<xarray.DataArray 'tas' (time: 1032, lat: 192, lon: 288)>
[57065472 values with dtype=float32]
Coordinates:
  * time      (time) datetime64[ns] 2015-01-01 2015-02-01 ... 2100-12-01
  * lat       (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0
  * lon       (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
    model_id  <U9 'CMCC-ESM2'

<xarray.DataArray 'tas' (time: 1032, lat: 192, lon: 288)>
[57065472 values with dtype=float32]
Coordinates:
  * lat       (lat) float64 -90.0 -89.06 -88.12 -87.17 ... 88.12 89.06 90.0
  * lon       (lon) float64 0.0 1.25 2.5 3.75 5.0 ... 355.0 356.2 357.5 358.8
  * time      (time) datetime64[ns] 2015-01-01 2015-02-01 ... 2100-12-01
    model_id  <U11 'CESM2-WACCM'
```

Now we can concatenate the three xarray objects along the ```model_id``` dimension.

In [None]:
full_ds = xr.concat([ds_list[0],ds_list[1],ds_list[2]], 'model_id')
full_ds

And plot the three time series for a specific spatial point on a graph

In [None]:
figure(figsize=(15, 8), dpi=80)
full_ds.sel(lat="41.9",lon="12.49",method="nearest").plot(
    x="time", hue="model_id"
)

We can calculate the yearly mean over the time dimension

In [None]:
yearly_means = full_ds.groupby("time.year").mean()
yearly_means

And plot faceted maps showing the results for each of the 3 models over some selected years.

In [None]:
fg = yearly_means.sel(year=slice(np.min(yearly_means.year.values),np.max(yearly_means.year.values),20)).plot(
        x="lat", y="lon", row="model_id",
        col="year",
)

We can also compute the difference between the yearly means related to two of any models (e.g. ```CMCC-CM2-SR5``` and ```CESM2-WACCM```) and plot the result on a map for a specific year.

In [None]:
time_index = 2 # 2017
diff_cmcc_ipsl = np.fabs(yearly_means.isel(model_id=[0,2]).diff('model_id')).isel(year=time_index)

In [None]:
fig = plt.figure(figsize=(10, 5), dpi=100)
p = diff_cmcc_ipsl.plot(
    cmap="GnBu",
    subplot_kws=dict(projection=ccrs.PlateCarree()),
    transform=ccrs.PlateCarree(),
)
p.axes.set_global()
p.axes.set_aspect('auto', adjustable=None)
p.axes.coastlines()
p.axes.gridlines()
plt.title('Absolute difference between CMCC-CM2-SR5 and CESM2-WACCM yearly mean for year '+str(diff_cmcc_ipsl.year.values))

We can also compare the 3 ```yearly mean``` time series for a specific spatial point.

In [None]:
figure(figsize=(15, 8), dpi=80)
yearly_means.sel(lat="41.9",lon="12.49",method="nearest").plot(
    x="year", hue="model_id"
)

As another example, we can simply look at the **Northern Hemisphere** and compute the **average** over the ```(lat,lon)``` dimensions.

In [None]:
north_mean = full_ds.sel(lat=[0., 90.],method='nearest').mean(dim=['lon','lat'])
north_mean

In [None]:
figure(figsize=(15, 8), dpi=80)
north_mean.plot(hue='model_id')

And this is the model mean

In [None]:
figure(figsize=(15, 8), dpi=80)
north_mean.mean(dim='model_id').plot()

And finally we can further aggregate over time to compute the global mean on yearly basis

In [None]:
figure(figsize=(15, 8), dpi=80)
north_mean.groupby("time.year").mean().plot(hue='model_id')