# Final Project - Reinforcements Learning 
Hello dear students,<br> this is the template notebook. Please click on the "File" tab and then on "Save a copy into drive".

---
<br>

### Name and ID:
Student 1: Avraham Raviv, 204355390
<br>
Student 2: Yevgeni Berkovitch, 317079234
<br><br>
<img src="https://play-lh.googleusercontent.com/e_oKlKPISbgdzut1H9opevS7-LTB8-8lsmpCdMkhlnqFenZhpjxbLmx7l158-xQQCIY">

### https://github.com/mpSchrader/gym-sokoban

# Installs

In [1]:
%%capture
!sudo apt-get update
!sudo apt-get install -y xvfb ffmpeg freeglut3-dev
!pip install 'imageio==2.4.0'
!pip install gym
!pip install pygame
!apt-get install python-opengl -y
!apt install xvfb -y
!pip install pyvirtualdisplay
!pip install piglet
!pip install gym
!apt-get install python-opengl -y
!apt install xvfb -y
!pip install gym_sokoban

!imageio_download_bin ffmpeg

# Imports

In [2]:
import random
import time

import numpy as np
import scipy as scp
import matplotlib.pyplot as plt

import base64
import imageio
from pyvirtualdisplay import Display
from IPython.display import HTML

import gym
from gym import error, spaces, utils
from soko_pap import *

from collections import deque
from queue import PriorityQueue

from keras.models import Sequential
from keras.layers import Dense

from tqdm.notebook import tqdm
from collections import defaultdict

In [3]:
%matplotlib inline

In [4]:
imageio.plugins.ffmpeg.download()

In [5]:
from gym import logger as gymlogger
gymlogger.set_level(40) # error only

# Display utils
The cell below contains the video display configuration. No need to make changes here.

In [6]:
def embed_mp4(filename):
    """Embeds an mp4 file in the notebook."""
    video = open(filename,'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())

    return HTML(tag)

# Utils

In [7]:
def get_distances(room_state):
    for i in range(room_state.shape[0]):
        for j in range(room_state.shape[1]):
            if room_state[i][j] == 2:
                target = (i, j)

    distances = np.zeros(shape=room_state.shape)
    visited_cells = set()
    cell_queue = deque()

    visited_cells.add(target)
    cell_queue.appendleft(target)

    while len(cell_queue) != 0:
        cell = cell_queue.pop()
        distance = distances[cell[0]][cell[1]]
        for x,y in ((1,0), (-1,-0), (0,1), (0,-1)):
            next_cell_x, next_cell_y = cell[0]+x, cell[1]+y
            if room_state[next_cell_x][next_cell_y] != 0 and not (next_cell_x, next_cell_y) in visited_cells:
                distances[next_cell_x][next_cell_y] = distance + 1
                visited_cells.add((next_cell_x, next_cell_y))
                cell_queue.appendleft((next_cell_x, next_cell_y))
                
    return distances   

def calc_distances(room_state, distances):
    box = None
    mover = None
    for i in range(room_state.shape[0]):
        for j in range(room_state.shape[1]):            
            if room_state[i][j] == 4:
                box = (i,j)
            
            if room_state[i][j] == 5:
                mover = (i,j)
    
    return mover, box, distances[box[0]][box[1]]   

def box2target_change_reward(room_state, next_room_state, distances):
    if np.array_equal(room_state, next_room_state):
        return -1.0
    
    mover, box, t2b = calc_distances(room_state, distances)
    n_mover, n_box, n_t2b = calc_distances(next_room_state, distances)
    
    change_reward = 0.0
    if n_t2b < t2b:
        change_reward += 5.0
    elif n_t2b > t2b:
        change_reward -= 5.0
        
    m2b = np.sqrt((mover[0]-box[0])**2 + (mover[1]-box[1])**2)
    n_m2b = np.sqrt((n_mover[0]-n_box[0])**2 + (n_mover[1]-n_box[1])**2)
    
    if n_m2b < m2b and m2b >= 2:
        change_reward += 1.0
    elif n_m2b > m2b and n_m2b >= 2:
        change_reward -= 1.0
        
    return change_reward   

# Solution

In [8]:
class SOK_Agent:
    def __init__(self):
        # Construct DQN models
        self.state_size = (25,) 
        self.action_size = 8
        self.model = self._build_model()
        self.target_model = self._build_model()
        self.target_model.set_weights(self.model.get_weights())
        self.batch_size = 8
        
        # Replay buffers
        self.replay_buffer = deque(maxlen=5000)
        self.prioritized_replay_buffer = deque(maxlen=500)
        self.prioritized_replay_batch = 50        
        
        # Hyperparameters
        self.gamma = 0.9
        self.epsilon = 1.0   
        self.epsilon_min = 0.3
        self.epsilon_decay = 0.995
        self.replay_rate = 10
        self.update_beta = 0.999

        self.verbosity = 100 

    def _build_model(self):
        model = Sequential()
        model.add(Dense(25, input_shape=self.state_size, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer="adam")        
        return model

    def remember(self, state, action, reward, next_state, done):
        self.replay_buffer.append([state, action, reward, next_state, done])    
        
    def copy_to_prioritized_buffer(self, n):
        for i in range(n):
            self.prioritized_replay_buffer.append(self.replay_buffer[-1-i])  

    def act(self, state, stochastic=False):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        
        act_values = self.model.predict(state, verbose=0)[0]
        
        if stochastic:
            act_probs = np.exp(act_values)/np.exp(act_values).sum()
            return np.random.choice(np.arange(self.action_size), size=1, p=act_probs)[0]
              
        return np.argmax(act_values) 

    def replay(self): 
        if len(self.replay_buffer) < self.batch_size:
            return
        
        if len(self.prioritized_replay_buffer) < self.batch_size//2:
            minibatch = random.sample(self.replay_buffer, self.batch_size) 
        else:    
            minibatch = random.sample(self.replay_buffer, self.batch_size//2) 
            minibatch.extend(random.sample(self.prioritized_replay_buffer, self.batch_size//2))
        
        states = np.zeros((self.batch_size, self.state_size[0]))
        actions = np.zeros(self.batch_size, dtype=int)
        rewards = np.zeros(self.batch_size)
        next_states = np.zeros((self.batch_size, self.state_size[0]))
        statuses = np.zeros(self.batch_size)
        targets = np.zeros((self.batch_size, self.action_size)) 
        
        for i, (state, action, reward, next_state, done) in enumerate(minibatch): 
            states[i] = state.copy()
            actions[i] = action
            rewards[i] = reward
            next_states[i] = next_state.copy()
            statuses[i] = 1 if done else 0    
        
        targets = self.model.predict(states) 
        max_actions = np.argmax(self.model.predict(next_states), axis=1)
        next_rewards = self.target_model.predict(next_states)
        
        ind = 0
        for action, reward, next_reward, max_action, done in zip(actions, rewards, next_rewards, max_actions, statuses):  
            if not done:
                reward += self.gamma * next_reward[max_action]
            targets[ind][action] = reward
            ind += 1
        
        self.model.fit(states, targets, epochs=10, verbose=0) 
        
        self.update_target_model()        
    
        if self.epsilon > self.epsilon_min:
            self.epsilon = self.epsilon * self.epsilon_decay
        
    def update_target_model(self):
        model_w = self.model.get_weights()
        target_model_w = self.target_model.get_weights()
        updated_target_model_w = []
        for i in range(len(model_w)):
            updated_target_model_w.append(self.update_beta*target_model_w[i] + (1-self.update_beta)*model_w[i])
        self.target_model.set_weights(updated_target_model_w)    
            
    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)

In [9]:
def process_frame(frame):
    f = frame[16:96, 16:96, 0]   
    f = f.reshape(5, 16, 5, 16).max(axis=(1, 3))
    f = f.flatten()
    f = f / 255
    return np.expand_dims(f, axis=0)

## Training

#### Test Suite

In [10]:
def test_agent(stochastic=False):
    current_epsilon = agent.epsilon
    agent.epsilon = 0.0
    num_solved = 0
    solved_in_steps = defaultdict(int)

    for t in tqdm(range(100)):    
        random.seed(t)
        sok = PushAndPullSokobanEnv(dim_room=(7, 7), num_boxes=1)
        sok.set_maxsteps(20)
        steps = 0

        state = sok.get_image('rgb_array')
        done = False
        while not done:
            steps += 1
            action = agent.act(process_frame(state), stochastic)
            if action < 4:
                action += 1
            else:
                action += 5
            state, reward, done, info = sok.step(action)

        if 3 in sok.room_state:            
            num_solved += 1
            solved_in_steps[steps] += 1
    
    agent.epsilon = current_epsilon
    print("*" * 30)
    print("Stochastic" if stochastic else "Deterministic")
    print("*" * 30)
    print("Solved: %d" % num_solved)
    print("=" * 30)
    print(solved_in_steps)
    print("*" * 30)

In [11]:
max_episodes = 5000
max_steps = 100

def init_sok(r):
    random.seed(r+100)
    sok = PushAndPullSokobanEnv(dim_room=(7, 7), num_boxes=1)
    sok.set_maxsteps(max_steps)
    return sok

In [12]:
agent = SOK_Agent()

steps_per_episode = []

for e in range(max_episodes):
    if e % 100 == 0:
        test_agent(stochastic=False)
        test_agent(stochastic=True)
        
    print("Episode: %d" % (e))
    
    sok = init_sok(e)
    random.seed(e)
    
    state = process_frame(sok.get_image('rgb_array'))
    room_state = sok.room_state.copy() 
    distances = get_distances(room_state)
    
    for step in range(sok.max_steps):
        action = agent.act(state)
        if action < 4:
            next_state, reward, done, _ = sok.step(action+1) 
        else:
            next_state, reward, done, _ = sok.step(action+5)         
        
        next_state = process_frame(next_state)        
        next_room_state = sok.room_state
        
        if not done:
            reward += box2target_change_reward(room_state, next_room_state, distances)
        
        agent.remember(state, action, reward, next_state, done)
        
        state = next_state.copy() 
        room_state = next_room_state.copy()                
        
        if (step+1) % agent.replay_rate == 0:
            agent.replay()            
        
        if done:   
            steps_per_episode.append(step+1)
            
            if 3 in sok.room_state:                
                print("SOLVED! Episode %d Steps: %d Epsilon %.4f" % (e, step+1, agent.epsilon))
            else:    
                agent.copy_to_prioritized_buffer(min(agent.prioritized_replay_batch, step+1))
            
            #agent.save("exp1_episode%d.h5" % (e))            
            break

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 2
defaultdict(<class 'int'>, {1: 2})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 13
defaultdict(<class 'int'>, {15: 1, 10: 2, 7: 1, 6: 1, 2: 1, 17: 2, 9: 2, 4: 2, 3: 1})
******************************
Episode: 0
Episode: 1
Episode: 2
Episode: 3
Episode: 4
Episode: 5
SOLVED! Episode 5 Steps: 27 Epsilon 0.7705
Episode: 6
Episode: 7
SOLVED! Episode 7 Steps: 9 Epsilon 0.7329
Episode: 8
Episode: 9
Episode: 10
Episode: 11
Episode: 12
Episode: 13
SOLVED! Episode 13 Steps: 5 Epsilon 0.5704
Episode: 14
SOLVED! Episode 14 Steps: 78 Epsilon 0.5507
Episode: 15
Episode: 16
SOLVED! Episode 16 Steps: 43 Epsilon 0.5134
Episode: 17
Episode: 18
Episode: 19
Episode: 20
Episode: 21
Episode: 22
SOLVED! Episode 22 Steps: 25 Epsilon 0.3956
Episode: 23
Episode: 24
Episode: 25
Episode: 26
SOLVED! Episode 26 Steps: 6 Epsilon 0.3404
Episode: 27
Episode: 28
Episode: 29
Episode: 30
Episode: 31
Episode: 32
Episode: 33
Episode: 34
Episode: 35
Episode: 36
Episode: 37
SOLVED! Episode 37 Steps: 3 Epsilon 0.2988
Episod

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 9
defaultdict(<class 'int'>, {1: 3, 2: 1, 3: 5})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 11
defaultdict(<class 'int'>, {11: 1, 17: 1, 10: 2, 7: 2, 4: 2, 13: 1, 3: 2})
******************************
Episode: 100
Episode: 101
Episode: 102
SOLVED! Episode 102 Steps: 3 Epsilon 0.2988
Episode: 103
SOLVED! Episode 103 Steps: 93 Epsilon 0.2988
Episode: 104
Episode: 105
Episode: 106
Episode: 107
Episode: 108
Episode: 109
Episode: 110
Episode: 111
Episode: 112
Episode: 113
Episode: 114
SOLVED! Episode 114 Steps: 1 Epsilon 0.2988
Episode: 115
Episode: 116
Episode: 117
Episode: 118
Episode: 119
Episode: 120
Episode: 121
Episode: 122
Episode: 123
Episode: 124
Episode: 125
Episode: 126
Episode: 127
Episode: 128
SOLVED! Episode 128 Steps: 2 Epsilon 0.2988
Episode: 129
SOLVED! Episode 129 Steps: 45 Epsilon 0.2988
Episode: 130
Episode: 131
Episode: 132
Episode: 133
SOLVED! Episode 133 Steps: 99 Epsilon 0.2988
Episode: 134
Episode: 135
Episode: 136
Episode: 137
Episode: 138
Episode: 139
Episode: 140
Episode: 

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 14
defaultdict(<class 'int'>, {3: 7, 1: 5, 2: 2})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 11
defaultdict(<class 'int'>, {10: 2, 6: 1, 15: 1, 18: 1, 9: 2, 4: 2, 3: 2})
******************************
Episode: 200
SOLVED! Episode 200 Steps: 2 Epsilon 0.2988
Episode: 201
Episode: 202
Episode: 203
Episode: 204
Episode: 205
Episode: 206
Episode: 207
Episode: 208
Episode: 209
SOLVED! Episode 209 Steps: 2 Epsilon 0.2988
Episode: 210
Episode: 211
Episode: 212
Episode: 213
Episode: 214
SOLVED! Episode 214 Steps: 34 Epsilon 0.2988
Episode: 215
Episode: 216
SOLVED! Episode 216 Steps: 3 Epsilon 0.2988
Episode: 217
Episode: 218
Episode: 219
SOLVED! Episode 219 Steps: 21 Epsilon 0.2988
Episode: 220
Episode: 221
SOLVED! Episode 221 Steps: 6 Epsilon 0.2988
Episode: 222
Episode: 223
Episode: 224
Episode: 225
Episode: 226
Episode: 227
Episode: 228
Episode: 229
Episode: 230
Episode: 231
Episode: 232
Episode: 233
Episode: 234
SOLVED! Episode 234 Steps: 11 Epsilon 0.2988
Episode: 235
Episode: 236
Episode: 237
Episo

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 25
defaultdict(<class 'int'>, {3: 14, 1: 7, 2: 4})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 24
defaultdict(<class 'int'>, {6: 3, 7: 3, 3: 2, 10: 3, 4: 6, 1: 2, 16: 1, 11: 2, 5: 2})
******************************
Episode: 300
SOLVED! Episode 300 Steps: 1 Epsilon 0.2988
Episode: 301
Episode: 302
SOLVED! Episode 302 Steps: 2 Epsilon 0.2988
Episode: 303
Episode: 304
Episode: 305
Episode: 306
Episode: 307
SOLVED! Episode 307 Steps: 32 Epsilon 0.2988
Episode: 308
Episode: 309
Episode: 310
SOLVED! Episode 310 Steps: 76 Epsilon 0.2988
Episode: 311
Episode: 312
Episode: 313
Episode: 314
Episode: 315
Episode: 316
Episode: 317
Episode: 318
SOLVED! Episode 318 Steps: 2 Epsilon 0.2988
Episode: 319
Episode: 320
Episode: 321
Episode: 322
SOLVED! Episode 322 Steps: 1 Epsilon 0.2988
Episode: 323
Episode: 324
SOLVED! Episode 324 Steps: 94 Epsilon 0.2988
Episode: 325
Episode: 326
Episode: 327
Episode: 328
Episode: 329
Episode: 330
Episode: 331
Episode: 332
Episode: 333
Episode: 334
SOLVED! Episode 334 Steps: 6 Eps

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 17
defaultdict(<class 'int'>, {3: 6, 2: 4, 1: 7})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 12
defaultdict(<class 'int'>, {16: 1, 6: 3, 7: 3, 9: 3, 3: 2})
******************************
Episode: 400
Episode: 401
Episode: 402
Episode: 403
SOLVED! Episode 403 Steps: 3 Epsilon 0.2988
Episode: 404
Episode: 405
Episode: 406
Episode: 407
Episode: 408
Episode: 409
Episode: 410
Episode: 411
Episode: 412
Episode: 413
SOLVED! Episode 413 Steps: 85 Epsilon 0.2988
Episode: 414
Episode: 415
Episode: 416
Episode: 417
Episode: 418
Episode: 419
Episode: 420
Episode: 421
Episode: 422
Episode: 423
Episode: 424
SOLVED! Episode 424 Steps: 50 Epsilon 0.2988
Episode: 425
Episode: 426
Episode: 427
Episode: 428
Episode: 429
Episode: 430
Episode: 431
Episode: 432
Episode: 433
Episode: 434
Episode: 435
SOLVED! Episode 435 Steps: 53 Epsilon 0.2988
Episode: 436
Episode: 437
Episode: 438
SOLVED! Episode 438 Steps: 7 Epsilon 0.2988
Episode: 439
Episode: 440
Episode: 441
SOLVED! Episode 441 Steps: 52 Epsilon 0.2988
Episode: 4

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 4
defaultdict(<class 'int'>, {1: 3, 3: 1})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 16
defaultdict(<class 'int'>, {9: 2, 11: 3, 3: 3, 10: 1, 20: 2, 6: 1, 4: 2, 8: 1, 7: 1})
******************************
Episode: 500
SOLVED! Episode 500 Steps: 21 Epsilon 0.2988
Episode: 501
Episode: 502
Episode: 503
SOLVED! Episode 503 Steps: 3 Epsilon 0.2988
Episode: 504
Episode: 505
Episode: 506
Episode: 507
Episode: 508
Episode: 509
Episode: 510
SOLVED! Episode 510 Steps: 1 Epsilon 0.2988
Episode: 511
Episode: 512
SOLVED! Episode 512 Steps: 6 Epsilon 0.2988
Episode: 513
Episode: 514
Episode: 515
Episode: 516
Episode: 517
Episode: 518
Episode: 519
Episode: 520
Episode: 521
Episode: 522
Episode: 523
SOLVED! Episode 523 Steps: 31 Epsilon 0.2988
Episode: 524
Episode: 525
Episode: 526
Episode: 527
Episode: 528
SOLVED! Episode 528 Steps: 8 Epsilon 0.2988
Episode: 529
Episode: 530
SOLVED! Episode 530 Steps: 18 Epsilon 0.2988
Episode: 531
Episode: 532
SOLVED! Episode 532 Steps: 45 Epsilon 0.2988
Episode: 533


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 4
defaultdict(<class 'int'>, {1: 2, 3: 2})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 11
defaultdict(<class 'int'>, {9: 1, 10: 2, 18: 1, 17: 1, 11: 1, 14: 1, 6: 1, 4: 1, 3: 2})
******************************
Episode: 600
Episode: 601
Episode: 602
SOLVED! Episode 602 Steps: 45 Epsilon 0.2988
Episode: 603
Episode: 604
SOLVED! Episode 604 Steps: 47 Epsilon 0.2988
Episode: 605
Episode: 606
Episode: 607
Episode: 608
SOLVED! Episode 608 Steps: 18 Epsilon 0.2988
Episode: 609
Episode: 610
Episode: 611
SOLVED! Episode 611 Steps: 15 Epsilon 0.2988
Episode: 612
Episode: 613
SOLVED! Episode 613 Steps: 13 Epsilon 0.2988
Episode: 614
Episode: 615
Episode: 616
Episode: 617
Episode: 618
Episode: 619
Episode: 620
SOLVED! Episode 620 Steps: 3 Epsilon 0.2988
Episode: 621
Episode: 622
Episode: 623
Episode: 624
Episode: 625
SOLVED! Episode 625 Steps: 30 Epsilon 0.2988
Episode: 626
Episode: 627
SOLVED! Episode 627 Steps: 75 Epsilon 0.2988
Episode: 628
SOLVED! Episode 628 Steps: 85 Epsilon 0.2988
Episode: 629
Ep

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Deterministic
******************************
Solved: 1
defaultdict(<class 'int'>, {1: 1})
******************************


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


******************************
Stochastic
******************************
Solved: 7
defaultdict(<class 'int'>, {15: 1, 3: 2, 6: 3, 4: 1})
******************************
Episode: 700
Episode: 701
Episode: 702
SOLVED! Episode 702 Steps: 6 Epsilon 0.2988
Episode: 703


KeyboardInterrupt: 