In [None]:
"""
A prototype implementation of Use Case 9
"""

import xarray as xr
import os
import numpy as np
from ipywidgets import interact
# Needed for current resampling implementation
from mpl_toolkits import basemap
import time
# Pretty plots on maps
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
# For linear regression calculation
from scipy import stats
%matplotlib inline

In [None]:
# This is just a prototype, hence the hard coded paths
CLOUD_ECV_PATH = '/home/ccitbx/cci_data/cloud/2008/'
AEROSOL_ECV_PATH = '/home/ccitbx/cci_data/aerosol/2008_monthly/'

In [None]:
# Define the functions that the prototype is going to use

def read_data(path):
    """
    Read in multiple netCDF files and combine them in an xarray dataset.

    :rtype: xr.Dataset
    :param path: Path to the folder
    :return: The resulting dataset
    """
    path = path + os.sep + '*.nc'
    print(path)
    dataset = xr.open_mfdataset(path, concat_dim='time')
    return dataset

class Plotter:
    def __init__(self, datasets, extents=None):
        """
        A helper class for interactive plotting.
        This is because interact() can't pass keyword arguments
        that don't define a widget to the desired function.
        
        datasets: a dictionary of datsets, keys have to correspond to names 
        in 'plot_datasets'
        extent: a list [lat,lat,lon,lon] that defines the bounding box to show
        in the plot. If not provided, a global plot is made
        """
        self.datasets = datasets
        self.extents = extents
        
    def plot_datasets(self, names, t):
        """
        A function for interactive plotting.
        """
        array_slice = self.datasets[names].isel(time=t)
        fig=plt.figure(figsize=(16, 8))
        ax = plt.axes(projection=ccrs.PlateCarree())
        
        if self.extents:
            ax.set_extent(self.extents, ccrs.PlateCarree())
        else:
            ax.set_global()
        ax.coastlines()
        array_slice.plot.contourf(ax=ax, transform=ccrs.PlateCarree())
    
    
def resample_slice(slice, grid_lon, grid_lat, order=1):
    """
    Resample a single time slice of a larger xr.DataArray

    :param slice: xr.DataArray single slice
    :param grid_lon: meshgrid of longitudes for the new grid
    :param grid_lat: meshgrid of latitudes for the new grid
    :param order: Interpolation method 0 - nearest neighbour, 1 - bilinear (default), 3 - cubic spline
    :return: xr.DataArray, resampled slice
    """
    result = basemap.interp(slice.values, slice['lon'].data, slice['lat'].data, grid_lon, grid_lat)
    return xr.DataArray(result)


def resample_array(array, lon, lat, order=1):
    """
    Resample the given xr.DataArray to a new grid defined by grid_lat and grid_lon

    :param array: xr.DataArray with lat,lon and time coordinates
    :param lat: 'lat' xr.DataArray attribute for the new grid
    :param lon: 'lon' xr.DataArray attribute for the new grid
    :param order: 0 for nearest-neighbor interpolation, 1 for bilinear interpolation,
    3 for cubic spline (default 1). order=3 requires scipy.ndimage.
    :return: None, changes 'array' in place.
    """
    # Don't do anything if this DataArray has different dims than expected
    if 'time' not in array.dims or 'lat' not in array.dims or 'lon' not in array.dims:
        return array

    #if array.dims != (('time','lat','lon')):
    #    return array

    print(array.name)
    print(array.dims)

    grid_lon, grid_lat = np.meshgrid(lon.data, lat.data)
    kwargs = {'grid_lon':grid_lon, 'grid_lat':grid_lat}
    temp_array = array.groupby('time').apply(resample_slice, **kwargs)
    chunks = list(temp_array.shape[1:])
    chunks.insert(0,1)
    return xr.DataArray(temp_array.values,
                        name = array.name,
                        dims = array.dims,
                        coords = {'time':array.time, 'lat':lat, 'lon':lon},
                        attrs = array.attrs).chunk(chunks=chunks)

In [None]:
# Read in data
ds_clouds = read_data(CLOUD_ECV_PATH)
ds_aerosol = read_data(AEROSOL_ECV_PATH)

In [None]:
# Rename the aerosol dataset's coordinates to correspond with clouds
ds_aerosol.rename({'latitude':'lat', 'longitude':'lon'},inplace=True)

# Select the variables of interest
cc_total = ds_clouds.cc_total
aerosol = ds_aerosol.AOD550_mean

In [None]:
# Regrid the cc_total to the grid used for aerosol
cc_total_resampled = resample_array(cc_total, aerosol['lon'], aerosol['lat'], order=1)

In [None]:
# Define plotting parameters
names = ['clouds', 'aerosol']
ds = {'clouds':cc_total_resampled, 'aerosol':aerosol}

In [None]:
disp = Plotter(ds)
interact(disp.plot_datasets, names = names, t=(0,11,1))

In [None]:
# Select a spatial subset (Africa)
lat_slice = slice(-40., 40.)
lon_slice = slice(-20., 60.)
cc_total_sub = cc_total_resampled.sel(lat=lat_slice,lon=lon_slice)
aerosol_sub = aerosol.sel(lat=lat_slice, lon=lon_slice)

# Select a time subset, feb - sept. Selecting by indices now, as the
# time definition is not uniform for both datasets
time_slice = slice(1,10)
cc_total_sub = cc_total_sub.isel(time=time_slice)
aerosol_sub = aerosol_sub.isel(time=time_slice)


In [None]:
print(cc_total_sub)
print(aerosol_sub)

In [None]:
# Plot the subset
ds = {'clouds':cc_total_sub, 'aerosol':aerosol_sub}

disp = Plotter(ds)
interact(disp.plot_datasets, names = names, t=(0,8))

In [None]:
def plot_scatter(t):
    x = cc_total_sub.isel(time=t).values.flatten()
    y = aerosol_sub.isel(time=t).values.flatten()
    
    # This is to ignore the NaN values in aerosol
    xm = np.ma.masked_array(x,mask=np.isnan(y)).compressed()
    ym = np.ma.masked_array(y,mask=np.isnan(y)).compressed()
    
    # Do a linear regression
    slope, intercept, r_value, p_value, std_err = stats.linregress(xm,ym)
    print("Slope: " + str(slope))
    print("Intercept: " + str(intercept))
    print("R squared: " + str(r_value*r_value))
    print("P value: " + str(p_value))
    print("Standard error of the estimate: " + str(std_err))
    line = slope*x+intercept
    
    fig=plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(111)
    plt.plot(x,y,'.',x,line,'r-')
    ax.set_xlabel('cc_total')
    ax.set_ylabel('AOD550_mean')
    plt.show()

In [None]:
interact(plot_scatter, t=(0,8))