# <font color='blue'> Random forest classification model </font>

## Read input data, including ground truth data and images

In [None]:
# Import GDAL, NumPy, and matplotlib
from osgeo import gdal, gdal_array
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Tell GDAL to throw Python exceptions, and register all drivers
gdal.UseExceptions()
gdal.AllRegister()

# Read in the image and ground data recalled as roi
img_ds = gdal.Open(r"path", gdal.GA_ReadOnly)
roi_ds = gdal.Open(r"path", gdal.GA_ReadOnly)

# gdal_array to read raster data as numeric array from file
img = np.zeros((img_ds.RasterYSize, img_ds.RasterXSize, img_ds.RasterCount),
              gdal_array.GDALTypeCodeToNumericTypeCode(img_ds.GetRasterBand(1).DataType)) 

for b in range (img.shape[2]):
    img[:, :, b] = img_ds.GetRasterBand(b+1).ReadAsArray()

roi = roi_ds.GetRasterBand(1).ReadAsArray().astype(np.int16)


# Display one band of image and ground truth (roi) data
plt.subplot(121)
plt.imshow(img[:, :, -2], cmap=plt.cm.Greys_r, vmin=-2000, vmax=10000) # vmin and vmax is defined based on your data
plt.title('Input Data')

plt.subplot(122)
plt.imshow(roi, cmap=plt.cm.Spectral)
plt.title("Ground Truth Data")

plt.show()

## Using Skicit-learn to split ground truth and image data into training and testing sets


In [None]:
from sklearn.model_selection import train_test_split

# Find how many valid entries are in the roi data -- i.e. how many ground truth data samples? in this data zero entries refer
# to non-valid enteries. 
n_samples = (roi > 0).sum()
print('There are {n} samples'.format(n=n_samples))

# What are the classification targets? i.e. classvalue 1 refers to irrigated areas and classvalue 627 refers to irrigated pixels 
targets = np.unique(roi[roi > 0])
print('The training data include {n} classes: {classes}'.format(n=targets.size, 
                                                                classes=targets))

# features : used to make predictions (e.g. the image inputs (img_ds))
# labels :  to be predicted (e.g. roi ground truth data)
features = img[roi > 0, :]  # include all bands
labels = roi[roi > 0]

# Split the data into training and testing sets
# train_test_split: Allowed inputs are lists, numpy arrays, scipy-sparse matrices or pandas dataframes.
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size = 0.3, random_state = 42)


# Make sure the split of data is correct
print('Training Features Shape:', train_features.shape)
print('Training Labels Shape:', train_labels.shape)
print('Testing Features Shape:', test_features.shape)
print('Testing Labels Shape:', test_labels.shape)