# DenseCRF Inference

This notebook will post-process the output of a per-pixel classifier to improve object boundaries.

It uses the hyperparameters that were found through hyperparameter tuning in a separate notebook.

In [None]:
IMAGE_PATH = 'para_condensed_stacks_2020_h7v3.tif' # path to input raster
MODEL_PATH = 'para_model_2020_h7v3.tif'            # path to GLM output

PARAMS_PATH = 'params.yaml'                        # path to file containing hyperparameters

OUTPUT_PATH = 'para_densecrf_2020_h7v3.tif'        # output path 

### Import necessary packages

In [None]:
import pandas as pd
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian
from tqdm import trange
import yaml
import cv2

### Load hyperparameters

In [None]:
with open(PARAMS_PATH,'r') as f:
    params = yaml.safe_load(f)

In [None]:
params

### Read input and model rasters

In [None]:
with rasterio.open(IMAGE_PATH,'r') as f:
    whole_image = f.read().transpose([1,2,0])

In [None]:
whole_image.shape

In [None]:
with rasterio.open(MODEL_PATH,'r') as f:
    meta = f.meta
    whole_probs = np.squeeze(f.read().transpose([1,2,0]))
    whole_probs[np.isnan(whole_probs)] = 0

In [None]:
whole_probs.shape

### Run DenseCRF inference

In [None]:
def run_crf(image, probs, num_iter=5,
            gaussian_sdim=3, gaussian_compat=3, 
            bilateral_sdim=80,
            bilateral_schan1=13, bilateral_schan2=13, bilateral_schan3=13, bilateral_schan4=13, bilateral_schan5=13,
            bilateral_compat=10,
           *args, **kwargs):
    U = unary_from_softmax(np.stack([1-probs,probs],axis=0))
    pairwise_gaussian = create_pairwise_gaussian(sdims=[gaussian_sdim]*2, shape=image.shape[:2])
    pairwise_bilateral = create_pairwise_bilateral(sdims=[bilateral_sdim]*2, schan=[bilateral_schan1,bilateral_schan2,bilateral_schan3,bilateral_schan4,bilateral_schan5], img=image, chdim=2)
    d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], 2)
    d.setUnaryEnergy(U)
    d.addPairwiseEnergy(pairwise_gaussian, compat=gaussian_compat, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
    d.addPairwiseEnergy(pairwise_bilateral, compat=bilateral_compat, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
    Q, tmp1, tmp2 = d.startInference()
    for _ in trange(num_iter):
        d.stepInference(Q, tmp1, tmp2)
    soln = np.argmax(Q, axis=0).reshape(image.shape[:2])
    return soln


In [None]:
whole_soln = run_crf(whole_image, whole_probs, **params)

### Filter connected components by minimum connected component size

In [None]:
def filter_by_min_size(soln, min_size):
    nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(soln.astype('uint8'), connectivity=8)
    sizes = stats[1:, -1]
    nb_components = nb_components - 1
    
    soln_filtered = np.zeros_like(soln)
    for i in range(nb_components):
        if sizes[i] >= min_size:
            soln_filtered[output == i + 1] = 1
    
    return soln_filtered

In [None]:
whole_soln_filtered = filter_by_min_size(whole_soln,params['min_size'])

### Write out result raster

In [None]:
with rasterio.open(OUTPUT_PATH,'w',**meta) as f:
    f.write(whole_soln_filtered,1)