In [1]:
import cv2
import argparse
import imutils
import numpy as np
import os
import random 
import _pickle as pickles
from math import sqrt

In [2]:
# This is the environment created for the RL agent to work on 
class colour_env:
    
    def __init__(self, fn, c):
        
        # the image to perform coloring in
        self.img = fn 
        #the part of the image to be colored in the form of a contour
        self.c = c
        #the initial color space of the object
        self.color = [0,0,0]
        #possible actions that the RL agent can perform, increase or decrease the value of blues, greens and reds
        self._actions = ['bp','bn','gp','gn','rp','rn']
        
    def color_img(self):
        
        #making color tuple 
        b = int(self.color[0])
        g = int(self.color[1])
        r = int(self.color[2])
        full_color = (b,g,r)
        # to color the shape passed
        cv2.drawContours(self.img, [self.c], -1, full_color, -1)

    def get_cur_state(self):
        
        #finding the current color of the contour
        M = cv2.moments(self.c)
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        #getting the color in the pixel placed in the centre of the contor
        self.color = self.img[cX, cY]
        return self.color
    
    def init(self):
        
        #returning the current color
        return self.color
    
    def do_action(self,action):
        
        #Assigning the color value for all the actions
        if action == 'bp':
            self.color[0] += 65
            self.color[1] -= 65
            self.color[2] -= 65
        elif action == 'bn':
            self.color[0] -= 65
            self.color[1] += 65
            self.color[2] += 65
        elif action == 'gp':
            self.color[1] += 65
            self.color[0] -= 65
            self.color[2] -= 65
        elif action == 'gn':
            self.color[1] -= 65
            self.color[0] += 65
            self.color[2] += 65
        elif action == 'rp':
            self.color[2] += 65
            self.color[0] -= 65
            self.color[1] -= 65
        elif action == 'rn':
            self.color[2] -=65
            self.color[1] += 65
            self.color[0] += 65
            
    def check_state(self):
        
        #checking if the color values are within the range of [0,255]
        if self.color[0] > 255:
            self.color[0] = 255
        if self.color[1] > 255:
            self.color[1] = 255
        if self.color[2] > 255:
            self.color[2] = 255
        if self.color[0] < 0:
            self.color[0] = 0
        if self.color[1] < 0:
            self.color[1] = 0
        if self.color[2] < 0:
            self.color[2] = 0
            
        
    def reward(self, a, shape):
        
        # color the image with the present color
        self.color_img()
        #checking for the shape to be colored 
        # if the detected shape is triangle is Blue
        if shape == 0:
            if np.argmax(self.color) == 0:
                if self.color[0] >= 200 and self.color[1] <= 100 and self.color[2] <= 100:
                    return 30
                return 10
            elif self.color[0] == 0:
                return -30
            else:
                return -10
        
        #if the detected shape is rectangle is Green
        if shape == 1:
            if np.argmax(self.color) == 1:
                if self.color[0] <= 100 and self.color[1] >= 200 and self.color[2] <= 100:
                    return 30
                return 10
            elif self.color[1] == 0:
                return -30
            else:
                return -10
        
        #if the detected shape is circle is Red
        if shape == 2:
            if np.argmax(self.color) == 2:
                if self.color[0] <= 100 and self.color[1] <= 100 and self.color[2] >= 200:
                    return 30
                return 10
            elif self.color[2] == 0:
                return -30
            else:
                return -10
        
        
    def is_goal(self, r):
        
        #defining the goal state
        if r == 30:
            return True
        else:
            return False
        
    def get_actions(self):
        
        #getting an action to be performed
        return self._actions

In [3]:
# The RL agent to color the objects in the image
class RLAgent: 
    def __init__(self, env, shape):
        
        #the shape of the incoming object
        self.shape = shape
        #environment for the RL agent
        self.env = env
        #number of possible actions
        self.n_a = len(env.get_actions())
        #definition of the Q table in the form of a dictionary
        self.Q = {}
    
    #checking if the state is between the range of [0,255]
    def check_s(self,s):
        
        if s[0] > 255:
            s[0] = 255
        if s[1] > 255:
            s[1] = 255
        if s[2] > 255:
            s[2] = 255
        if s[0] < 0:
            s[0] = 0
        if s[1] < 0:
            s[1] = 0
        if s[2] < 0:
            s[2] = 0
        return s
    
    #To read the Q table from the respective Q-table file 
    def reading_shape(self):
        #table found for triangle
        if self.shape == 0:
            print("file0")
            with open('file0.txt','rb') as handle:
                self.Q = pickles.loads(handle.read())
        #table found for rectangle
        if self.shape == 1:
            print("file1")
            with open('file1.txt','rb') as handle:
                self.Q = pickles.loads(handle.read())
        #table found for circle
        if self.shape == 2:
            print("file2")
            with open('file2.txt','rb') as handle:
                self.Q = pickles.loads(handle.read())
    
    #To create a file that contains the Q-table for the detected shape
    def writing_shape(self):
        
        #to write Q-table of triangle
        if self.shape == 0:
            with open ('file0.txt','wb') as handle:
                pickles.dump(self.Q,handle)
        #to write Q-table of rectangle
        if self.shape == 1:
            with open ('file1.txt','wb') as handle:
                pickles.dump(self.Q,handle)
        #to write Q-table of circle
        if self.shape == 2:
            with open ('file2.txt','wb') as handle:
                pickles.dump(self.Q,handle)
    
    #function to pick the action 
    def epsilon_greed(self, epsilon, s):
        
        #checks the validity of s
        s = self.check_s(s)
        #converts it into a tuple
        s = tuple(s)
        #if the state not present then create an entry
        if s not in self.Q:
            self.Q[s] = dict((action, 0) for action in env.get_actions())
        #generated random action 
        if np.random.rand() < epsilon:
            return random.choice(env.get_actions())
        else:
            #greedy possible action
            Q_maximum = max(self.Q[s].values())
            possible_actions = []
            for i in list(env.get_actions()):
                if(self.Q[s][i] == Q_maximum):
                    possible_actions.append(i)
            return random.choice(possible_actions)  

    def train(self, **params):
        
        # parameters
        gamma = params.pop('gamma', 0.99)
        alpha = params.pop('alpha', 0.1)
        epsilon= params.pop('epsilon', 0.1)
        maxiter= params.pop('maxiter', 100) 
        maxstep= params.pop('maxstep', 1000) 
        #training the agent multiple time
        for j in range(maxiter):
            if j == 0:
                #for the first training the colors are initialised to just white
                s = self.env.init()
            else:
                #this is for other iterations
                s = self.env.get_cur_state()
            a = self.epsilon_greed(epsilon, s)
            s = tuple(s)
            rewards = []
            
            for step in range(maxstep):
                s1 = self.env.get_cur_state()
                a1 = self.epsilon_greed(epsilon,s1)
                s1 = tuple(s1)
                self.env.do_action(a1)
                
                reward = self.env.reward(a,self.shape)
                rewards.append(reward)
                #updating the Q table
                self.Q[s][a] = self.Q[s][a] + alpha * (reward + (gamma*max(self.Q[s1].values()))- self.Q[s][a])
                #checking for the final state
                if self.env.is_goal(reward):
                    self.Q[s1][a1] = 0
                    break
                #re-assigning the state and the action    
                s = s1
                a = a1
        #copying the q table on the file
        self.writing_shape()
        #returns the colored image and the rewards collected
        return self.env.img ,reward 
    
    def test(self,c ,shape, image, maxstep=1000):
        #parameters 
        epsilon = 0.1
        self.shape = shape
        self.env.img = image
        self.env.c = c
        #read the respective Q table from the file of the detected shape
        self.reading_shape()
        #initialise the color space
        s = self.env.init()
        #pick an action 
        a = self.epsilon_greed(epsilon, s)
        s = tuple(s)
        rewards = []
        #coloring and finding the apt color for the detected object
        for step in range(maxstep):
                s1 = self.env.get_cur_state()
                a1 = self.epsilon_greed(epsilon,s1)
                s1 = tuple(s1)
                self.env.do_action(a1)
                #receive the rewards
                reward = self.env.reward(a,self.shape)
                rewards.append(reward)
                #check for goal state 
                if self.env.is_goal(reward):
                    self.Q[s1][a1] = 0
                    break
                #re-setting the state and actions   
                s = s1
                a = a1
        #returns the colored image and rewards
        return self.env.img, rewards

In [4]:
#class for detecting the shape
class ShapeDetector:
    def __init__(self):
        #for maintaing a list with appropriate edges detected
        self.approx_copy = []
    
    #to detect anomalies in edges
    def dist(self,a,b):
        x1,y1,x2,y2 = a[0],a[1],b[0],b[1]
        dist = sqrt( (x2 - x1)**2 + (y2 - y1)**2 )
        #discard an edge with a distance less than 5
        if dist < 5:
            self.approx_copy = self.approx_copy.tolist()
            b = b.tolist()
            if b in self.approx_copy:
                self.approx_copy.remove(b)
            self.approx_copy = np.array(self.approx_copy)
                
    def detect(self, c):
    # initialize the shape name and approximate the contour
        shape = "unidentified"
        peri = cv2.arcLength(c, True)
        approx = cv2.approxPolyDP(c, 0.01 * peri, True)
        approx = approx.reshape(-1,2)
        self.approx_copy = np.copy(approx)
        #procedure to eliminate irrelevant edges
        for i in range(len(approx)-1):
            x= i+1
            for j in range(x, len(approx)):
                self.dist(approx[i],approx[j])
        
        #3 vertices detected for a triangle
        if len(self.approx_copy) == 3:
            shape = 0
            print("triangle")
 
        # 4 vertices detected for a rectangle
        elif len(approx) == 4:
            shape = 1
            print("rectangle")
        
        #if its not either of the above then its a circle
        else:
            shape = 2
            print("circle")
 
        # return the name of the shape
        return shape


In [10]:
# load the image and resize it to a smaller factor so that
# the shapes can be approximated better

type_input = input("Enter 1 for overlapping and 2 for non-overlapping:")

if type_input == "1":
    image = cv2.imread('image-2.png')
    resized = imutils.resize(image, width=300)
    ratio = image.shape[0] / float(resized.shape[0])
    gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]

    element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(37,37))
    thresh = cv2.erode(thresh,element)
    
    cv2.imwrite('temp.png',thresh)
    image = cv2.imread('temp.png')
    resized = imutils.resize(image, width=300)
    ratio = image.shape[0] / float(resized.shape[0])

else:
    image = cv2.imread('image-2.png')
    resized = imutils.resize(image, width=300)
    ratio = image.shape[0] / float(resized.shape[0])
    gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    thresh = cv2.threshold(blurred, 60, 255, cv2.THRESH_BINARY)[1]

# find contours in the thresholded image and initialize the
# shape detector
cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if imutils.is_cv2() else cnts[1]
sd = ShapeDetector()
resized_copy = np.copy(resized)
x=0

# loop over the contours
for c in cnts:
    # compute the center of the contour, then detect the name of the
    # shape using only the contour
    shape = sd.detect(c)
    env = colour_env(resized_copy, c)
    agent = RLAgent(env, shape)
    agent.train(gamma=0.99, 
                    alpha=0.1, 
                    epsilon=0.1, 
                    maxiter=100, 
                    maxstep=100)
    
 #show the output image
for d in cnts:
    shape = sd.detect(d)
    cv2.imwrite('messi'+str(x)+'.png',resized)
    resized, rewards = agent.test(d, shape, resized)
    cv2.imwrite('messigray'+str(x)+'.png',resized)
    x=x+1

Enter 1 for overlapping and 2 for non-overlapping:1
circle
triangle
circle
file2
triangle
file0
