<div align="center">
<a href="http://camma.u-strasbg.fr/">
<img src="lib/camma_logo.png" width="18%">
</a>
</div>



Weakly-supervised ConvLSTM Surgical Tool Tracker
================
------
**A re-implementation of surgical tool tracker in**<br>
<i>Nwoye, C. I., Mutter, D., Marescaux, J., & Padoy, N. (2019). 
    Weakly supervised convolutional LSTM approach for tool tracking in laparoscopic videos. 
    International journal of computer assisted radiology and surgery, 14(6), 1059-1067.<br></i>
    <b> Please cite this paper, if you use this code or part of it.. </b>
    
    
(c) Research Group CAMMA, University of Strasbourg, France<br>
Website: http://camma.u-strasbg.fr<br>
Code author: Chinedu Nwoye <br>
    
-----

This provides the code for:
    1. Model architecture and all its libraries
    2. Evaluation
    3. Visualization

The code uses tf.contrib lib. TensorFlow version > 1.12 is discouraged.

<br> Download code and libraries

In [None]:
!git clone https://github.com/CAMMA-public/ConvLSTM-Surgical-Tool-Tracker.git
%cd ConvLSTM-Surgical-Tool-Tracker

<br> Download sample video data

In [None]:
!wget --content-disposition https://s3.unistra.fr/camma_public/github/convlstm_tracker/data.zip
!unzip data.zip
print("Download completed ...")

<br>Download model weights

In [None]:
!wget --content-disposition https://s3.unistra.fr/camma_public/github/convlstm_tracker/ckpt.zip
!unzip ckpt.zip
print("Download completed ...")

##### Some important installationns

In [None]:
!pip uninstall -y tensorflow
!pip install tensorflow-gpu==1.14
!pip install imageio
!pip install imageio-ffmpeg

##### Imports

In [None]:
import model
import tensorflow as tf
import os
import numpy as np
import cv2
import imageio
import sys
from matplotlib import animation, rc, pyplot as plt
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
from IPython.display import HTML

tf.logging.set_verbosity(tf.logging.ERROR)
%matplotlib inline
print("imports success...")

##### Variables & Device setup

In [None]:
img_height   = 480 #@param {type:"integer"}
img_width    = 854 #@param {type:"integer"}
img_channel  = 3   #@param {type:"integer"}
num_classes  = 7   #@param {type:"integer"}
offset_x     = 20  #@param {type:"integer"}
offset_y     = 11  #@param {type:"integer"}
gpu_usable   = 0   #@param {type:"integer"}
data_path    = 'data/surgical_video.avi' #@param {type:"string"} you can modify this if you evaluate on a different video
ckpt_path    = 'ckpt' #@param {type:"string"}
            
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_usable) 
print("Variables set, model to run on Device: GPU: ", gpu_usable)

##### Model architecture

In [None]:
img_ph  = tf.placeholder(dtype=tf.float32, shape=[None,None,3], name='inputs')
x       = tf.expand_dims(img_ph, 0)   
x       = tf.image.resize_bilinear(x, size=(480,854))             
seek_ph = tf.placeholder(dtype=tf.int64, shape=[None], name='inputs')
network = model.Model(images=x, seek=seek_ph, num_classes=num_classes)
logits, lhmaps  = network.build_model() 
logits  = tf.cast(tf.round(tf.sigmoid(logits)), tf.int32)
lhmaps  = lhmaps * tf.cast(logits, tf.float32)

print("Model loaded successfully...")

##### Saver and weights

In [None]:
with tf.name_scope("saver_and_writer"):                  
    saver = tf.train.Saver()  
    state = tf.train.get_checkpoint_state(ckpt_path)
    ckpt  = state.model_checkpoint_path

print('Loading checkpoint from :',ckpt)

##### Evaluate on video dataset

In [None]:
PREDICTIONS    = []
CLASS_LHMAPS   = []
reader         = imageio.get_reader(data_path)
sess_config    = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9), allow_soft_placement = True, log_device_placement = False) 
with tf.Session(config=sess_config) as sess:   
    sess.run([tf.local_variables_initializer(), tf.global_variables_initializer()])
    saver.restore(sess, ckpt)
    for seek, frame in enumerate(reader):
        predict, lhmap = sess.run([logits, lhmaps], feed_dict={img_ph:frame, seek_ph:[seek]})
        PREDICTIONS.append(predict)
        CLASS_LHMAPS.append(lhmap)
        
print("Evaluation done...")   

#### Helper functions

In [None]:
# Get coordinates

def get_center_coordinates(lhmap):
    coord = np.where(lhmap == lhmap.max()) 
    cx    = (coord[1][0] * 854 // 107) + offset_x
    cy    = (coord[0][0] * 480 // 60 ) + offset_y
    return (cx, cy)

def get_box_coordinates(lhmap):
    coord = np.where(lhmap>0)
    if len(coord[0])>0 and len(coord[1])>0 :
        x0 = (coord[1].min() * 854 // 107) - offset_x
        x1 = (coord[1].max() * 854 // 107) + offset_x
        y0 = (coord[0].min() * 480 // 60) - offset_y
        y1 = (coord[0].max() * 480 // 60) + offset_y
    else:
        x0,x1,y0,y1 = -1,-1,-1,-1
    return (x0,y0,x1,y1)


# Build animators
def build_animators():
    BUFFER_BOX_CENTER = []
    colors    = [(255,0,0),(255,255,0),(0,0,255),(255,0,255),(255,128,0),(0,255,255),(0,255,0)] 
    radius    = 28
    thickness = 4
    reader    = imageio.get_reader(data_path)
    fig       = plt.figure()
    for k, (img, predict, lhmap) in enumerate(zip(reader, PREDICTIONS, CLASS_LHMAPS)):
        img_overlay     = img.copy()
        for i in range(num_classes):
            cam         = lhmap[0,:,:,i]
            x1,y1,x2,y2 = get_box_coordinates(cam)
            cx,cy       = get_center_coordinates(cam)
            color       = colors[i]
            cv2.rectangle(img_overlay, (x1,y1), (x2,y2), color, thickness)
            cv2.circle(img_overlay, (cx,cy), radius, color, -1)
        cv2.circle(img_overlay, (offset_x,offset_y), radius, (0,0,0), -1)
        BUFFER_BOX_CENTER.append([plt.imshow(img_overlay)])
    return fig, BUFFER_BOX_CENTER
        

# Colorizer
def cstr(s, color='black'):
    return "<text style=color:{}>{}</text>".format(color, s)

print("Model ready to track...")

#### Tracking the video
Let's track the instruments in the video<br>
Colormap: display the legend for the tracker

In [None]:
fig, OVERLAY = build_animators()

HTML('='*20+"> [  Tool Colormap:                                       "
           +cstr("Grasper", "red") +" | "+cstr("Bipolar", "yellow") +"  |  "+cstr("Hook", "blue")+"  |  "
           +cstr("Scissors", "violet")+"  |  " +cstr("Clipper", "orange") 
           +"  |  "+cstr("Irrigator", "mouve") +"  |  "+cstr("Specimen bag  ", "green")+'  ] <'+'='*20 )

In [None]:
anim = animation.ArtistAnimation(fig, OVERLAY, interval=160, blit=True, repeat_delay=1000)
HTML(anim.to_html5_video())

The End