# Dendrite Spine Identification Project

#### Automating annotating dendritic spines in images using Faster Recurrent Convolutional Neural Networks (Faster RCNNs) & TensorFlow Object Detection. 

Heavily based on:
- https://github.com/ily-R/Deep-Learning-for-Dendritic-Spines-Detection/blob/master/report.pdf
- https://arxiv.org/abs/1506.01497

## Preprocessing

In [1]:
# data handling
import pandas as pd
import numpy as np
import os
import cv2
import sys

# tensorflow & object-detection
import tensorflow.compat.v1 as tf
from object_detection.utils import visualization_utils as vis_util
from object_detection.utils import label_map_util


In [2]:
# file_paths
MODEL_DIR = 'object_detection/inference_graph'
TEST_DIR = 'test/img/'
CW_DIR = os.getcwd() # current working directory
CKPT_DIR = os.path.join(CW_DIR, MODEL_DIR, 'frozen_inference_graph.pb') # checkpoint
LABELS_DIR = os.path.join(CW_DIR, 'training', 'labelmap.pbtxt')

In [3]:
# test data
img_files = np.unique(os.listdir(TEST_DIR))  
img_files[:3]

array(['000001.jpg', '000002.jpg', '000003.jpg'], dtype='<U10')

## TensorFlow Object Detection & Networking

In [4]:
# load model
detection_graph = tf.Graph()

with detection_graph.as_default():
    od_graph = tf.GraphDef()
    with tf.gfile.GFile(CKPT_DIR, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph, name='')

    sess = tf.Session(graph=detection_graph)

sess

<tensorflow.python.client.session.Session at 0x18087ec5128>

In [5]:
# input is dendrite image
img_ten = detection_graph.get_tensor_by_name('image_tensor:0')
# outputs are detection boxes(where each box is a part of the image where a spine was detected), scores (where each score is a cofidence level for the spine), & classes (where each class is a label for the spine)
boxes_ten = detection_graph.get_tensor_by_name('detection_boxes:0')
scores_ten = detection_graph.get_tensor_by_name('detection_scores:0')
classes_ten = detection_graph.get_tensor_by_name('detection_classes:0')

n_detected = detection_graph.get_tensor_by_name('num_detections:0')

n_detected

<tf.Tensor 'num_detections:0' shape=<unknown> dtype=float32>

In [6]:
# conduct spine detection for each dendrite img
n_imgs = img_files.shape[0]
cat_idx = {1: {'id': 1, 'name': 'dendrite'}}
for i in range(n_imgs):
    # load img and run model on it
    IMG_DIR = os.path.join(CW_DIR, TEST_DIR + img_files[i])
    img = cv2.imread(IMG_DIR)
    expanded_img = np.expand_dims(img, axis=0)
    feed_dict = {img_ten: expanded_img}
    (boxes, scores, classes, num) = sess.run([boxes_ten, scores_ten, classes_ten, n_detected], feed_dict=feed_dict)
    # draw boxes
    _, boxes_drawn = vis_util.visualize_boxes_and_labels_on_image_array(
        image=img,
        boxes=np.squeeze(boxes),
        classes=np.squeeze(classes).astype(np.int32),
        scores=np.squeeze(scores),
        category_index=cat_idx,
        instance_masks=None,
        instance_boundaries=None,
        keypoints=None,
        use_normalized_coordinates=True,
        max_boxes_to_draw=20,
        min_score_thresh=0.60,
        agnostic_mode=False,
        line_thickness=1,
        groundtruth_box_visualization_color='black',
        skip_scores=True,
        skip_labels=True
    )

    scale_factor = 220  # percent increase of orginal size
    width = int(img.shape[1] * scale_factor / 100)
    height = int(img.shape[0] * scale_factor / 100)
    dims = (width, height)
    resized_img = cv2.resize(img, dims, interpolation=cv2.INTER_AREA)
    # write img to file folder
    cv2.imwrite('output_imgs/'+img_files[i], resized_img)

In [None]:
import matplotlib.pyplot as plt

def display_images(img_files, idx=0, num_show=6):
  plt.figure(figsize=(16, 12))
  plt.subplots_adjust(wspace=0.05, hspace=0.01)

  for i, img in enumerate(img_files[:num_show]):
     plt.subplot(231 + idx)
     idx += 1
     img = cv2.imread(f'./output_imgs/{img}')
     plt.title(f"Dendrite {i+1}")
     plt.imshow(img)
     plt.axis('off')
     plt.savefig("example_output_img.jpg")

display_images(img_files)
