# Final Project - Reinforcements Learning 
Hello dear students,<br> this is the template notebook. Please click on the "File" tab and then on "Save a copy into drive".

---
<br>

### Name and ID:
Student 1: Avraham Raviv, 204355390
<br>
Student 2: Yevgeni Berkovitch, 317079234
<br><br>
<img src="https://play-lh.googleusercontent.com/e_oKlKPISbgdzut1H9opevS7-LTB8-8lsmpCdMkhlnqFenZhpjxbLmx7l158-xQQCIY">

### https://github.com/mpSchrader/gym-sokoban

# Installs

In [1]:
%%capture
!sudo apt-get update
!sudo apt-get install -y xvfb ffmpeg freeglut3-dev
!pip install 'imageio==2.4.0'
!pip install gym
!pip install pygame
!apt-get install python-opengl -y
!apt install xvfb -y
!pip install pyvirtualdisplay
!pip install piglet
!pip install gym
!apt-get install python-opengl -y
!apt install xvfb -y
!pip install gym_sokoban

!imageio_download_bin ffmpeg

# Imports

In [2]:
import random

import numpy as np
import matplotlib.pyplot as plt

import base64
import imageio
from pyvirtualdisplay import Display
from IPython.display import HTML

import gym
from gym import error, spaces, utils
from soko_pap import *

from collections import deque

from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten

from tqdm.notebook import tqdm
from collections import defaultdict

In [3]:
%matplotlib inline

In [4]:
imageio.plugins.ffmpeg.download()

In [5]:
from gym import logger as gymlogger
gymlogger.set_level(40) # error only

In [6]:
import warnings
warnings.filterwarnings('ignore')

# Display utils
The cell below contains the video display configuration. No need to make changes here.

In [7]:
def embed_mp4(filename):
    """Embeds an mp4 file in the notebook."""
    video = open(filename,'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())

    return HTML(tag)

# Utils

In [8]:
def get_distances_for_target(room_state, target):
    distances = np.zeros(shape=room_state.shape)
    visited_cells = set()
    cell_queue = deque()

    visited_cells.add(target)
    cell_queue.appendleft(target)

    while len(cell_queue) != 0:
        cell = cell_queue.pop()
        distance = distances[cell[0]][cell[1]]
        for x,y in ((1,0), (-1,-0), (0,1), (0,-1)):
            next_cell_x, next_cell_y = cell[0]+x, cell[1]+y
            if room_state[next_cell_x][next_cell_y] != 0 and not (next_cell_x, next_cell_y) in visited_cells:
                distances[next_cell_x][next_cell_y] = distance + 1
                visited_cells.add((next_cell_x, next_cell_y))
                cell_queue.appendleft((next_cell_x, next_cell_y))
                
    return distances

def get_distances(room_state):
    targets = []
    for i in range(room_state.shape[0]):
        for j in range(room_state.shape[1]):
            if room_state[i][j] in (2, 3):
                targets.append((i, j))

    distances1 = get_distances_for_target(room_state, targets[0])
    distances2 = get_distances_for_target(room_state, targets[1])
    return np.minimum(distances1, distances2)

def calc_distances(room_state, distances):
    boxes = []
    for i in range(room_state.shape[0]):
        for j in range(room_state.shape[1]):            
            if room_state[i][j] in (3,4):
                boxes.append((i,j))
    
    return distances[boxes[0][0]][boxes[0][1]] + distances[boxes[1][0]][boxes[1][1]]   

def box2target_change_reward(room_state, next_room_state, distances):
    if np.array_equal(room_state, next_room_state):
        return -5.0
    
    t2b = calc_distances(room_state, distances)
    n_t2b = calc_distances(next_room_state, distances)
    
    change_reward = 0.0
    if n_t2b < t2b:
        change_reward += 5.0
    elif n_t2b > t2b:
        change_reward -= 5.0
        
    return change_reward   

# Solution

In [9]:
class SOK_Agent:
    def __init__(self):
        # Construct DQN models
        self.state_size = (112,112,1) 
        self.action_size = 8
        self.model = self._build_model()
        self.target_model = self._build_model()
        self.target_model.set_weights(self.model.get_weights())
        self.batch_size = 8
        
        # Replay buffers
        self.replay_buffer = deque(maxlen=50000)
        self.prioritized_replay_buffer = deque(maxlen=50000)
        
        # Hyperparameters
        self.gamma = 0.9
        self.epsilon = 1.0   
        self.epsilon_min = 0.3
        self.epsilon_decay = 0.995
        self.replay_rate = 10
        self.update_beta = 0.999
        
        self.action_rotation_map = {
            0: 2,
            1: 3,
            2: 1,
            3: 0,
            4: 6,
            5: 7,
            6: 5,
            7: 4
        }

    def _build_model(self):
        model = Sequential()
        model.add(Conv2D(32, (16,16), strides=(16,16), input_shape=self.state_size, activation='relu'))
        model.add(Conv2D(64, (3,3), activation='relu'))
        model.add(Conv2D(64, (3,3), padding='same', activation='relu'))
        model.add(Flatten())
        model.add(Dense(512, activation='relu'))    
        model.add(Dense(64, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer="adam")        
        return model

    def remember(self, state, action, reward, next_state, done):
        self.replay_buffer.append([state, action, reward, next_state, done])    
        
    def copy_to_prioritized_buffer(self, n):
        for i in range(n):
            self.prioritized_replay_buffer.append(self.replay_buffer[-1-i])  

    def act(self, state, stochastic=False):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        
        act_values = self.model.predict(state, verbose=0)[0]
        
        if stochastic:
            act_probs = np.exp(act_values)/np.exp(act_values).sum()
            return np.random.choice(np.arange(self.action_size), size=1, p=act_probs)[0]
              
        return np.argmax(act_values) 

    def replay(self): 
        if len(self.replay_buffer) < self.batch_size:
            return
        
        if len(self.prioritized_replay_buffer) < self.batch_size//2:
            minibatch = random.sample(self.replay_buffer, self.batch_size) 
        else:    
            minibatch = random.sample(self.replay_buffer, self.batch_size//2) 
            minibatch.extend(random.sample(self.prioritized_replay_buffer, self.batch_size//2))
        
        states = np.zeros((self.batch_size*4, self.state_size[0], self.state_size[1]))
        actions = np.zeros(self.batch_size*4, dtype=int)
        rewards = np.zeros(self.batch_size*4)
        next_states = np.zeros((self.batch_size*4, self.state_size[0], self.state_size[1]))
        statuses = np.zeros(self.batch_size*4)
        targets = np.zeros((self.batch_size*4, self.action_size)) 
        
        for i, (state, action, reward, next_state, done) in enumerate(minibatch): 
            for rot in range(4):  
                ind = i*4+rot
                if rot != 0:
                    state = np.rot90(state, axes=(1,2))
                    next_state = np.rot90(next_state, axes=(1,2))
                    action = self.action_rotation_map.get(action)

                states[ind] = state.copy()
                actions[ind] = action
                rewards[ind] = reward
                next_states[ind] = next_state.copy()
                statuses[ind] = 1 if done else 0          
        
        targets = self.model.predict(states) 
        max_actions = np.argmax(self.model.predict(next_states), axis=1)
        next_rewards = self.target_model.predict(next_states)
        
        ind = 0
        for action, reward, next_reward, max_action, done in zip(actions, rewards, next_rewards, max_actions, statuses):  
            if not done:
                reward += self.gamma * next_reward[max_action]
            targets[ind][action] = reward
            ind += 1
        
        self.model.fit(states, targets, epochs=10, verbose=0) 
        
        self.update_target_model()        
    
        if self.epsilon > self.epsilon_min:
            self.epsilon = self.epsilon * self.epsilon_decay    
        
    def update_target_model(self):
        model_w = self.model.get_weights()
        target_model_w = self.target_model.get_weights()
        updated_target_model_w = []
        for i in range(len(model_w)):
            updated_target_model_w.append(self.update_beta*target_model_w[i] + (1-self.update_beta)*model_w[i])
        self.target_model.set_weights(updated_target_model_w)    
            
    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)

In [10]:
def process_frame(frame):
    f = frame.mean(axis=2)
    f = f / 255
    return np.expand_dims(f, axis=0)

## Training

#### Test Suite

In [11]:
max_episodes = 50000
max_steps = 30

def init_sok(r):
    random.seed(r)
    sok = PushAndPullSokobanEnv(dim_room=(7, 7), num_boxes=2)
    sok.set_maxsteps(max_steps)
    return sok

In [12]:
def test_agent(e, cur_record, stochastic=False):
    current_epsilon = agent.epsilon
    agent.epsilon = 0.0
    num_solved = 0
    solved_in_steps = defaultdict(int)

    for t in tqdm(range(100)):    
        sok = init_sok(t)
        steps = 0

        state = sok.get_image('rgb_array')
        done = False
        while not done:
            steps += 1
            action = agent.act(process_frame(state), stochastic)
            if action < 4:
                action += 1
            else:
                action += 5
            state, reward, done, info = sok.step(action)

        if 3 in sok.room_state:            
            num_solved += 1
            solved_in_steps[steps] += 1
    
    agent.epsilon = current_epsilon    
    print("Episode %d Solved: %d" % (e+1, num_solved))
    
    if num_solved > cur_record:
        agent.save("models\Q3_02A_%d.h5" % (num_solved)) 
        cur_record = num_solved
        
    return num_solved, cur_record  

In [None]:
agent = SOK_Agent()

running_puzzles = 0
running_solved = 0
record = 0
solved_tests = []

for e in range(max_episodes):
    sok = init_sok(e+100)
    random.seed(e)
    running_puzzles += 1
    
    state = process_frame(sok.get_image('rgb_array'))
    room_state = sok.room_state.copy() 
    distances = get_distances(room_state)
    
    for step in range(sok.max_steps):
        action = agent.act(state)
        if action < 4:
            next_state, reward, done, _ = sok.step(action+1) 
        else:
            next_state, reward, done, _ = sok.step(action+5)         
        
        next_state = process_frame(next_state)        
        next_room_state = sok.room_state
        
        if not done:
            reward += box2target_change_reward(room_state, next_room_state, distances)
        
        agent.remember(state, action, reward, next_state, done)
        
        state = next_state.copy() 
        room_state = next_room_state.copy()                
        
        if (step+1) % agent.replay_rate == 0:
            agent.replay()            
        
        if done: 
            if sok.boxes_on_target == 2:  
                agent.copy_to_prioritized_buffer(step+1)  
                running_solved += 1
                
            if (e+1) % 10 == 0 and e > 0:
                print(f"{running_solved} | {running_puzzles}") 

                if (e+1) % 100 == 0:
                    running_puzzles = 0
                    running_solved = 0
                    
            break
            
    if (e+1) % 100 == 0 and e > 0:
        num_solved, record = test_agent(e, record, stochastic=False) 
        solved_tests.append(num_solved)

0 | 10
2 | 20
2 | 30
3 | 40
3 | 50
3 | 60
5 | 70
6 | 80
6 | 90
7 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 100 Solved: 29
2 | 10
2 | 20
2 | 30
2 | 40
3 | 50
5 | 60
5 | 70
5 | 80
6 | 90
[SOKOBAN] Retry . . .
7 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 200 Solved: 32
2 | 10
4 | 20
4 | 30
4 | 40
5 | 50
[SOKOBAN] Retry . . .
5 | 60
[SOKOBAN] Retry . . .
6 | 70
6 | 80
6 | 90
6 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 300 Solved: 30
1 | 10
3 | 20
4 | 30
8 | 40
9 | 50
10 | 60
11 | 70
12 | 80
13 | 90
14 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 400 Solved: 47
1 | 10
2 | 20
3 | 30
4 | 40
5 | 50
[SOKOBAN] Retry . . .
8 | 60
[SOKOBAN] Retry . . .
9 | 70
10 | 80
10 | 90
12 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 500 Solved: 53
2 | 10
[SOKOBAN] Retry . . .
4 | 20
8 | 30
8 | 40
[SOKOBAN] Retry . . .
9 | 50
10 | 60
13 | 70
[SOKOBAN] Retry . . .
13 | 80
14 | 90
16 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 600 Solved: 54
3 | 10
7 | 20
8 | 30
11 | 40
14 | 50
[SOKOBAN] Retry . . .
18 | 60
20 | 70
22 | 80
22 | 90
24 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 700 Solved: 46
1 | 10
[SOKOBAN] Retry . . .
2 | 20
4 | 30
5 | 40
8 | 50
11 | 60
16 | 70
19 | 80
[SOKOBAN] Retry . . .
20 | 90
23 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 800 Solved: 54
3 | 10
3 | 20
[SOKOBAN] Retry . . .
6 | 30
9 | 40
11 | 50
15 | 60
15 | 70
19 | 80
22 | 90
24 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 900 Solved: 55
2 | 10
[SOKOBAN] Retry . . .
4 | 20
6 | 30
9 | 40
10 | 50
12 | 60
16 | 70
17 | 80
22 | 90
25 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1000 Solved: 53
1 | 10
[SOKOBAN] Retry . . .
[SOKOBAN] Retry . . .
[SOKOBAN] Retry . . .
5 | 20
8 | 30
11 | 40
14 | 50
17 | 60
19 | 70
21 | 80
23 | 90
24 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1100 Solved: 55
2 | 10
5 | 20
7 | 30
9 | 40
13 | 50
15 | 60
16 | 70
19 | 80
20 | 90
22 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1200 Solved: 64
0 | 10
5 | 20
7 | 30
9 | 40
[SOKOBAN] Retry . . .
9 | 50
10 | 60
11 | 70
13 | 80
15 | 90
[SOKOBAN] Retry . . .
19 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1300 Solved: 56
3 | 10
4 | 20
[SOKOBAN] Retry . . .
7 | 30
[SOKOBAN] Retry . . .
11 | 40
16 | 50
17 | 60
19 | 70
22 | 80
24 | 90
25 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1400 Solved: 61
3 | 10
7 | 20
10 | 30
[SOKOBAN] Retry . . .
12 | 40
16 | 50
18 | 60
19 | 70
22 | 80
[SOKOBAN] Retry . . .
24 | 90
26 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1500 Solved: 48
3 | 10
4 | 20
7 | 30
[SOKOBAN] Retry . . .
8 | 40
10 | 50
[SOKOBAN] Retry . . .
16 | 60
20 | 70
22 | 80
23 | 90
24 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1600 Solved: 52
4 | 10
6 | 20
10 | 30
13 | 40
[SOKOBAN] Retry . . .
17 | 50
19 | 60
20 | 70
22 | 80
24 | 90
[SOKOBAN] Retry . . .
27 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1700 Solved: 61
[SOKOBAN] Retry . . .
2 | 10
4 | 20
6 | 30
9 | 40
12 | 50
[SOKOBAN] Retry . . .
14 | 60
18 | 70
[SOKOBAN] Retry . . .
19 | 80
[SOKOBAN] Retry . . .
22 | 90
26 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1800 Solved: 56
2 | 10
[SOKOBAN] Retry . . .
5 | 20
[SOKOBAN] Retry . . .
8 | 30
[SOKOBAN] Retry . . .
12 | 40
[SOKOBAN] Retry . . .
14 | 50
18 | 60
20 | 70
[SOKOBAN] Retry . . .
24 | 80
26 | 90
30 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 1900 Solved: 61
3 | 10
7 | 20
[SOKOBAN] Retry . . .
10 | 30
[SOKOBAN] Retry . . .
15 | 40
20 | 50
23 | 60
24 | 70
[SOKOBAN] Retry . . .
27 | 80
[SOKOBAN] Retry . . .
30 | 90
32 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2000 Solved: 53
4 | 10
[SOKOBAN] Retry . . .
5 | 20
[SOKOBAN] Retry . . .
9 | 30
11 | 40
16 | 50
20 | 60
23 | 70
25 | 80
28 | 90
28 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2100 Solved: 59
1 | 10
5 | 20
[SOKOBAN] Retry . . .
8 | 30
14 | 40
19 | 50
21 | 60
21 | 70
26 | 80
28 | 90
29 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2200 Solved: 60
1 | 10
6 | 20
8 | 30
[SOKOBAN] Retry . . .
8 | 40
12 | 50
17 | 60
21 | 70
25 | 80
26 | 90
28 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2300 Solved: 63
3 | 10
4 | 20
6 | 30
10 | 40
12 | 50
[SOKOBAN] Retry . . .
13 | 60
[SOKOBAN] Retry . . .
14 | 70
17 | 80
20 | 90
26 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2400 Solved: 60
3 | 10
7 | 20
9 | 30
12 | 40
13 | 50
16 | 60
18 | 70
22 | 80
27 | 90
[SOKOBAN] Retry . . .
31 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2500 Solved: 56
5 | 10
9 | 20
[SOKOBAN] Retry . . .
14 | 30
16 | 40
21 | 50
24 | 60
26 | 70
29 | 80
32 | 90
33 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2600 Solved: 62
5 | 10
8 | 20
10 | 30
14 | 40
16 | 50
19 | 60
21 | 70
22 | 80
27 | 90
30 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2700 Solved: 65
1 | 10
3 | 20
8 | 30
11 | 40
12 | 50
[SOKOBAN] Retry . . .
15 | 60
21 | 70
25 | 80
27 | 90
[SOKOBAN] Retry . . .
32 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2800 Solved: 63
4 | 10
7 | 20
9 | 30
[SOKOBAN] Retry . . .
11 | 40
12 | 50
15 | 60
18 | 70
20 | 80
24 | 90
28 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 2900 Solved: 55
2 | 10
5 | 20
9 | 30
12 | 40
15 | 50
18 | 60
23 | 70
26 | 80
28 | 90
31 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 3000 Solved: 59
6 | 10
10 | 20
14 | 30
[SOKOBAN] Retry . . .
18 | 40
22 | 50
25 | 60
31 | 70
34 | 80
[SOKOBAN] Retry . . .
38 | 90
42 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .

Episode 3100 Solved: 57
3 | 10
7 | 20
12 | 30
15 | 40
19 | 50
21 | 60
26 | 70
29 | 80
34 | 90
[SOKOBAN] Retry . . .
37 | 100


HBox(children=(FloatProgress(value=0.0), HTML(value='')))

[SOKOBAN] Retry . . .
