In [1]:
!git clone https://github.com/CAMMA-public/ConvLSTM-Surgical-Tool-Tracker.git #contains model architecture
%cd ConvLSTM-Surgical-Tool-Tracker

print("Repo cloned and extracted ...")

Cloning into 'ConvLSTM-Surgical-Tool-Tracker'...
remote: Enumerating objects: 76, done.[K
remote: Counting objects: 100% (76/76), done.[K
remote: Compressing objects: 100% (55/55), done.[K
remote: Total 76 (delta 33), reused 54 (delta 18), pack-reused 0[K
Unpacking objects: 100% (76/76), done.
/content/ConvLSTM-Surgical-Tool-Tracker
Repo cloned and extracted ...


In [None]:
if 'google.colab' in str(get_ipython()):  # colab installs tf.2.2 on default.
    !pip uninstall -y tensorflow
    !pip install tensorflow-gpu==1.14
!pip install imageio
!pip install imageio-ffmpeg

print("Installations completed ...")

In [None]:
import model
import tensorflow as tf
import os
import numpy as np
import cv2
import imageio
import sys
print("\npwd = ", os.getcwd())

from matplotlib import animation, rc, pyplot as plt
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
from IPython.display import HTML
tf.logging.set_verbosity(tf.logging.ERROR)
%matplotlib inline

print("imports success...")

In [2]:
img_height   = 480 #@param {type:"integer"}
img_width    = 854 #@param {type:"integer"}
img_channel  = 3   #@param {type:"integer"}
num_classes  = 7   #@param {type:"integer"}
offset_x     = 20  #@param {type:"integer"}
offset_y     = 11  #@param {type:"integer"}
VIDEO_NUM = "01" #change to train on a different video
data_path    = '../drive/MyDrive/surgical_tracking_training_material/video'+VIDEO_NUM+'.mp4'
#data_path should contain video from cholec80 dataset

print("Model and device variables set .. ")

../drive/MyDrive/surgical_tracking_training_material/video01.mp4
Model and device variables set .. 


In [None]:

#reading training data labels
with open("../drive/MyDrive/surgical_tracking_training_material/video" + VIDEO_NUM + "-tool.txt", "r") as f: #path for ground-truth labels text file
  tr_label_matrix = []
  label_arr = []
  row = 0
  col = 0
  for line in f:
    if(row > 0):
      for word in line.split():
        if(col > 0):
          label_arr.append(float(word)) 
        col += 1
      tr_label_matrix.append(label_arr) # Each row/arrray of labels denotes which collection of tools are present in each frame
      label_arr = []
      col = 0
    row += 1
 


In [None]:
import imageio
reader = imageio.get_reader(data_path) 

tr_frames = []
for i, frame in enumerate(reader):
  if(i % 25 == 0):  #reading every 25th video frame into tr_frames[] (downsampling from 25fps to 1fps)
    tr_frames.append(frame)


In [None]:
tf.reset_default_graph()

img_ph  = tf.placeholder(dtype=tf.float32, shape=[None, None,3], name='inputs') #placeholder for feeding training images
x       = tf.expand_dims(img_ph, 0)   
x       = tf.image.resize_bilinear(x, size=(480,854))             
seek_ph = tf.placeholder(dtype=tf.int64, shape=[None], name='inputs') #placeholder for feeding frame number (used for spatio-temporal calculations)
labels_ph = tf.placeholder(dtype=tf.float32, shape=[num_classes], name='labels') #placeholder for feeding training labels
network = model.Model(images=x, seek=seek_ph, num_classes=num_classes)
logits_float, lhmaps  = network.build_model() #logits collects tool presence probabilities and lhmaps stores location graphs
logits_float =  tf.sigmoid(logits_float)
logits  = tf.cast(tf.round(logits_float), tf.int32)
lhmaps  = lhmaps * tf.cast(logits, tf.float32)


print("Model loaded successfully...")

In [None]:
cross_entropy = -tf.reduce_mean(((labels_ph*tf.log(logits_float + 1e-9)) + ((1-labels_ph) * tf.log(1 - logits_float + 1e-9))) , name='xentropy' )
#entropy calculates loss as difference between ground-truth labels (labels_ph) and predicted probabilities (logits)
#you can add per class weights to cross_entropy as mentioned in paper

#setting training hyperparameters
ln_rate =  0.001
optimizer = tf.train.GradientDescentOptimizer(learning_rate=ln_rate).minimize(cross_entropy)
#optimizers will tune weights to reach convergence. The GSD is used in this case


In [None]:
from re import I


#creating a training session
import sys

sess_config    = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9), allow_soft_placement = True, log_device_placement = False) 
with tf.Session(config=sess_config) as sess:   
    sess.run([tf.local_variables_initializer(), tf.global_variables_initializer()])

    EPOCHS = 15 #decides number of training cycles
    min_frame_count = min(len(tr_frames), len(tr_label_matrix))
  

   # training loop over the number of epoches
    for epoch in range(EPOCHS):
        total_loss = 0;
        for i in range(min_frame_count):
            x_batch=tr_frames[i]
            y_batch=tr_label_matrix[i]

            # feeding training data/examples     
            _, loss_val = sess.run([optimizer, cross_entropy], feed_dict={img_ph:x_batch , seek_ph:[i], labels_ph:y_batch})
            total_loss += loss_val

        total_loss /= min_frame_count # computes average loss
        print("avg loss =", str(total_loss))

        if (epoch % 5 == 0 and epoch > 0): #for every 5 epochs, the model is tested on same video frames
          PREDICTIONS    = [] #stores logits(or predictions for each frame)
          CLASS_LHMAPS   = [] #stores heat map
          seek = 0
          #change path below to direct predictions to desired file
          with open('../drive/MyDrive/logits_ep'+ str(epoch)+'_of_'+ str(EPOCHS)+'_' + str(min_frame_count) + '.txt', 'w') as f:
            for frame in tr_frames:
                predict, lhmap = sess.run([logits, lhmaps], feed_dict={img_ph:frame, seek_ph:[seek]}) #evaluates model and collects predictions + heat map
                PREDICTIONS.append(predict)
                CLASS_LHMAPS.append(lhmap)
                if(seek >= min_frame_count):
                  break
          print("Evaluation done after epoch ", epoch)


In [None]:
def get_center_coordinates(lhmap):
    coord = np.where(lhmap == lhmap.max()) 
    cx    = (coord[1][0] * img_width // 107) + offset_x
    cy    = (coord[0][0] * img_height // 60) + offset_y
    return (cx, cy)

def get_box_coordinates(lhmap):
    coord = np.where(lhmap>0)
    if len(coord[0])>0 and len(coord[1])>0 :
        x0 = (coord[1].min() * img_width // 107) - offset_x
        x1 = (coord[1].max() * img_width // 107) + offset_x
        y0 = (coord[0].min() * img_height // 60) - offset_y
        y1 = (coord[0].max() * img_height // 60) + offset_y
    else:
        x0,x1,y0,y1 = -1,-1,-1,-1
    return (x0,y0,x1,y1)


# Build animators
def build_animators():
    BUFFER_BOX_CENTER = []
    colors    = [(255,0,0),(255,255,0),(0,0,255),(255,0,255),(255,128,0),(0,255,255),(0,255,0)] 
    radius    = 28
    thickness = 4
    reader    = imageio.get_reader(data_path)
    fig       = plt.figure()
    for k, (img, predict, lhmap) in enumerate(zip(tr_frames, PREDICTIONS, CLASS_LHMAPS)):
        img_overlay     = img.copy()
        for i in range(num_classes):
            cam         = lhmap[0,:,:,i]
            x1,y1,x2,y2 = get_box_coordinates(cam)
            cx,cy       = get_center_coordinates(cam)
            color       = colors[i]
            cv2.rectangle(img_overlay, (x1,y1), (x2,y2), color, thickness)
            cv2.circle(img_overlay, (cx,cy), radius, color, -1)
        cv2.circle(img_overlay, (offset_x,offset_y), radius, (0,0,0), -1)
        BUFFER_BOX_CENTER.append([plt.imshow(img_overlay)])
    return fig, BUFFER_BOX_CENTER
        

# Colorizer
def cstr(s, color='black'):
    return "<text style=color:{}>{}</text>".format(color, s)

print("Model ready to track...")

In [None]:
fig, OVERLAY = build_animators()

HTML('='*20+"> [  Tool Colormap:                                       "
           +cstr("Grasper", "red") +" | "+cstr("Bipolar", "yellow") +"  |  "+cstr("Hook", "blue")+"  |  "
           +cstr("Scissors", "violet")+"  |  " +cstr("Clipper", "orange") 
           +"  |  "+cstr("Irrigator", "mouve") +"  |  "+cstr("Specimen bag  ", "green")+'  ] <'+'='*20 )

In [None]:
anim = animation.ArtistAnimation(fig, OVERLAY, interval=160, blit=True, repeat_delay=1000)
HTML(anim.to_html5_video())