![Assimila logo](../helpers/logosmall.png "Assimila")
# Assimila Prise Products
# Analyse Temperature and Anomalies

If the drop down menus do not appear restart the notebook by selecting "Kernel" then "Restart" above.

In [1]:
import os
import sys

sys.path.append('../')
from DQTools.DQTools import Dataset

sys.path.append('../helpers/')
from widgets import Widgets
from data import Data



import matplotlib

import matplotlib.cm as cm
import matplotlib.animation as animation

import matplotlib.cbook as cbook
import matplotlib.image as image

matplotlib.use('nbagg')
import matplotlib.pyplot as plt

import gdal
import numpy as np
import xarray as xr

import ipywidgets as widgets
import pandas as pd
import datetime
import calendar

from IPython.display import clear_output
from IPython.display import display

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
# keyfile = os.path.join(os.path.expanduser("~"), '.assimila_dq.txt')
# product = 'era5'
# subproduct = 't2m'
# region = [533, 59, -36, -21]
# tile = ['era5_africa',]
# ds = Dataset(product, subproduct, region)

In [3]:
# xr = ds.get_data(start=datetime.date(2009, 1, 1), stop=datetime.date(2009, 12, 31))

In [4]:
# print(xr)

In [5]:
def get_chunk_size(filename):
    """
    Extract the block size and raster count so that the
    chunks tuple can be formed, a parameter needed to read
    a dataset using xr.open_rasterio() using DASK.
    :param filename: GDAL valid file.
    :return: tuple raster count, x block size, y block size
    """

    # Extract raster count and block size from file
    d = gdal.Open(filename)
    raster_count = d.RasterCount
    # Get internal block size from first band
    b = d.GetRasterBand(1)
    block_size = b.GetBlockSize()
    chunks = (raster_count, block_size[0], block_size[1])

    return chunks

def get_times_from_file_band(fname):
    """
    Extract time info from band metadata
    """
    d = gdal.Open(fname)

    # Get dataset metadata
    dmd = d.GetMetadata()
    bands = d.RasterCount
    # Empty times list
    times = []

    for band in range(bands):
        b = d.GetRasterBand(band+1)
        # Get band metadata
        md = b.GetMetadata()
        # Get fields with date info
        key = 'time'
        if key in md:
            start_date = md[key]
        elif key in dmd:
            start_date = dmd[key]
        else:
            err_msg = f"File {fname} does not have date information"
            raise Exception(err_msg)

        times.append(np.datetime64(start_date))
#         scale = float(md['scale_factor'])
#         offset = float(md['add_offset'])
        
    return times

def get_data(month):
    # Get monthly ERA5 data
    datadir = '../ERA5/Africa/'

    fname = os.path.join(datadir, f'era5_t2m_era5_africa_2009-{month}.tif')
    
    times = get_times_from_file_band(fname)
    chunks = get_chunk_size(fname)
    data_array = xr.open_rasterio(fname)

    data_array = data_array.rename(
        {'x': 'longitude',
         'y': 'latitude',
         'band': 'time'})

    data_array['time'] = times
#     data_array = scale*data_array + offset
    
    # Get stats 
    stats = xr.zeros_like(data_array)
    _mean = data_array.mean(axis=0)
    _std = data_array.std(axis=0)
    
    # Get standard anomalies
    for i, layer in enumerate(data_array):
        stats.data[i] = (data_array.data[i] - _mean.data) / _std.data
    
    return data_array, stats, times

def make_patch_spines_invisible(ax):
    ax.set_frame_on(True)
    ax.patch.set_visible(False)
    for sp in ax.spines.values():
        sp.set_visible(False)

def analyse_days(tmin, tmax):
    """
    Handle the min, mean and max for DAILY timestep.
    
    :returns: DataArrays for single day [tmin, tmax]
    """
    mean_min = t2m_data.sel(time=slice(tmin, tmax)).min('time')
    mean_mean = t2m_data.sel(time=slice(tmin, tmax)).mean('time')
    mean_max = t2m_data.sel(time=slice(tmin, tmax)).max('time')

    return mean_min, mean_mean, mean_max

def analyse_months(val):
    """
    Handle the min, mean and max for MONTHLY timestep.
    
    :returns: DataArrays for single month based on daily tmin, tmax
    """ 
    _daily_min = []
    _daily_mean = []
    _daily_max = []

    zipped_dates = get_tmin_tmax(val)

    for i, j in zipped_dates: # e.g. for 1 day
        _daily_min.append(t2m_data.sel(time=slice(i, j)).min('time'))
        _daily_mean.append(t2m_data.sel(time=slice(i, j)).mean('time'))
        _daily_max.append(t2m_data.sel(time=slice(i, j)).max('time'))

    daily_min = xr.concat([d for d in _daily_min], dim='time') # 30 daily mins (off hourly data) 
    daily_mean = xr.concat([d for d in _daily_mean], dim='time')
    daily_max = xr.concat([d for d in _daily_max], dim='time')

    mean_min = daily_min.mean('time')
    mean_mean = daily_mean.mean('time')
    mean_max = daily_max.mean('time')
    
    return mean_min, mean_mean, mean_max
    
    
def get_tmin_tmax(val):
    # return a zipped list of start and end daily timestamps
    # in a given month.

    # find last day in a given month.
    last_day = calendar.monthrange(val.year, val.month)[1]
    tmin = []
    tmax = []

    for i in range(last_day):
        i+=1
        tmin.append(datetime.datetime.combine(datetime.date(val.year, val.month, i), datetime.datetime.min.time()))
        tmax.append(datetime.datetime.combine(datetime.date(val.year, val.month, i), datetime.time(23, 0, 0, 0)))
        
    return zip(tmin, tmax)   


def on_dropdown_dates_change(*args):
    global data_list
    """
    Handles change event on dropdown_farms widget
    """

    val = args[0]['new']

    if timestep.value == 'days':
        tmin = datetime.datetime.combine(val, datetime.datetime.min.time())
        tmax = datetime.datetime.combine(val, datetime.time(23, 0, 0, 0))
        mean_min, mean_mean, mean_max = analyse_days(tmin, tmax)
        
    if timestep.value == 'months':
        mean_min, mean_mean, mean_max = analyse_months(val)
        
    if timestep.value == 'years':
        pass

    data_list = [mean_min, mean_mean, mean_max]

#     if change['type'] == 'change' and change['name'] == 'value':
    for i, variable in enumerate(variables):
        ts_plots[i].clear()
        plots[i].clear()

        data_list[i].plot.imshow(ax=plots[i], cmap='RdBu_r', add_colorbar=False)
                                #vmin=vmin_vmax[0], vmax=vmin_vmax[1], add_colorbar=False)

        plots[i].set_title(variable)
        plots[i].set_aspect('equal')
        #colorbars.append(fig.colorbar(tmp_img, ax=plots[i], orientation="horizontal"))
        plots[i].grid()

    fig.canvas.draw()
    plt.show()
        
def on_click(event):
    """
    Event handler
    """
    # Event does not apply for time series plot
    # Check if the click was in a
    if event.inaxes in ts_plots:
        return
    
    for i, ax in enumerate(plots):
        ts_plots[i].clear()

        # Delete last reference point
        if len(ax.lines) > 0:
            del ax.lines[0]

        _ts_mean_min = []
        _ts_mean_mean = []
        _ts_mean_max = []

        if timestep.value == 'days':
            delta = end.value-start.value
            period = [start.value + datetime.timedelta(days=i) for i in range(delta.days)]

            for day in period:
                tmin = datetime.datetime.combine(day, datetime.datetime.min.time())
                tmax = datetime.datetime.combine(day, datetime.time(23, 0, 0, 0))

                x, y, z = analyse_days(tmin, tmax)
                _ts_mean_min.append(x)
                _ts_mean_mean.append(y)
                _ts_mean_max.append(z)

        if timestep.value == 'months':
            tmin_list, tmax_list = [], []
            period = pd.date_range(start.value, end.value, freq='MS')

            for month in period:
                x, y, z = analyse_months(month)
                _ts_mean_min.append(x)
                _ts_mean_mean.append(y)
                _ts_mean_max.append(z)

        # Concatenate list of xarrays for individual timesteps
        ts_mean_min = xr.concat([d for d in _ts_mean_min], dim='time')
        ts_mean_mean = xr.concat([d for d in _ts_mean_mean], dim='time')
        ts_mean_max = xr.concat([d for d in _ts_mean_max], dim='time')

        # Redefine x-coordinate values 
        ts_mean_min.coords['time']  = ('time', period)
        ts_mean_mean.coords['time']  = ('time', period)
        ts_mean_max.coords['time']  = ('time', period)

        ts_data = [ts_mean_min, ts_mean_mean, ts_mean_max]

        # Draw a point as a reference
        ax.plot(event.xdata, event.ydata,
                marker='o', color='red', markersize=7, alpha=0.7)

        _data = ts_data[i].sel(longitude=event.xdata,
                                latitude=event.ydata,
                                method='nearest')

        ts_plots[i].plot(_data.time.data, _data.data, label=f'ERA5 {variables[i]}',
                     color=colors[i], lw=1.5, alpha=0.5)

        ts_plots[i].set_title('')

        ts_plots[i].yaxis.label.set_color(colors[i])
        tkw = dict(size=4, width=1.5)
        ts_plots[i].tick_params(axis='y', colors=colors[i], **tkw)
        ts_plots[i].set_ylabel(variables[i])


    #ts_plots[0].legend(loc='best', fontsize='small',
    #          fancybox=True, framealpha=0.5)

    ts_plots[0].grid()
    ts_plots[-1].spines["right"].set_position(("axes", 1.07))
    make_patch_spines_invisible(ts_plots[-1])
    ts_plots[-1].spines["right"].set_visible(True)
    fig.canvas.draw()

    
    
def on_date_change(*args):
    
    val = args[0]['new']
    dt_times = [pd.to_datetime(d).date() for d in times]
    on_timestep_change()
    
    
def on_timestep_change(*args):
    if args:
        val = args[0]['new']
    else:
        val = timestep.value
    
    dt_times = [pd.to_datetime(d).date() for d in times]
    dates = sorted(list(set([d for d in dt_times if start.value <= d <= end.value])))
    if val == 'days':
        dropdown_dates.options = dates
        
    elif val == 'months':
        _months = pd.date_range(dates[0], dates[-1], freq='MS')
        dropdown_dates.options = [pd.to_datetime(d).date() for d in _months]
        
    elif val == 'years':
        dropdown_dates.options = pd.date_range(dates[0], dates[-1], freq='YS')

In [6]:
def on_button_clicked(b):
    """
    Button event handler
    """    
    for i, ax in enumerate(plots):
        ts_plots[i].clear()

        # Delete last reference point
        if len(ax.lines) > 0:
            del ax.lines[0]

        _ts_mean_min = []
        _ts_mean_mean = []
        _ts_mean_max = []

        if timestep.value == 'days':
            delta = end.value-start.value
            period = [start.value + datetime.timedelta(days=i) for i in range(delta.days)]

            for day in period:
                tmin = datetime.datetime.combine(day, datetime.datetime.min.time())
                tmax = datetime.datetime.combine(day, datetime.time(23, 0, 0, 0))

                x, y, z = analyse_days(tmin, tmax)
                _ts_mean_min.append(x)
                _ts_mean_mean.append(y)
                _ts_mean_max.append(z)

        if timestep.value == 'months':
            tmin_list, tmax_list = [], []
            period = pd.date_range(start.value, end.value, freq='MS')

            for month in period:
                x, y, z = analyse_months(month)
                _ts_mean_min.append(x)
                _ts_mean_mean.append(y)
                _ts_mean_max.append(z)

        # Concatenate list of xarrays for individual timesteps
        ts_mean_min = xr.concat([d for d in _ts_mean_min], dim='time')
        ts_mean_mean = xr.concat([d for d in _ts_mean_mean], dim='time')
        ts_mean_max = xr.concat([d for d in _ts_mean_max], dim='time')

        # Redefine x-coordinate values 
        ts_mean_min.coords['time']  = ('time', period)
        ts_mean_mean.coords['time']  = ('time', period)
        ts_mean_max.coords['time']  = ('time', period)

        ts_data = [ts_mean_min, ts_mean_mean, ts_mean_max]

        # Draw a point as a reference
        ax.plot(lon.value, lat.value,
                marker='o', color='red', markersize=7, alpha=0.7)

        _data = ts_data[i].sel(longitude=lon.value,
                                latitude=lat.value,
                                method='nearest')

        ts_plots[i].plot(_data.time.data, _data.data, label=f'ERA5 {variables[i]}',
                     color=colors[i], lw=1.5, alpha=0.5)

        ts_plots[i].set_title('')

        ts_plots[i].yaxis.label.set_color(colors[i])
        tkw = dict(size=4, width=1.5)
        ts_plots[i].tick_params(axis='y', colors=colors[i], **tkw)
        ts_plots[i].set_ylabel(variables[i])


    #ts_plots[0].legend(loc='best', fontsize='small',
    #          fancybox=True, framealpha=0.5)

    ts_plots[0].grid()
    ts_plots[-1].spines["right"].set_position(("axes", 1.07))
    make_patch_spines_invisible(ts_plots[-1])
    ts_plots[-1].spines["right"].set_visible(True)
    fig.canvas.draw()

In [7]:
############
# Get data #
############

months = ['01', '02', '03', '04', '05', '06',
          '07', '08', '09', '10', '11', '12']

datasets = []
anomalies = []
_times = []

for month in months:
    # Get anomalies and datasets
    d, a, t = get_data(month)
    datasets.append(d*0.01+100)
    anomalies.append(a)
    _times.append(t)

t2m_data = xr.concat([d for d in datasets], dim='time')
_times = [item for sublist in _times for item in sublist]   
_dt_times = [pd.to_datetime(d).date() for d in _times]

times = sorted(list(set([d for d in _dt_times])))

In [8]:
###########
# Widgets #
###########

w = Widgets()

freq = ['days', 'months']#, 'years']

lat = w.latitude()
lon = w.longitude()
button = widgets.Button(description='Get point',
                        disabled=False,
                        button_style='info',
                        tooltip='Click me to plot location data')


start = w.get_date(value=datetime.date(2009,3,1), description='Start')
end = w.get_date(value=datetime.date(2009,4,30), description='End')

dropdown_dates = widgets.Dropdown(options=times,
                                  value=times[0],
                                  description='Date selection',
                                  disabled=False)

timestep = widgets.Dropdown(options=freq,
                            value=freq[0],
                            description='Timestep',
                            disabled=False)

button.on_click(on_button_clicked)
data_list = dropdown_dates.observe(on_dropdown_dates_change, names='value')
start.observe(on_date_change)
end.observe(on_date_change)
timestep.observe(on_timestep_change)

box1 = widgets.HBox([lat, lon, button])
box2 = widgets.HBox([start, end, timestep, dropdown_dates])
box = widgets.VBox([box1, box2])

################
# Plot objects #
################

fig = plt.figure(figsize=(18, 12))

t2m_min = plt.subplot2grid((2, 3), (0, 0), colspan=1)       

t2m_mean = plt.subplot2grid((2, 3), (0, 1), colspan=1,
                                         sharex=t2m_min, sharey=t2m_min)

t2m_max = plt.subplot2grid((2, 3), (0, 2), colspan=1,
                                sharex=t2m_min, sharey=t2m_min)

plots = [t2m_min, t2m_mean, t2m_max]


ts_p = plt.subplot2grid((2, 3), (1, 0), colspan=3)
ts_p1 = ts_p.twinx()
ts_p2 = ts_p.twinx()

ts_plots = [ts_p, ts_p1, ts_p2]

ts_p2.spines["right"].set_position(("axes", 1.1))
make_patch_spines_invisible(ts_p2)
ts_p2.spines["right"].set_visible(True)

# Connect the canvas with the event
cid = fig.canvas.mpl_connect('button_press_event', on_click)

# vmin_vmax = [-2.5, 2.5]
variables = ['t2m_min', 't2m_mean', 't2m_max']
colors = ['red', 'blue', 'black']
data_list = [t2m_data[0], t2m_data[0], t2m_data[0]] # TODO: Initialise with displayed date
colorbars = []
images = []

for i, variable in enumerate(variables):
    tmp_img = data_list[i].plot.imshow(ax=plots[i], cmap='RdBu_r', add_colorbar=False),
#                                          vmin=vmin_vmax[0], vmax=vmin_vmax[1],
#                                          add_colorbar=False)
    
    images.append(tmp_img)
    
    plots[i].set_title(f'{variable}')
    plots[i].set_aspect('equal')
#     colorbars.append(fig.colorbar(tmp_img, ax=plots[i], orientation="horizontal"))
    plots[i].grid()

display(box)
plt.tight_layout()
plt.show()


<IPython.core.display.Javascript object>

VBox(children=(HBox(children=(BoundedFloatText(value=7.5, description='Latitude (y):', layout=Layout(height='a…