# Set up the notebook by importing and defining functions

In [None]:
% pylab notebook
from datacube.storage import masking
from datetime import datetime
from skimage import exposure
import datacube
from datacube.helpers import ga_pq_fuser
dc = datacube.Datacube(app='crops')

In [None]:
def threeBandImage(ds, bands, time = 0, figsize = [10,10], projection = 'projected'):
    '''
    threeBandImage takes three spectral bands and plots them on the RGB bands of an image. 
    
    Inputs: 
    ds -   Dataset containing the bands to be plotted
    bands - list of three bands to be plotted
    
    Optional:
    time - Index value of the time dimension of ds to be plotted
    figsize - dimensions for the output figure
    projection - options are 'projected' or 'geographic'. To determine if the image is in degrees or northings
    '''
    t, y, x = ds[bands[0]].shape
    rawimg = np.zeros((y,x,3), dtype = np.float32)
    for i, colour in enumerate(bands):
        rawimg[:,:,i] = ds[colour][time].values
    rawimg[rawimg == -999] = np.nan
    img_toshow = exposure.equalize_hist(rawimg, mask = np.isfinite(rawimg))
    fig = plt.figure(figsize = figsize)
    imshow(img_toshow)
    ax = plt.gca()
    ax.set_title(str(ds.time[time].values), fontweight = 'bold', fontsize = 16)
    ax.set_xticklabels(ds.x.values)
    ax.set_yticklabels(ds.y.values)
    if projection == 'geographic':
        ax.set_xlabel('Longitude', fontweight = 'bold')
        ax.set_ylabel('Latitude', fontweight = 'bold')
    else:
        ax.set_xlabel('Eastings', fontweight = 'bold')
        ax.set_ylabel('Northings', fontweight = 'bold')

def threeBandImage_subplots(ds, bands, num_cols, figsize = [10,10], projection = 'projected', left  = 0.125, 
                            right = 0.9, bottom = 0.1, top = 0.9, wspace = 0.2, hspace = 0.4):
    '''
    threeBandImage_subplots takes three spectral bands and multiple time steps, and plots them 
    on the RGB bands of an image. 
    
    Inputs: 
    ds -   Dataset containing the bands to be plotted
    bands - list of three bands to be plotted
    num_cols - number of columns for the subplot
    
    Optional:
    figsize - dimensions for the output figure
    projection - options are 'projected' or 'geographic'. To determine if the image is in degrees or northings
    left  = 0.125  # the space on the left side of the subplots of the figure
    right = 0.9    # the space on the right side of the subplots of the figure
    bottom = 0.1   # the space on the bottom of the subplots of the figure
    top = 0.9      # the space on the top of the subplots of the figure
    wspace = 0.2   # the amount of width reserved for blank space between subplots
    hspace = 0.2   # the amount of height reserved for white space between subplots
    '''
    # Find the number of rows/columns we need, based on the number of time steps in ds
    timesteps = ds.time.size
    num_rows = int(ceil(timesteps/num_cols))
    fig, axes = plt.subplots(num_rows, num_cols, figsize = figsize)
    fig.subplots_adjust(left  = left, right = right, bottom = bottom, top = top, wspace = wspace, hspace = hspace)
    numbers = 0
    try:
        for ax in axes.flat:
            t, y, x = ds[bands[0]].shape
            rawimg = np.zeros((y,x,3), dtype = np.float32)
            for i, colour in enumerate(bands):
                rawimg[:,:,i] = ds[colour][numbers].values
            rawimg[rawimg == -999] = np.nan
            img_toshow = exposure.equalize_hist(rawimg, mask = np.isfinite(rawimg))
            ax.imshow(img_toshow)
            ax.set_title(str(ds.time[numbers].values), fontweight = 'bold')
            ax.set_xticklabels(ds.x.values, fontsize = 8)
            ax.set_yticklabels(ds.y.values, fontsize = 8)
            if projection == 'geographic':
                ax.set_xlabel('Longitude', fontweight = 'bold')
                ax.set_ylabel('Latitude', fontweight = 'bold')
            else:
                ax.set_xlabel('Eastings', fontweight = 'bold')
                ax.set_ylabel('Northings', fontweight = 'bold')
            numbers = numbers + 1
    except IndexError:
        # This error will pop up if there are not enough scenes to fill the number of rows x columns, so we can 
        # safely ignore it
        fig.delaxes(ax)
        plt.draw()   

# Load some data

In [None]:
query = {
        'lat': (-35.51, -35.69),
        'lon': (146.81, 147.19),
        'time':('2017-01-01', '2017-12-31')
        }

test_area = dc.load(product='ls8_nbar_albers', group_by='solar_day', **query)

In [None]:
test_area

In [None]:
threeBandImage_subplots(test_area, bands = ['red', 'green', 'blue'], num_cols = 3, figsize = [10, 30])

# Grab the accompanying pixel quality data

In [None]:
test_area_pq = dc.load(product='ls8_pq_albers', group_by='solar_day', fuse_func = ga_pq_fuser, **query)

In [None]:
test_area_pq

In [None]:
mask_components = {'cloud_acca':'no_cloud',
'cloud_shadow_acca' :'no_cloud_shadow',
'cloud_shadow_fmask' : 'no_cloud_shadow',
'cloud_fmask' :'no_cloud',
'blue_saturated' : False,
'green_saturated' : False,
'red_saturated' : False,
'nir_saturated' : False,
'swir1_saturated' : False,
'swir2_saturated' : False,
'contiguous':True}

In [None]:
quality_mask = masking.make_mask(test_area_pq, **mask_components)

In [None]:
quality_mask

## Check the mask looks ok

In [None]:
quality_mask.isel(time = 3).pixelquality.plot(figsize = [10, 10])

## Apply the mask to the Landsat data

In [None]:
cleaned_data = test_area.where(quality_mask.pixelquality)

In [None]:
cleaned_data

# Load in WOFS for the same location

In [None]:
dcwofs = datacube.Datacube(config='/home/156/cek156/wofscube.conf')

query = {
        'lat': (-35.51, -33.66),
        'lon': (146.81, 151.43)
        }

test_area_wofs = dcwofs.load(product='old_wofs', **query)

In [None]:
test_area_wofs

In [None]:
dcwofs.list_products()