## Libraries

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from statistics import mean
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Dense, PReLU, Conv2D, Flatten
from tensorflow.keras.optimizers import RMSprop

import random as rnd

np.random.seed(42)

In [4]:
import session_info
session_info.show()

## Classes and functions

In [None]:
'''Here you can find all the classes and the custom functions used for this project'''

class SurgeryEnv: 
    
    def __init__(self):
        
        self.size = 10 #size of the grid 
        self.surgery_mat =  -np.ones((self.size+1, self.size+1), dtype=np.int32) #the matrix that represents the environment
        self.knife_position_x = 3 #position of the knife (x coordinate)
        self.knife_position_y = 5 #position of the knife (y coordinate)
        self.original_knife_coordinates = (3,5)
        self.object_points = [(4,4), (4,5),(4,6),(5,4),(5,5),(5,6),(6,4),(6,5),(6,6)] #points of the matrix that correspond to the object to be removed (in this case the uterus)
        self.object_value = 32  #the value that identify the uterus in the matrix
        self.target_points = [(3,4), (3,6), (4,7),(4,8),(5,8),(4,3),(4,2),(5,2),(7,4), (8,4), (7,6), (8,6)] #points of the matrix that correspond to the target area (the objective of the Agent is to cut these points)
        self.target_value = 64 #the value that identify the target area in the matrix
        self.ligart_points = [(1,4), (2,4), (1,6),(2,6),(5,3),(4,1),(5,1),(5,7),(4,9), (5,9), (9,4), (9,6)] #points of the matrix that correspond to the area of ligaments and/or arteries (these must be avoided)
        self.ligart_value = 128 #the value that identify ligaments and arteries in the matrix
        self.peritoneum_value = 255 #the value that identify the peritoneum (the points in which the Agent can navigate) in the matrix
        self.knife_value = 16 #the value assigned to the knife
        self.path = [] #list to take trace of the path of the knife
        self.taken_actions = [] #list to take trace of the actions taken by the knife
        self.action_space = [0,1,2,3] #possible actions
        self.actions_dict = {0:'up', 1:'right', 2:'down', 3:'left'} #dictionary to associate to each value the correspondent action
        self.n_actions = len(self.action_space) #number of possible actions
        
    def gen_env(self):

        '''Creating the frame'''
        self.surgery_mat[0] = 0 #first row
        self.surgery_mat[-1] = 0 #last row
        self.surgery_mat[:, -1] = 0 # last column
        self.surgery_mat[:, 0] = 0 #first column

        '''Creating the object at the center'''
        for object_point in self.object_points:
            self.surgery_mat[object_point[0], object_point[1]] = self.object_value

        '''Creating the target area'''
        for target_point in self.target_points:
            self.surgery_mat[target_point[0], target_point[1]] = self.target_value

        '''Creating ligaments and arteries'''
        for ligart_point in self.ligart_points:
            self.surgery_mat[ligart_point[0], ligart_point[1]] = self.ligart_value

        '''Creating peritoneum'''
        mask = self.surgery_mat == -1
        self.surgery_mat[mask] = self.peritoneum_value

        return self.surgery_mat.copy()

    def plot_env(self):

        '''Plot'''
        plt.figure(figsize=(5,5))
        plt.imshow(self.surgery_mat)
        plt.show()
        
    def step_env(self, action):
        
        reward = 0
        done = False
        
        self.path.append((self.knife_position_x, self.knife_position_y)) #take trace of the position of the knife
        
        '''Applying action'''
        if action == 0: #up
            self.knife_position_x = self.knife_position_x-1
            
        if action == 1: #right
            self.knife_position_y = self.knife_position_y+1
            
        if action == 2: #down
            self.knife_position_x = self.knife_position_x+1
            
        if action == 3: #left
            self.knife_position_y = self.knife_position_y-1
        
        '''Determine reward and done values'''
        if self.surgery_mat[self.knife_position_x,self.knife_position_y] == 0: #frame
            reward = -5
            done=True
            
        if self.surgery_mat[self.knife_position_x,self.knife_position_y] == self.object_value: #uterus
            reward = -5
            done=True
            
        if self.surgery_mat[self.knife_position_x,self.knife_position_y] == self.target_value: #target
            reward = +5
            
        if self.surgery_mat[self.knife_position_x,self.knife_position_y] == self.ligart_value: #ligaments or arteries
            reward = -3
            
        if self.surgery_mat[self.knife_position_x,self.knife_position_y] == self.peritoneum_value: #peritoneum
            reward = -1
            
        if self.surgery_mat[self.knife_position_x,self.knife_position_y] == self.knife_value: #cells already visited
            reward = -3
            
        self.surgery_mat[self.knife_position_x,self.knife_position_y] = self.knife_value #set the new position of the knife
        
        mask = self.surgery_mat == self.target_value #check that all the target area is gone
        if not mask.any():
            done=True
    
        self.taken_actions.append(self.actions_dict[action]) #take trace of the current action
        
                                                                            #info
        return self.surgery_mat.copy(), reward, done, {'path': self.path, 'taken_actions': self.taken_actions} 
        
    def reset_env(self):
        
        '''Generating the basic environment'''
        self.surgery_mat = self.gen_env()
        
        ''' Cleaning the history of the knife'''
        self.path = []
        self.taken_actions = []
       
        '''Positioning the knife'''
        self.knife_position_x = self.original_knife_coordinates[0]
        self.knife_position_y = self.original_knife_coordinates[1]
        self.surgery_mat[self.original_knife_coordinates[0], self.original_knife_coordinates[1]] = self.knife_value

'''Replay function for training the Agent on past steps'''        
def replay(model, replay_memory, minibatch_size, env_size, gamma = 0.7):

    inputs = np.zeros((minibatch_size, env_size))
    targets = np.zeros((minibatch_size, n_actions))
    
    minibatch = np.random.choice(replay_memory, minibatch_size, replace=True) 
    
    '''Splitting states, actions, rewards etc. in lists'''
    state_list = np.array(list(map(lambda x: x['s'], minibatch)))
    action_list = np.array(list(map(lambda x: x['a'], minibatch)))
    reward_list = np.array(list(map(lambda x: x['r'], minibatch)))
    s_prime_list = np.array(list(map(lambda x: x['s_prime'], minibatch)))
    done_list = np.array(list(map(lambda x: x['done'], minibatch)))
    
    '''Taking, for each step in the batch, the state where the Agent started, the action performed to reach the next state (s_prime), the reward and so on. 
    Use these values to compute the Q values.'''
    for i, (s,a,r,sprime,done) in enumerate(zip(state_list, action_list, reward_list, s_prime_list, done_list)):
        inputs[i] = s
        targets[i] = model.predict(s)
        Q_sa = np.max(model.predict(sprime))
        if not done:
            targets[i,a] = r + gamma * Q_sa
        else:
            targets[i,a] = r 
    
    model.fit(inputs, targets, epochs=10, verbose=0) #training the model
    
    return model

'''Function for creating the animation of the Agent in the Environment'''
def animate_func(i):
    
    im.set_array(frames[i])
    return [im]
        

## Variables

In [None]:
surgery_env = SurgeryEnv()
surgery_env.gen_env() 
surgery_env.plot_env()

In [None]:
n_actions = surgery_env.n_actions
input_shape = surgery_env.surgery_mat.reshape(-1, surgery_env.size+1, 1).shape #for the model definition
print('Input shape:', input_shape)

## Model

In [None]:
'''This is the model used in the original paper'''
model = Sequential(name="sequential_layer")
model.add(Conv2D(16,kernel_size=2, input_shape = input_shape, activation = 'relu', name = "conv2d_layer1"))
model.add(Conv2D(32,kernel_size=3, activation = 'relu', name = "conv2d_layer2"))
model.add(Flatten(name = "flatten_layer"))
model.add(Dense(256, activation = 'relu', name = "dense_layer1"))
model.add(Dense(n_actions, activation = 'linear', name = "dense_layer2"))
model.compile(optimizer = RMSprop(), loss='MSE')

model.summary()

In [None]:
'''This is the model I used which is way simpler than the original one (as you can see by the number of parameters).'''
model = Sequential(name="sequential_layer")
model.add(Dense(surgery_env.surgery_mat.size, input_shape=(surgery_env.surgery_mat.size,)))
model.add(PReLU())
model.add(Dense(surgery_env.surgery_mat.size))
model.add(PReLU())
model.add(Dense(n_actions))
model.compile(optimizer = 'adam', loss='MSE')

model.summary()

## Training

In [None]:
n_episodes = 1500 #number of episodes
minibatch_size = 32 #size of the batch used in the Replay function
mem_max_size = 100000 #maximum number of episodes can be stored

In [None]:
%matplotlib qt

epsilon = 1.0
r_sums = [] #list to contain the sums of the rewards for each episode
replay_memory = [] #list to contain the step and retrain the model

for n in range(n_episodes):

    surgery_env.reset_env() 
    done = False
    r_sum = 0
    
    while not done:
        s = surgery_env.surgery_mat.reshape(1,-1).copy()
        
        if np.random.rand() < epsilon: #to enhance exploration
            a = rnd.choice(surgery_env.action_space) #choose a random action
        else: 
            qvals_s = model.predict(s)
            a = np.argmax(qvals_s)
        
        s_prime , r, done, _ = surgery_env.step_env(a)
        r_sum += r
        
        if len(replay_memory) > mem_max_size: #if the buffer is full, pop the least recent episode
            replay_memory.pop(0)
            
        replay_memory.append({"s":s,"a":a,"r":r,"s_prime":s_prime.reshape(1,-1),"done":done}) #store the step
        model= replay(model, replay_memory, minibatch_size = minibatch_size, env_size = surgery_env.surgery_mat.size) #train the model on past steps
        
    if epsilon > 0.001: #decrease the epsilon value
        epsilon -= 0.005 
    
    r_sums.append(r_sum)
    plt.plot(r_sums, color='red') #plot of the rewards
    
    if n >= 3: #after 3 episodes it's possible to plot also the trend line
        
        x = np.linspace(0,len(r_sums), len(r_sums), dtype= np.int32)
        z = np.polyfit(x, r_sums, 3)
        p = np.poly1d(z)
        plt.plot(x, p(x), linewidth = 2.5, color='blue')
    
    
    plt.title(f'Episode {n+1},  \u03B5= {round(epsilon,3)},  Avg reward={round(mean(r_sums),3)}')
    
    if (n!=n_episodes - 1): #until the training is not finished
        
        plt.draw()
        plt.pause(1.0)
        plt.cla()
        
    else:
        plt.show()
        
model.save('./knifeAgent_'+str(n_episodes)+'Episodes.h5') #saving the model

## Testing

In [None]:
trained_model = load_model('./knifeAgent_'+str(n_episodes)+'Episodes.h5') #loading the trained model (to save time)

test_env = SurgeryEnv()
test_env.reset_env()

finished = False
rsum = 0
frames = [test_env.surgery_mat.copy()] #this is for creating the animation

while not finished:
    
    s=test_env.surgery_mat.reshape(1,-1).copy()
    qvals = trained_model.predict(s)
    action = np.argmax(qvals)
    _ , reward, finished, info = test_env.step_env(action)
    print('Path:', info['path'])
    print('Actions', info['taken_actions'])
    print()
    rsum += reward
    
    frames.append(test_env.surgery_mat.copy()) #storing the frame 
    
print('Total reward:', rsum)


## Animation

In [None]:
%%capture

plt.rcParams["animation.html"] = "jshtml"

fps= len(frames)
nSeconds = len(frames)

fig = plt.figure(figsize=(5,5))

a = frames[0]
im = plt.imshow(a, interpolation='none', aspect='auto', vmin=0, vmax=255)
anim = FuncAnimation(
                               fig, 
                               animate_func, 
                               frames = len(frames),
                               interval = 10000 / fps, # ms
                               )

anim.save('./knifeAgent_'+str(n_episodes)+'Episodes.gif', writer='imagemagick', fps=int(fps/2)) 