## Load Packages 

In [None]:
%matplotlib inline

import datacube
import numpy as np
import xarray as xr
from joblib import load
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

import sys
sys.path.insert(1, '../../Tools/')
from dea_tools.datahandling import load_ard
from dea_tools.plotting import rgb, display_map
from dea_tools.bandindices import calculate_indices
from dea_tools.classification import predict_xr
from dea_tools.dask import create_local_dask_cluster

import warnings
warnings.filterwarnings("ignore")


## Set up a dask cluster

In [None]:
create_local_dask_cluster(spare_mem='2Gb')

## Analysis parameters

In [None]:
model_path = 'rf_fmc.pickle'

testing_locations = {
    'Namadgi': (-35.675, 149.0540),  
    'Mallacoota': (-37.5162, 149.6735),
    # 'Geraldton': (-28.850, 114.746),
    # 'Ravensthorpe': (-33.5048, 119.839),
}

buffer = 0.125

output_crs='EPSG:3577'

dask_chunks = {'x': 1000, 'y': 1000}

time= ('2019-11-20', '2019-11-22')

resolution = (-20, 20)

measurements =  ["nbart_blue","nbart_green","nbart_red", 
                 "nbart_red_edge_1","nbart_red_edge_2","nbart_red_edge_3",
                 "nbart_nir_1","nbart_nir_2",
                "nbart_swir_2","nbart_swir_3"
                ]

## Connect to the datacube

In [None]:
dc = datacube.Datacube(app='fmc')

## Import the model

In [None]:
model = load(model_path).set_params(n_jobs=1)

## Set up datacube query


In [None]:
# Generate a new datacube query object
query = {
    'time': time,
    'resolution': resolution,
    'output_crs': output_crs,
    'dask_chunks': dask_chunks,
    'measurements': measurements
}

## Loop through test locations and predict

In [None]:
predictions = []

for key, value in testing_locations.items():

    print('Working on: ' + key)
    
    bounds = {'x': (value[1] - buffer, value[1] + buffer),
              'y': (value[0] + buffer, value[0] - buffer)}
    
    # Update datacube query

    query.update(bounds)

    # Load data and calculate features
    ds = load_ard(dc=dc,
              products=["ga_s2am_ard_3", "ga_s2bm_ard_3"],
              cloud_mask="s2cloudless",
              mask_pixel_quality=True,
              **query)
    
    ds['ndvi']=((ds.nbart_nir_1-ds.nbart_red)/(ds.nbart_nir_1+ds.nbart_red))
    ds['ndii']=((ds.nbart_nir_1-ds.nbart_swir_2)/(ds.nbart_nir_1+ds.nbart_swir_2))
    
    ds = ds[['ndvi','ndii','nbart_red','nbart_green','nbart_blue',
             'nbart_nir_1','nbart_nir_2','nbart_swir_2','nbart_swir_3']]
    
    # Predict using the imported model
    predicted = predict_xr(model,
                           ds,
                           proba=False,
                           persist=False,
                           clean=True,
                           return_input=True).compute()

    predictions.append(predicted)
    break

### 6.- Create colormap consistent with the current Australian Fuel Monitoring System

In [None]:
colors = [(0.87, 0, 0), (1, 1, 0.73), (0.165, 0.615, 0.957)]  # R -> G -> B
cmap = LinearSegmentedColormap.from_list('fmc', colors, N=256)
predicted.Predictions.plot.imshow(figsize=(10,10), cmap=cmap, vmin=0, vmax=136)