In [1]:
import gym
from gym import spaces, error, utils
from gym.utils import seeding
import numpy as np
import copy
import random
%run 'solver.py'

In [2]:
N_TUBES = 5
N_COLORS = 4
H_TUBES = 4

DEFAULT_BOARD = [[3,2,1,3],[2,1,1,2],[1,2,3,4],[4,4,0,0],[3,4,0,0]]
arrCompleted = [0]*N_TUBES
Board = copy.deepcopy(DEFAULT_BOARD)
for arr in Board:
    for a in range(len(arr)-1,-1,-1):
        if(arr[a]==0):
            arr.pop(a)

In [3]:
States = {}
root = Node(None, Board,arrCompleted, N_COLORS, H_TUBES,N_TUBES, (-1, -1), 0, 0)
graph1 = Graph(root)
N_MOVES = graph1.getNeededMoves(root)
States = graph1.generateStates(root,H_TUBES)

In [4]:
class BallSortEnv(gym.Env):
    def __init__(self, alpha=0.02):
        self.action_space = {}
        count=0
        for i in range(0,N_TUBES):
            for j in range(0,N_TUBES):
                if (i != j):
                    self.action_space[count] = (i,j)
                    count +=1
        
        self.observation_space = States
        self.alpha = alpha
        self.reset()

    def render(self, mode='human'):
        print(self.board)
        
        
    def step(self, action):
        info = {}
        if(self.checkGameOver()):
            done = True
            reward = (N_MOVES*0.1) + 1.0
            return self.board, reward , done ,info
        else: 
            done = False
            reward = -0.1
        
        if(self.checkValidMove(action)):
            self.moveBall(action)   

        return self.board, reward , done ,info
    
    def reset(self):
        self.board = copy.deepcopy(DEFAULT_BOARD)
        return self.board
    
    def close(self):
        pass
    
    def checkGameOver(self):
        numColsCompleted = 0
        for col in self.board:
            if(col.count(col[0]) == len(col) and col[0] != 0):
                numColsCompleted += 1
        if (numColsCompleted == N_COLORS):
            return True
        else: return False
    
    def checkValidMove(self,move):
        fromCol,toCol = move
        fromIndex = -1
        toIndex = -1
        for number in range(len(self.board[fromCol])-1,-1,-1):
            if(self.board[fromCol][number] != 0):
                fromIndex = number
                break
                    
        for number in range(0,len(self.board[toCol])):
            if(self.board[toCol][number] == 0):
                toIndex = number  
                break
    
        if(fromIndex == -1 or toIndex == -1):
            return False
        
        if(toIndex != 0):
            if(self.board[fromCol][fromIndex] != self.board[toCol][toIndex-1]):
                return False
        
        
        return True
    
    def moveBall(self,move):
        fromCol,toCol = move
        fromIndex = -1
        toIndex = -1
        for number in range(len(self.board[fromCol])-1,-1,-1):
            if(self.board[fromCol][number] != 0):
                fromIndex = number
                break
                    
        for number in range(0,len(self.board[toCol])):
            if(self.board[toCol][number] == 0):
                toIndex = number  
                break
                
        ball = self.board[fromCol][fromIndex]
        self.board[fromCol][fromIndex] = 0
        self.board[toCol][toIndex] = ball
        
def get_key(val,my_dict):
    for key, value in my_dict.items():
         if val == value:
             return key
    return -1
    

In [5]:
env = BallSortEnv()
action_space_size = len(env.action_space)
state_space_size = len(env.observation_space)
q_table = np.zeros((state_space_size, action_space_size))

In [6]:
num_episodes = 1000
max_steps_per_episode = 200 # but it won't go higher than 1

learning_rate = 0.1
discount_rate = 0.99

exploration_rate = 1
max_exploration_rate = 1
min_exploration_rate = 0.01

exploration_decay_rate = 0.1 #if we decrease it, will learn slower

In [7]:
rewards_all_episodes = []

# Q-Learning algorithm
for episode in range(num_episodes):
    env.reset()
    
    state = 0
    
    done = False
    rewards_current_episode = 0
    
    for step in range(max_steps_per_episode):
        
        # Exploration -exploitation trade-off
        exploration_rate_threshold = random.uniform(0,1)
        if exploration_rate_threshold > exploration_rate:
            numAction = np.argmax(q_table[state,:])
            action = env.action_space[numAction]
        else:
            numAction = random.randint(0,len(env.action_space)-1)
            action = env.action_space[numAction]    
            
            
            
        new_state, reward, done, info = env.step(action)
        new_index = get_key(new_state,env.observation_space)
        
        q_table[state, numAction] = (1 - learning_rate) * q_table[state, numAction] + \
            learning_rate * (reward + discount_rate * np.max(q_table[new_index,:]))
        
        state = new_index
        rewards_current_episode += reward
        
        if done == True: 
            break
    exploration_rate = min_exploration_rate + \
        (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate * episode)
    
    rewards_all_episodes.append(rewards_current_episode)
    
# Calculate and print the average reward per 10 episodes
rewards_per_thousand_episodes = np.split(np.array(rewards_all_episodes), num_episodes / 100)
count = 100
print("********** Average  reward per thousand episodes **********\n")

for r in rewards_per_thousand_episodes:
    print(count, ": ", str(sum(r / 100)))
    count += 100
    
# Print updated Q-table
print("\n\n********** Q-table **********\n")
for table in range(0,len(q_table)):
    print(table," : ",q_table[table])

********** Average  reward per thousand episodes **********

100 :  -19.626999999999963
200 :  -18.347999999999978
300 :  -16.677999999999976
400 :  -14.182999999999986
500 :  -5.135000000000007
600 :  0.9640000000000006
700 :  0.9870000000000007
800 :  0.9880000000000007
900 :  0.9860000000000007
1000 :  0.9880000000000007


********** Q-table **********

0  :  [-0.84172048 -0.83703816 -0.83936119 -0.84019513 -0.84106769 -0.84101753
  3.67848363 -0.84193029 10.13231464 -0.84188605 -0.83536084 -0.84009199
 -0.83696854 -0.83553606 -0.84228876 -0.8380267  -0.83980582 -0.83588211
 -0.84047394 79.48138652]
1  :  [-0.7849874  -0.79259637 -0.78148105 -0.78442968 -0.77810026 -0.78433902
 -0.77843682 -0.78351441 -0.78674453 -0.77843682 -0.77843682 -0.78649102
 -0.77843682 -0.77843682 -0.77843682 -0.79082812 -0.77803487 -0.78535604
 -0.78600283 -0.71113897]
2  :  [-0.85032946 -0.85083225 -0.84229284 -0.84863083 -0.8410327  -0.84134466
 -0.84917212 -0.84279443 -0.84188605 -0.84162568 -0.84094612

472  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
473  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
474  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
475  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
476  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
477  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
478  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
479  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
480  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
481  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
482  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
483  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
484  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
485  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
486  :  [0. 0. 0. 0.

805  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
806  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
807  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
808  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
809  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
810  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
811  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
812  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
813  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
814  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
815  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
816  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
817  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
818  :  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
819  :  [0. 0. 0. 0.