# xarray tutorial
Modified heavily from [Stephan Hoyer](http://stephanhoyer.com), Rossbypalooza, 2016

-------------------

## xarray basics

In [8]:
# standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr

%matplotlib inline

#### `xarray.Dataset` is like a Python dictionary (of `xarray.DataArray` objects)

We'll use the "air_temperature" tutorial dataset:

In [9]:
ds = xr.tutorial.load_dataset('air_temperature')

In [10]:
ds

<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 62.5 60.0 57.5 55.0 52.5 ...
  * lon      (lon) float32 200.0 202.5 205.0 207.5 210.0 212.5 215.0 217.5 ...
  * time     (time) datetime64[ns] 2013-01-01 2013-01-01T06:00:00 ...
Data variables:
    air      (time, lat, lon) float32 241.2 242.5 243.5 244.0 244.09999 ...
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

In [11]:
ds['air']

<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)>
array([[[241.2    , 242.5    , ..., 235.5    , 238.59999],
        [243.79999, 244.5    , ..., 235.29999, 239.29999],
        ...,
        [295.9    , 296.19998, ..., 295.9    , 295.19998],
        [296.29   , 296.79   , ..., 296.79   , 296.6    ]],

       [[242.09999, 242.7    , ..., 233.59999, 235.79999],
        [243.59999, 244.09999, ..., 232.5    , 235.7    ],
        ...,
        [296.19998, 296.69998, ..., 295.5    , 295.1    ],
        [296.29   , 297.19998, ..., 296.4    , 296.6    ]],

       ...,

       [[245.79   , 244.79   , ..., 243.98999, 244.79   ],
        [249.89   , 249.29   , ..., 242.48999, 244.29   ],
        ...,
        [296.29   , 297.19   , ..., 295.09   , 294.38998],
        [297.79   , 298.38998, ..., 295.49   , 295.19   ]],

       [[245.09   , 244.29   , ..., 241.48999, 241.79   ],
        [249.89   , 249.29   , ..., 240.29   , 241.68999],
        ...,
        [296.09   , 296.88998, ..., 295.69   , 

In [12]:
ds.keys()

KeysView(<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 62.5 60.0 57.5 55.0 52.5 ...
  * lon      (lon) float32 200.0 202.5 205.0 207.5 210.0 212.5 215.0 217.5 ...
  * time     (time) datetime64[ns] 2013-01-01 2013-01-01T06:00:00 ...
Data variables:
    air      (time, lat, lon) float32 241.2 242.5 243.5 244.0 244.09999 ...
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...)

In [13]:
ds.dims

Frozen(SortedKeysDict({'lat': 25, 'time': 2920, 'lon': 53}))

In [14]:
ds.attrs

OrderedDict([('Conventions', 'COARDS'),
             ('title', '4x daily NMC reanalysis (1948)'),
             ('description',
              'Data is from NMC initialized reanalysis\n(4x/day).  These are the 0.9950 sigma level values.'),
             ('platform', 'Model'),
             ('references',
              'http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html')])

In [15]:
ds['air'].identical(ds.air)

True

In [16]:
ds['air'].values

array([[[241.2 , ..., 238.6 ],
        ...,
        [296.29, ..., 296.6 ]],

       ...,

       [[245.09, ..., 241.79],
        ...,
        [297.69, ..., 295.69]]], dtype=float32)

In [17]:
ds['air'].dims

('time', 'lat', 'lon')

In [18]:
ds['air'].attrs

OrderedDict([('long_name', '4xDaily Air temperature at sigma level 995'),
             ('units', 'degK'),
             ('precision', 2),
             ('GRIB_id', 11),
             ('GRIB_name', 'TMP'),
             ('var_desc', 'Air temperature'),
             ('dataset', 'NMC Reanalysis'),
             ('level_desc', 'Surface'),
             ('statistic', 'Individual Obs'),
             ('parent_stat', 'Other'),
             ('actual_range', array([185.16, 322.1 ], dtype=float32))])

#### Reading and writing netCDF

Under the covers, this uses scipy or the [netCDF4-Python](https://github.com/unidata/netcdf4-python) library:

In [19]:
ds.to_netcdf('another-copy-2.nc')

  for k, v in iteritems(variables))


In [20]:
xr.open_dataset('another-copy-2.nc')

<xarray.Dataset>
Dimensions:  (lat: 25, lon: 53, time: 2920)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 62.5 60.0 57.5 55.0 52.5 ...
  * lon      (lon) float32 200.0 202.5 205.0 207.5 210.0 212.5 215.0 217.5 ...
  * time     (time) datetime64[ns] 2013-01-01 2013-01-01T06:00:00 ...
Data variables:
    air      (time, lat, lon) float32 ...
Attributes:
    Conventions:  COARDS
    title:        4x daily NMC reanalysis (1948)
    description:  Data is from NMC initialized reanalysis\n(4x/day).  These a...
    platform:     Model
    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...

### Indexing with named dimensions

In [22]:
ds.air.sel(time='2013-01-01')

<xarray.DataArray 'air' (time: 4, lat: 25, lon: 53)>
array([[[241.2    , 242.5    , ..., 235.5    , 238.59999],
        [243.79999, 244.5    , ..., 235.29999, 239.29999],
        ...,
        [295.9    , 296.19998, ..., 295.9    , 295.19998],
        [296.29   , 296.79   , ..., 296.79   , 296.6    ]],

       [[242.09999, 242.7    , ..., 233.59999, 235.79999],
        [243.59999, 244.09999, ..., 232.5    , 235.7    ],
        ...,
        [296.19998, 296.69998, ..., 295.5    , 295.1    ],
        [296.29   , 297.19998, ..., 296.4    , 296.6    ]],

       [[242.29999, 242.2    , ..., 236.09999, 238.7    ],
        [244.59999, 244.39   , ..., 232.     , 235.7    ],
        ...,
        [296.19998, 296.5    , ..., 296.     , 295.6    ],
        [296.4    , 296.29   , ..., 297.     , 296.79   ]],

       [[241.89   , 241.79999, ..., 235.5    , 237.59999],
        [246.29999, 245.29999, ..., 231.5    , 234.5    ],
        ...,
        [297.     , 297.5    , ..., 296.6    , 296.29   ],
    

In [23]:
ds.air.sel(lat=slice(60, 50), lon=slice(200, 270))

<xarray.DataArray 'air' (time: 2920, lat: 5, lon: 29)>
array([[[273.69998, 273.6    , ..., 246.2    , 246.79999],
        [274.79   , 275.19998, ..., 250.7    , 249.5    ],
        ...,
        [276.69998, 277.4    , ..., 249.59999, 249.39   ],
        [277.29   , 277.4    , ..., 249.89   , 252.29999]],

       [[272.1    , 272.69998, ..., 245.2    , 246.79999],
        [274.     , 274.4    , ..., 248.89   , 248.89   ],
        ...,
        [275.79   , 276.     , ..., 252.     , 251.79999],
        [276.29   , 276.4    , ..., 249.29999, 252.09999]],

       ...,

       [[274.29   , 273.88998, ..., 258.69   , 256.19   ],
        [275.59   , 276.29   , ..., 258.69   , 257.19   ],
        ...,
        [276.79   , 277.29   , ..., 255.39   , 254.18999],
        [277.59   , 278.29   , ..., 254.18999, 254.59   ]],

       [[272.59   , 271.99   , ..., 256.49   , 255.18999],
        [274.29   , 274.49   , ..., 260.29   , 259.49   ],
        ...,
        [276.88998, 277.29   , ..., 258.79   , 2

In [25]:
ds.air.sel(lat=41.8781, lon=360-87.6298, method='nearest', tolerance=5)

<xarray.DataArray 'air' (time: 2920)>
array([268.9    , 264.19998, 261.4    , ..., 255.18999, 254.68999, 257.88998],
      dtype=float32)
Coordinates:
    lat      float32 42.5
    lon      float32 272.5
  * time     (time) datetime64[ns] 2013-01-01 2013-01-01T06:00:00 ...
Attributes:
    long_name:     4xDaily Air temperature at sigma level 995
    units:         degK
    precision:     2
    GRIB_id:       11
    GRIB_name:     TMP
    var_desc:      Air temperature
    dataset:       NMC Reanalysis
    level_desc:    Surface
    statistic:     Individual Obs
    parent_stat:   Other
    actual_range:  [185.16 322.1 ]

## Computation

You can do arithmetic directly on `Dataset` and `DataArray` objects. Labels are preserved, although attributes removed.

In [None]:
2 * ds

You can also apply NumPy "universal functions" like `np.sqrt` to `DataArray` objects:

In [None]:
np.sqrt(ds.air)

xarray also implements standard aggregation functions:

In [None]:
ds.max()

In [None]:
ds.mean(dim='time')

In [None]:
ds.median(dim=['lat', 'lon'])

In [None]:
# maximum air surface temperature over time for New York, (latitude=40.7128, longitude=-74.0060)

ds.sel(lat=40.71728, lon=-74.0060, method='nearest', tolerance=2).max()

Convert the dataset from Kelvin to degrees Celsius and save to a new netCDF file.

Don't forget to fix the temperature units! Recall `degC = degK - 273`.

In [None]:
ds_celsius = ds - 273
ds_celsius.air.attrs['units'] = 'kelvin'
ds_celsius.to_netcdf('temperature-in-kelvin.nc')

----------------------

## xarray comes with built-in plotting, based on matplotlib

In [None]:
ds.mean(['lat', 'lon']).air.plot()

In [None]:
ds.min('time').air.plot()

## Time series

Xarray implements the "split-apply-combine" paradigm with `groupby`. This works really well for calculating climatologies:

In [None]:
ds.groupby('time.season').mean()

In [None]:
clim = ds.groupby('time.month').mean('time')

In [None]:
clim

You can also do arithmetic with groupby objects, which repeats the arithmetic over each group:

In [None]:
anomalies = ds.groupby('time.month') - clim

In [None]:
anomalies

Resample adjusts a time series to a new resolution:

In [None]:
tmin = ds.air.resample('1D', dim='time', how='min')
tmax = ds.air.resample('1D', dim='time', how='max')

In [None]:
tmin

In [None]:
ds_extremes = xr.Dataset({'tmin': tmin, 'tmax': tmax})

In [None]:
ds_extremes

## Pandas

[Pandas](http://pandas.pydata.org) is the best way to work with tabular data (e.g., CSV files) in Python. It's also a highly flexible data analysis tool, with way more functionality than xarray.

In [None]:
df = ds.to_dataframe()

In [None]:
df.head()

Pandas provides very robust tools for reading and writing CSV:

In [None]:
print(df.head(10).to_csv())

Of course, it's just as easy to convert back from pandas:

In [None]:
xr.Dataset.from_dataframe(df)

If you're using pandas 0.18 or newer, you can write `df.to_xarray()`

### Things you can do with pandas

In [None]:
df.describe()

In [None]:
df.sample(10)

Statistical visualization with [Seaborn](https://stanford.edu/~mwaskom/software/seaborn/):

In [None]:
import seaborn as sns

data = (ds_extremes
        .sel_points(lat=[41.8781, 37.7749], lon=[360-87.6298, 360-122.4194],
                    method='nearest', tolerance=3,
                    dim=xr.DataArray(['Chicago', 'San Francisco'],
                                     name='location', dims='location'))
        .to_dataframe()
        .reset_index()
        .assign(month=lambda x: x.time.dt.month))

plt.figure(figsize=(10, 5))
sns.violinplot('month', 'tmax', 'location', data=data, split=True, inner=None)

-------------------

## Exercise

Calculate anomalies for `tmin`. Plot a 2D map of these anomalies for `2014-12-31`.

In [None]:
tmin_clim = tmin.groupby('time.month').mean('time')
tmin_anom = tmin.groupby('time.month') - tmin_clim
tmin_anom.sel(time='2014-12-31').plot()

-------------------

## xarray also works for data that doesn't fit in memory

Here's a quick demo of [how xarray can leverage dask](http://xarray.pydata.org/en/stable/dask.html) to work with data that doesn't fit in memory. This lets xarray substitute for tools like `cdo` and `nco`.

In [None]:
! ls /project/rossby/datasets/Tiffany

In [None]:
! ls /project/rossby/datasets/Tiffany/T925

In [None]:
! rsync -a /project/rossby/datasets/Tiffany/T925 /scratch/local/era-interim

Tell dask we want to use 4 threads (one for each core we have):

In [None]:
import dask
from multiprocessing.pool import ThreadPool

dask.set_options(pool=ThreadPool(4))

Open a bunch of netCDF files from disk using `xarray.open_mfdataset`:

In [None]:
ds = xr.open_mfdataset('/scratch/local/era-interim/T925/*.nc', engine='scipy',
                       chunks={'time': 100, 'latitude': 121, 'longitude': 121})

In [None]:
ds

In [None]:
ds.nbytes * (2 ** -30)

In [None]:
%time ds_seasonal = ds.groupby('time.season').mean('time')

In [None]:
%time ds_seasonal.load()

In [None]:
(ds_seasonal['t']
 .sel(season=['DJF', 'MAM', 'JJA', 'SON'])
 .plot(col='season', size=3, cmap='Spectral_r'))

For more details, read this blog post: http://continuum.io/blog/xray-dask