# CDACS Model Experiment - Inferencing

## Import necessary libraries

In [None]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU') 
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)
physical_devices

In [None]:
import os
from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
from module.model_utils import *
from module.metrics import *
from module.dataset_utils import BasicDatasetProcess

## Import datasets and examine details from output

In [None]:
import datasets.camelyon16

dataset_wrappers_he = BasicDatasetProcess.get_dataset_wrapper_from_dataset('camelyon16', 'HE_CR')

dataset_wrappers_he

In [None]:
import datasets.fluorescence

dataset_wrappers_if = BasicDatasetProcess.get_dataset_wrapper_from_dataset('fluorescence', 'IF_CR')

dataset_wrappers_if

### Preprocessing datasets using Color Deconvolution(CD) algorithm in the batch-processed manner

In [None]:
he_test = dataset_wrappers_he['test'].process()
if_test = dataset_wrappers_if['test'].process()

he_test, if_test

## Setup UNet + MobileNetV2 hybrid model for 1024 input size

In [None]:
patch_size=1000
input_size=1024

In [None]:
model_obj = MobileNetV2_1024_Model(
    output_channels=2,
    input_channels=1,
    input_size=input_size,
)
model = model_obj.model

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=[
                  'accuracy',
                  jacard_coef,
                  dice_coef,
              ])

## Load CDACS model best checkpoint

In [None]:
weight_path = '/path/to/your/weight.hdf5'
model.load_weights(weight_path)

## Inferencing

### Pre-defined necessary paths and functions to generate result images

In [None]:
def large_prediction(image, mask, cr_mask=None, patch_size=patch_size):
    pred_mask = model_obj.easy_predict_single(image, patch_size=patch_size, batch_size=2)
    
    pp_pred_mask = np.array(pred_mask, dtype=bool)
    if cr_mask != None:
        cr_mask = np.array(cr_mask.numpy(), dtype=bool)
        pp_pred_mask = np.logical_and(pp_pred_mask, cr_mask)
    
    fig = display(
        [image, mask, pred_mask, pp_pred_mask],
        show=False,
        # figsize=(30, 30),
        dpi=300,
        title = ['Input Image', 'True Mask', 'Predicted Mask', 'Post-processed Mask'],
    )
        
    return image, mask, pred_mask, pp_pred_mask, fig

In [None]:
from tqdm import tqdm

output_path = os.path.join('inferences')
os.makedirs(output_path, exist_ok=True)
he_output_path = os.path.join(output_path, 'CDACS_HECR')
os.makedirs(he_output_path, exist_ok=True)
if_output_path = os.path.join(output_path, 'CDACS_IFCR')
os.makedirs(if_output_path, exist_ok=True)

def save_all_imgs(dataset_path, image, mask, pred, file_name, fig, save=True):
    basename = file_name.split('.')[0]
    
    if save:
        plt.imsave(os.path.join(dataset_path, f'{basename}_input.png'), image[..., 0], cmap='gray')
        plt.imsave(os.path.join(dataset_path, f'{basename}_gt.png'), mask[..., 0], cmap='gray')
        plt.imsave(os.path.join(dataset_path, f'{basename}_pred.png'), pred[..., 0], cmap='gray')

        fig.savefig(os.path.join(dataset_path, f'{basename}_fig.png'))
    plt.show()

### Inferencing H&E dataset

In [None]:
for obj in tqdm(he_test.iterobjs()):
    img, mask = obj['image'], obj['segmentation_mask']
    file_name = obj['file_name'].numpy().decode('utf-8')
    
    cr_mask = obj['color_region_mask']
    image, mask, pred, pp_pred, fig = large_prediction(img, mask, cr_mask)
    
    save_all_imgs(he_output_path, image, mask, pp_pred, file_name, fig)

### Inferencing Fluorescence dataset

In [None]:
for obj in tqdm(if_test.iterobjs()):
    img, mask = obj['image'], obj['segmentation_mask']
    file_name = obj['file_name'].numpy().decode('utf-8')
    
    cr_mask = obj['color_region_mask']
    image, mask, pred, pp_pred, fig = large_prediction(img, mask, cr_mask, patch_size)
    
    save_all_imgs(he_output_path, image, mask, pp_pred, file_name, fig)