In [None]:
import os
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.ensemble import RandomForestClassifier

import rasterio
import rioxarray as rioxr
import geopandas as gpd

import warnings

import planetary_computer as pc

import gc # garbage collector

import dask
import dask_gateway
import dask.array as da
from dask_ml.wrappers import ParallelPostFit

import iceplant_detection_functions as ipf
import model_prep_and_evals as mpe

In [None]:
t_alpha = time.time()

# Import train/test sets

In [None]:
# select features from r (Red band) to avg_lidar // excludes geometry, aoi, naip_id, polygon)id and iceplant features
X_train = pd.read_csv(os.path.join(os.getcwd(),'data_sampling_workflow','train_set.csv')).loc[:,'r':'avg_lidar']

# select iceplant feature column
y_train = pd.read_csv(os.path.join(os.getcwd(),'data_sampling_workflow','train_set.csv')).loc[:,'iceplant'] 

In [None]:
# drop lidar features
X_train = X_train.drop(['lidar','max_lidar','min_lidar','min_max_diff','avg_lidar'],axis=1)
X_train.head(3)

In [None]:
mpe.iceplant_proportions(y_train)

## Train model

In [None]:
t0 = time.time()

rfc = ParallelPostFit(RandomForestClassifier(n_estimators = 100, 
                                             random_state = 42))
rfc.fit(X_train, y_train)

print('time to train: ', (time.time()-t0))

# Select NAIP scene and LIDAR year

In [None]:
campus_itemids = ['ca_m_3411934_sw_11_060_20200521',
                  'ca_m_3411934_sw_11_060_20180722_20190209',
                  'ca_m_3411934_sw_11_.6_20160713_20161004',
                  'ca_m_3411934_sw_11_1_20140601_20141030',
                  'ca_m_3411934_sw_11_1_20120505_20120730']

In [None]:
# **************************************************************
# ********* SPECIFY ITEMID AND LIDAR YEAR TO MATCH HERE ********

itemid = campus_itemids[0] # NAIP scene over Conception Point
year = 2020

save_raster = False
filename = 'SPECTRAL_campus_'+str(year)+'_predictions.tif'

plot_predictions = True
graph_title = "SPECTRAL PREDICTIONS : "+str(year)+" campus point NAIP scene"

# **************************************************************
# **************************************************************

# Pre-process NAIP scene for prediction

In [None]:
t0 = time.time()
item = ipf.get_item_from_id(itemid)
print('retrieved itemid. time:', (time.time()-t0))

In [None]:
df = ipf.features_over_aoi(item, 
                           ipf.get_raster_from_item(item).read([1,2,3,4]), 
                           thresh=0.05)
df.head(3)

This array is REALLU BIG. Mostly because the NDVI feature is a float and is making the whole array into a float. Maybe rescale the NDVI to make it uint16? Not sure if this would affect the model's predictions.

# Make dask data array of pixel values from NAIP scene raster

In [None]:
da_pixels = da.from_array(np.array(df), chunks=728802)
da_pixels

# Predict using dask

In [None]:
cluster = dask_gateway.GatewayCluster()
client = cluster.get_client()
#cluster.adapt(minimum=4, maximum=50)
cluster.scale(30)
client

In [None]:
scene_preds = rfc.predict(da_pixels)
scene_preds

In [None]:
t0 = time.time()
preds = scene_preds.compute()
print((time.time()-t0))

# Convert predictions back to image

In [None]:
shape = item.properties['proj:shape']
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    
    nrows = shape[0]
    ncols = shape[1]
#    index = predictions_df[predictions_df.prediction == 1].index.to_numpy()
    preds = ipf.preds_to_image_3labels(nrows, ncols, df.index, preds)
    print('converted back to image')

## Plot predictions if required

In [None]:
if plot_predictions == True:
    fig, ax = plt.subplots(figsize=(15, 15))
    plt.title(graph_title)
    ax.imshow(preds)
    plt.show()

## Save predictions if required

In [None]:
if save_raster == True:

    rast = ipf.get_raster_from_item(item)
    
    with rasterio.open(
        os.path.join(os.getcwd(),'temp',filename),  # file path
        'w',           # w = write
        driver = 'GTiff', # format
        height = preds.shape[0], 
        width = preds.shape[1],
        count = 1,  # number of raster bands in the dataset
        dtype = rasterio.uint8,
        crs = rast.crs,
        transform = rast.transform,
    ) as dst:
        dst.write(preds.astype(rasterio.uint8), 1)

In [None]:
print('total time:', time.time() - t_alpha )
