# LearningNMS Demo

## Install and Import of Requirements

In [None]:
!pip install -e histomics_detect/ 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image

import tensorflow as tf
import tensorflow.keras.backend as kb
# import tensorflow_addons as tfa

import yaml
import json
import sys
import os
import warnings

from abc import ABC

sys.path.append("/tf/notebooks/histomics_detect/")
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
from histomics_detect.anchors import create_anchors
from histomics_detect.networks import field_size, residual_backbone, transfer_layers, rpn
from histomics_detect.roialign import roialign
from histomics_detect.metrics import iou, greedy_iou_mapping as greedy_iou
from histomics_detect.models import FasterRCNN
from histomics_detect.models import LearningNMS
from histomics_detect.models.faster_rcnn import map_outputs
from histomics_detect.io import resize
from histomics_detect.augmentation import flip, crop
from histomics_detect.boxes import unparameterize, filter_edge_boxes, parameterize
from histomics_detect.visualization.visualization import _plot_boxes
from histomics_detect.visualization.lnms_visualization import plot_inference, run_plot
from histomics_detect.models.experiment_utils import run_experiments

from histomics_detect.models.block_model import BlockModel
from histomics_detect.models.compression_network import CompressionNetwork
from histomics_detect.models.lnms_model import LearningNMS
from histomics_detect.boxes.match import cluster_assignment
from histomics_detect.models.lnms_loss import normal_loss, clustering_loss, paper_loss, xor_loss, normal_clustering_loss, calculate_labels, _pos_neg_loss_calculation
from histomics_detect.metrics.lnms import tf_linear_sum_assignment
from histomics_detect.models.model_utils import extract_data

## Load Data

In [None]:
#import dataset related packages
from histomics_detect.io import dataset, resize
from histomics_detect.augmentation import crop, flip, jitter, shrink
# from histomics_detect.visualization import plot_inference
import numpy as np
import os

#input data path
path = '/tf/notebooks/data/DLBCL/'

factor = 3
#training parameters
train_tile = 224 #input image size
min_area_thresh = 0.5 # % of object area that must be in crop to be included
width = tf.constant(train_tile, tf.int32)
height = tf.constant(train_tile, tf.int32)
min_area = tf.constant(min_area_thresh, tf.float32)

#define filename parsers
def png_parser(png):
    file = os.path.splitext(png)[0]
    case = file.split('.')[0]
    roi = '.'.join(file.split('.')[1:])
    return case, roi

def csv_parser(csv):
    file = os.path.splitext(csv)[0]    
    case = file.split('.')[0]
    roi = '.'.join(file.split('.')[1:2] + file.split('.')[-3:])
    return case, roi

training = ['DCBT_2_CMYC', 'DCBT_3_CMYC', 'DCBT_5_CMYC',
            'DCBT_9_CMYC', 'DCBT_10_CMYC', 'DCBT_12_CMYC', 
            'DCBT_14_CMYC', 'DCBT_18_CMYC', 'DCBT_19_CMYC', 
            'DCBT_20_CMYC', 'DCBT_21_CMYC', 'DCBT_22_CMYC']
# training = ['DCBT_10_CMYC']
validation = ['DCBT_1_CMYC', 'DCBT_4_CMYC', 'DCBT_6_CMYC',
              'DCBT_8_CMYC', 'DCBT_11_CMYC',
              'DCBT_13_CMYC', 'DCBT_15_CMYC', 'DCBT_16_CMYC',
              'DCBT_17_CMYC']


#generate training, validation datasets
ds_train_roi = dataset(path, png_parser, csv_parser, train_tile, training)
ds_validation_roi = dataset(path, png_parser, csv_parser, 0, validation)

#build training dataset
ds_train_roi = ds_train_roi.map(lambda x, y, z: (*resize(x, y, 1.0*factor), z))
ds_train_roi = ds_train_roi.map(lambda x, y, z: (*crop(x, y, width, height, 
                                                                 min_area_thresh), z))
ds_train_roi = ds_train_roi.map(lambda x, y, z: (*flip(x, y), z))
ds_train_roi = ds_train_roi.map(lambda x, y, z: (x, jitter(y, 0.05), z))
ds_train_roi = ds_train_roi.map(lambda x, y, z: (x, shrink(y, 0.05), z))
ds_train_roi = ds_train_roi.prefetch(tf.data.experimental.AUTOTUNE)

#build validation datasets
ds_validation_roi = ds_validation_roi.map(lambda x, y, z: (*resize(x, y, 1.0*factor), z))
ds_validation_roi = ds_validation_roi.prefetch(tf.data.experimental.AUTOTUNE)

## Load Faster R-CNN model

In [None]:
# import model ------------------------------------------------------------------------------------------

#import network generation and training packages
from histomics_detect.networks.rpns import rpn
from histomics_detect.models.faster_rcnn import FasterRCNN

#choices for anchor sizes - all anchors 1:1 aspect ratio
anchor_px = tf.constant([32, 48, 64, 76, 96, 108], dtype=tf.int32) #width/height of square anchors in pixels at input mag.
anchor_px = tf.constant([32, 48, 64], dtype=tf.int32)


#feature network parameters
backbone_stride = 1 #strides in feature generation network convolution
backbone_blocks = 14 #number of residual blocks to use in backbone
backbone_dimension = 256 #number of features generated by rpn convolution

#rpn network parameters
rpn_kernel = [3] #kernel size for rpn convolution
rpn_act_conv = ['relu'] #activation for rpn convolutional layers

#anchor filtering parameters
neg_max = 128 #maximum number of negative/positive anchors to keep in each roi
pos_max = 128
rpn_lmbda = 10.0 #weighting for rpn regression loss
roialign_tiles = 3.0 #roialign - number of horizontal/vertical tiles in a proposal
roialing_pool = 2.0 #roialign - number of horizontal/vertical samples in each tile

#create backbone and rpn networks
resnet50 = tf.keras.applications.ResNet50(
    include_top=False, weights='imagenet', input_tensor=None,
    input_shape=(train_tile, train_tile, 3), pooling=None)
rpnetwork, backbone = rpn(resnet50, n_anchors=tf.size(anchor_px),
                          stride=backbone_stride, blocks=backbone_blocks, 
                          kernels=rpn_kernel, dimensions=[backbone_dimension],
                          activations=rpn_act_conv)

#create FasterRCNN keras model
faster_model = FasterRCNN(rpnetwork, backbone, [width, height], anchor_px, rpn_lmbda)

#compile FasterRCNN model with losses
faster_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss=[tf.keras.losses.BinaryCrossentropy(from_logits=True),
                    tf.keras.losses.Huber()])

## Load Weights of Faster R-CNN model

In [None]:
faster_path = "/tf/notebooks/networks/cpk_ly_3"

faster_model.load_weights(faster_path)

## Train Faster R-CNN model

In [None]:
model.fit(x=ds_train_roi, batch_size=1, epochs=1, verbose=1,
          validation_data=ds_validation_roi, validation_freq=50)

## Load LearningNMS configs

In [None]:
configs_path = './histomics_detect/example/lnms_configs.yml'

with open(configs_path) as config_file:
  configs = yaml.safe_load(config_file)

for key, value in configs['special_configs'].items():
    try:
      configs[key] = eval(value.replace('\\\'', '\''))
    except:
      pass

configs['roialign_pool'] = tf.cast(configs['roialign_pool'], tf.int32)
configs['roialign_tiles'] = tf.cast(configs['roialign_tiles'], tf.int32)

## Train LearningNMS

In [None]:
compression_net = CompressionNetwork(configs['feature_size'], configs['anchor_size'], faster_model.backbone)
    
model = LearningNMS(configs, faster_model.rpnetwork, faster_model.backbone, compression_net.compression_layers, 
                    [configs['width'], configs['height']], )

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-5,
    decay_steps=1000,
    decay_rate=0.9)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule))


# train model
history_callback = model.fit(x=ds_train_roi, batch_size=1, epochs=1, verbose=1, steps_per_epoch=2)

## Plot Inference

In [None]:
index = 0
save_fig = False
fig_path = 'path/to/image.png'
filter_edge = True

run_plot(ds_validation_roi, model, index, save_fig, fig_path, filter_edge)