# Script for running Machine-Learning models for Detecting Hot Spots
## Part 2 Inference

#### Requirements
* matplotlib
* numpy
* pandas
* pycaret
* rasterio

In [25]:
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

import pycaret
import pycaret.classification as cl

import rasterio

#### Settings 

In [2]:
MODEL = '../models/2020-10-30_catBoost_tuned_model'

In [None]:
DATA = '../data_inference/Inference_Baldwin.tif'
DATA_OUTPUT = '../data_inference/Inference_Baldwin_classified.tif'

In [17]:
DATA = '../data_inference/Yana/Inference_Yana-0000006400-0000019200-016.tif'
DATA_OUTPUT = '../data_inference/Yana/Inference_Yana-0000006400-0000019200-016_classified.tif'

#### Model Loading 

In [18]:
model = cl.load_model(MODEL)

Transformation Pipeline and Model Sucessfully Loaded


### Data Loading 
Load raster data and reorder

In [19]:
with rasterio.open(DATA) as src:
    columns = src.descriptions
    data = src.read()

In [20]:
data.shape

(27, 6400, 6400)

### Data Analysis - Prediction 

In [21]:
data_reshaped = pd.DataFrame(data.reshape(27, -1).T, columns=columns)

In [22]:
prediction = cl.predict_model(model, data_reshaped)['Label'].values.reshape(*data.shape[1:])

### Data Export 

In [23]:
with rasterio.open(DATA) as src:
    with rasterio.open(DATA_OUTPUT, mode='w', 
                    driver=src.driver, 
                    width=src.width, 
                    height=src.height, 
                    count=1, 
                    dtype='uint8', 
                    transform=src.transform,
                    crs=src.crs) as dst:
        dst.write_band(1, np.array(prediction, dtype=np.uint8))

### Visualization

In [None]:
plt.imshow(prediction)

In [None]:
plt.imshow(prediction==15)

### Automation

In [44]:
flist = glob.glob(r'K:\127_HotSpotOptimizer\hot_spot_classifier\data_inference\*.tif')

In [45]:
flist

['K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_160W60N-0000000000-0000000000-005.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_160W60N-0000000000-0000006400-006.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_160W60N-0000000000-0000012800-004.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_160W60N-0000006400-0000000000-002.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_160W60N-0000006400-0000006400-003.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_160W60N-0000006400-0000012800.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_Yana-0000000000-0000000000.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_Yana-0000000000-0000006400.tif',
 'K:\\127_HotSpotOptimizer\\hot_spot_classifier\\data_inference\\Inference_Yana-0000000000-0000012

In [None]:
for f in flist:
    outfile = os.path.join(os.path.dirname(f), os.path.basename(f)[:-4] + '_classified.tif')
    run_inference(MODEL, f, outfile)

In [46]:
def run_inference(model_path, infile_path, outfile_path):
    print("Processing dataset input:", infile_path)
    print("Processing dataset output:", outfile_path)
    
    model = cl.load_model(model_path)
    
    with rasterio.open(infile_path) as src:
        columns = src.descriptions
        data = src.read()
    data_reshaped = pd.DataFrame(data.reshape(27, -1).T, columns=columns)
    
    prediction = cl.predict_model(model, data_reshaped)['Label'].values.reshape(*data.shape[1:])
    
    with rasterio.open(infile_path) as src:
        with rasterio.open(outfile_path, 
                            mode='w', 
                            driver=src.driver, 
                            width=src.width, 
                            height=src.height, 
                            count=1, 
                            dtype='uint8', 
                            transform=src.transform,
                            crs=src.crs) as dst:
            dst.write_band(1, np.array(prediction, dtype=np.uint8))