# Mask R-CNN - Nephrology Inference
This is an custom version of [Mask R-CNN - Train cell nucleus Dataset](https://colab.research.google.com/github/navidyou/Mask-RCNN-implementation-for-cell-nucleus-detection-executable-on-google-colab-/blob/master/mask_RCNN_cell_nucleus_google_colab.ipynb) for Google Colab. 

If using this notebook on Google Colab, GPU/TPU might not be used due to version of TensorFlow.

## Google Colab Only

Execute only if using this notebook on Google Colab (installing compatible librairies and getting files needed). Errors might appear, do not worry about this.

In [0]:
!pip install -q scipy==1.1
!pip install -q tensorflow==1.7
!pip install -q keras==2.1.6
GITHUB_REPO = "https://raw.githubusercontent.com/AdrienJaugey/Custom-Mask-R-CNN-for-kidney-s-cell-recognition/master/"
files = ['mrcnn/config.py', 'mrcnn/utils.py', 'mrcnn/model.py', 'mrcnn/visualize.py', 'datasetTools/datasetDivider.py',
         'datasetTools/datasetWrapper.py']
for fileToDownload in files:
  url = GITHUB_REPO + fileToDownload
  !wget -N $url

### Connecting to Google Drive

The first time this cell is executed, a link should appear, asking you to accept to give access to files of a google account. 
1.   **Follow the link**;
2.   **Choose the account** you want to link;
3.   **Accept**;
4.   **Copy the key** Google gave you;
5.   **Paste the key in the text field** that appeared below the first link you used,
6.   **Press ENTER**.

In [0]:
from google.colab import drive
drive.mount('/content/drive')

###Retrieving your image

Choose how to get your image from the following list on the right   
Use ```.jp2``` or ```.png``` images only !


In [0]:
howToGetImage = "From Google Drive" #@param ["Upload", "From Google Drive"]

#### By upload

In [0]:
if howToGetImage == "Upload":
  print("Please upload the image you want to run the inference on")
  from google.colab import files
  src = list(files.upload().values())[0]

#### By copy from Google Drive

Be sure to customize the 2 variables for Google Colab to be able find your file in Google Drive.
Let's say you have this hierarchy in your Google Drive:
```
Root directory of Google Drive
  ├─── Directory1
  └─── Directory2
       ├─── images
       │    └─── example.png
       └─── saved_weights
            └─── weights.h5
```
1.   ```customPathInDrive``` must represent all the directories between the root directory and your weights file. In the example, it would be ```Directory2/images/```. **Do not forget the final /** if you have to use this variable;
2.   ```imageFileName``` must represent the file you want to upload. In the example, it would be ```example.png```.

Use the text fields available on the right.

In [0]:
if howToGetImage == "From Google Drive":
  pathToDrive = "'/content/drive/My Drive/"
  # Keep customPathInDrive empty if file directly in root directory of Google Drive
  customPathInDrive = "" #@param {type:"string"}
  imageFileName = "" #@param {type:"string"}
  annotationsFile = True #@param {type:"boolean"}
  
  pathToImage = pathToDrive + customPathInDrive + imageFileName + "'"
  print("Copying {} to {}".format(pathToImage, imageFileName))
  !cp -u $pathToImage $imageFileName

  if annotationsFile:
    annotationsFileName = imageFileName.split('.')[0] + '.xml'
    pathToAnnotations = pathToDrive + customPathInDrive + annotationsFileName + "'"
    print("Copying {} to {}".format(pathToAnnotations, annotationsFileName))
    !cp -u $pathToAnnotations $annotationsFileName

### Retrieving Weights File

Same thing than retrieving an image file using Google Drive but it is the saved weights file (```.h5``` extension). With the past example, it would be ```Directory2/saved_weights/``` as ```customPathInDrive``` and ```weights.h5``` as ```weightFileName```.

In [0]:
pathToDrive = "'/content/drive/My Drive/"
# Keep customPathInDrive empty if file directly in root directory of Google Drive
customPathInDrive = "" #@param {type:"string"}
weightFileName = "mask_rcnn_nephrologie_649_100.h5" #@param {type:"string"}
pathToWeights = pathToDrive + customPathInDrive + weightFileName + "'"
print("Copying {} to {}".format(pathToWeights, weightFileName))
!cp -u $pathToWeights $weightFileName

## Initialisation

Be sure to set ```IMAGE_PATH``` to the name of the image file (in the example, ```example.png```) and ```MODEL_PATH``` to the same value than ```weightFileName```. If you want to save the results in files ```saveResults``` should be checked. You will have to open the **Files tab** in the **vertical navigation bar on the left** to see the results appearing. Then you can save them by right-clicking on each file and save it.

In [0]:
IMAGE_PATH = "" #@param {type:"string"}
MODEL_PATH = "mask_rcnn_nephrologie_649_100.h5" #@param {type:"string"}
DIVISION_SIZE = 1024 #@param {type:"slider", min:896, max:1024, step:1}
saveResults = True #@param {type:"boolean"}
CELLS_CLASS_NAMES = ["tubule_sain", "tubule_atrophique", "nsg_complet", "nsg_partiel", "pac", "vaisseau",
                     "artefact"]

In [0]:
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    import os
    import sys
    import random
    import math
    import re
    import time
    import numpy as np
    import cv2
    import matplotlib
    import matplotlib.pyplot as plt
    import json
    from shlex import quote
    from time import time, ctime
    from skimage.io import imread, imsave, imshow, imread_collection, concatenate_images
    from skimage.transform import resize

    
    IMAGE_NAME = IMAGE_PATH.split('.')[0]
    import datasetDivider as div

    if '.png' not in IMAGE_PATH:
        print('Converting to png')
        tempPath = IMAGE_NAME + '.png'
        image = imread(IMAGE_PATH)
        imsave(tempPath, image)
        IMAGE_PATH = tempPath

    image = cv2.imread(IMAGE_PATH)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    height, width, _ = image.shape
    xStarts = div.computeStartsOfInterval(width)
    yStarts = div.computeStartsOfInterval(height)

    nbDiv = div.getDivisionsCount(xStarts, yStarts)

    NB_CLASS = len(CELLS_CLASS_NAMES)

    COMPLETE_RESULTS = os.path.exists(IMAGE_NAME + '.xml')

    import config
    import utils
    import model
    import visualize

    from config import Config
    import utils
    import model as modellib
    import visualize
    from model import log

    %matplotlib inline 

    # Root directory of the project
    ROOT_DIR = os.getcwd()

    # Directory to save logs and trained model
    MODEL_DIR = os.path.join(ROOT_DIR, "logs")

    print("Cell done")

## Creating masks if annotations file found

In [0]:
if COMPLETE_RESULTS:
    import datasetWrapper as wr
    wr.createMasksOfImage('.', IMAGE_NAME, 'data')

## Configurations

In [0]:
class CellsConfig(Config):
    NAME = "cells"
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    NUM_CLASSES = 1 + NB_CLASS
    IMAGE_MIN_DIM = DIVISION_SIZE
    IMAGE_MAX_DIM = DIVISION_SIZE
    RPN_ANCHOR_SCALES = (8, 16, 64, 128, 256)
    TRAIN_ROIS_PER_IMAGE = 800
    STEPS_PER_EPOCH = 400
    VALIDATION_STEPS = 50


config = CellsConfig()

## Notebook Preferences

In [0]:
def get_ax(rows=1, cols=1, size=8):
  return plt.subplots(rows, cols, figsize=(size*cols, size*rows), frameon=False)

In [0]:
if COMPLETE_RESULTS:
    class LightCellsDataset(utils.Dataset):
        __CELLS_CLASS_NAMES = CELLS_CLASS_NAMES.copy()

        def get_class_names(self):
            return self.__CELLS_CLASS_NAMES.copy()

        def load_cells(self):
            # Add classes
            for class_id, class_name in enumerate(self.__CELLS_CLASS_NAMES):
                self.add_class("cells", class_id + 1, class_name)

            img_path = 'data/' + IMAGE_NAME + '/images/'
            self.add_image("cells", image_id=IMAGE_NAME, path=img_path)

        def load_image(self, image_id):

            info = self.image_info[image_id]
            info = info.get("id")

            img = imread('data/' + info + '/images/' + info + '.png')[:, :, :3]

            return img

        def image_reference(self, image_id):
            """Return the cells data of the image."""
            info = self.image_info[image_id]
            if info["source"] == "cells":
                return info["cells"]
            else:
                super(self.__class__).image_reference(self, image_id)

        def load_mask(self, image_id):
            """Generate instance masks for cells of the given image ID.
            """
            info = self.image_info[image_id]
            info = info.get("id")

            path = 'data/' + info

            # Counting masks for current image
            number_of_masks = 0
            for masks_dir in os.listdir(path):
                # For each directory excepting /images
                if masks_dir not in self.__CELLS_CLASS_NAMES:
                    continue
                temp_DIR = path + '/' + masks_dir
                # Adding length of directory https://stackoverflow.com/questions/2632205/how-to-count-the-number-of-files-in-a-directory-using-python
                number_of_masks += len(
                    [name for name in os.listdir(temp_DIR) if os.path.isfile(os.path.join(temp_DIR, name))])

            mask = np.zeros([height, width, number_of_masks], dtype=np.uint8)
            iterator = 0
            class_ids = np.zeros((number_of_masks,), dtype=int)
            for masks_dir in os.listdir(path):
                if masks_dir not in self.__CELLS_CLASS_NAMES:
                    continue
                temp_class_id = self.__CELLS_CLASS_NAMES.index(masks_dir) + 1
                for mask_file in next(os.walk(path + '/' + masks_dir + '/'))[2]:
                    mask_ = imread(path + '/' + masks_dir + '/' + mask_file)
                    mask[:, :, iterator] = mask_
                    class_ids[iterator] = temp_class_id
                    iterator += 1
            # Handle occlusions
            occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
            for i in range(number_of_masks - 2, -1, -1):
                mask[:, :, i] = mask[:, :, i] * occlusion
                occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
            return mask, class_ids.astype(np.int32)

In [0]:
if COMPLETE_RESULTS:
    dataset_val = LightCellsDataset()
    dataset_val.load_cells()
    dataset_val.prepare()

## Detection

### Initialisation of the inference model and loading of weights 

In [0]:
class InferenceConfig(CellsConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

inference_config = InferenceConfig()

# Recreate the model in inference mode
model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir=MODEL_DIR)

# Load trained weights (fill in path to trained weights here)
assert MODEL_PATH != "", "Provide path to trained weights"
print("Loading weights from ", MODEL_PATH)
model.load_weights(MODEL_PATH, by_name=True)

### Inference

#### Display of the Input Image

In [0]:
if COMPLETE_RESULTS:
    fileName = None
    if saveResults:
      fileName = "{} Expected".format(IMAGE_NAME)
    image_id = dataset_val.image_ids[0]
    gt_mask, gt_class_id = dataset_val.load_mask(image_id)
    gt_bbox = utils.extract_bboxes(gt_mask)
    
    visualize.display_instances(image, gt_bbox, gt_mask, gt_class_id, dataset_val.class_names,
                                colorPerClass=True, figsize=(width / 100, height / 100),
                                title="{} Expected".format(IMAGE_NAME),
                                fileName="{} Expected".format(IMAGE_NAME))
else:
    _, ax = get_ax(size=16)
    ax.imshow(image)
    ax.axis('off')
    plt.show()

#### Getting predictions for each division

In [0]:
res = []
for divId in range(nbDiv):
    division = div.getImageDivision(image, xStarts, yStarts, divId)
    print('Inference {}/{}'.format(divId + 1, nbDiv))
    results = model.detect([division])
    res.append(results[0])

#### Post-processing of the predictions

In [0]:
print("Fusing results of all divisions")
fused_results = utils.fuse_results(res, image)
print("Results fused")

In [0]:
print("Fusing overlapping masks")
fused_mask = utils.fuse_masks(fused_results, 
                              bb_threshold=0.1, mask_threshold=0.1,
                              verbose=1)
print("Masks fused")

In [0]:
print("Removing non-sense masks")
filtered_masks = utils.filter_fused_masks(fused_mask, 
                                          bb_threshold=0.5,
                                          vessel_threshold=0.9,
                                          vessel_class_id=6)
print("Masks filtered")

In [0]:
if COMPLETE_RESULTS:
    gt_mask, gt_class_id = dataset_val.load_mask(image_id)
    AP, precisions, recalls, overlaps, confusion_matrix = utils.compute_ap(
        gt_bbox, gt_class_id, gt_mask,
        filtered_masks["rois"], filtered_masks["class_ids"], filtered_masks["scores"], filtered_masks['masks'],
        nb_class=NB_CLASS, confusion_iou_threshold=0.1)
    
    print("Average Precision is about {:5.2f}%".format(AP * 100))

    name = "{} Confusion Matrix".format(IMAGE_NAME)
    name2 = "{} Normalized Confusion Matrix".format(IMAGE_NAME)
    cmap = plt.cm.get_cmap('hot')
    visualize.display_confusion_matrix(confusion_matrix, dataset_val.get_class_names(), title=name,
                                       cmap=cmap, show=False, fileName=name)
    visualize.display_confusion_matrix(confusion_matrix, dataset_val.get_class_names(), title=name2,
                                       cmap=cmap, show=False, normalize=True, fileName=name2)

In [0]:
print("Displaying results")

fileName = None
if saveResults:
  fileName = "{} Predicted".format(IMAGE_NAME)
names = CELLS_CLASS_NAMES.copy()
names.insert(0, 'background')
_ = visualize.display_instances(image, filtered_masks['rois'], filtered_masks['masks'], filtered_masks['class_ids'], 
                                names, filtered_masks['scores'], colorPerClass=True,
                                figsize=(width / 100, height / 100), 
                                fileName=fileName, onlyImage=True)