In [None]:
import gym_sokoban.envs
from gym_sokoban.envs import SokobanEnv
from gym.spaces import Discrete, Box
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
from collections import deque
from copy import deepcopy
from Networks import ValueNetwork, ClassificationNetwork

In [None]:
class Node:
    def __init__(self, state_snapshot, parent = None, parent_action = 0, actions_num = 4):
        self.state_snapshot = state_snapshot
        self.childs = [None] * actions_num
        self.rewards = np.array([0.0] * actions_num)
        self.estimated_values = np.array([0.0] * actions_num)
        self.value = 0.0
        self.actions_num = actions_num
        self.parent = parent
        self.parent_action = parent_action
        
    def get_by_action(self, action):
        return self.childs[action]
    
    def set_node_for_action(self, action, other_node):
        self.childs[action] = other_node
        
    def set_reward_for_action(self, action, reward):
        self.rewards[action] = reward
        
    def update_value_for_action(self, action, value):
        self.value = (self.rewards + self.estimated_values).max()
        
    def get_value(self):
        return self.value
    
    def set_value(self, value):
        self.value = value
    
    def get_parent(self):
        return self.parent, self.parent_action
    
    def get_snapshot(self):
        return self.state_snapshot
    
    def get_best_action(self):
        return self.argmax(self.estimated_values + self.rewards) + 1
    
    def argmax(self, values):
        tie = []
        for idx, val in enumerate(values):
            if len(tie) == 0 or val > values[tie[0]]:
                tie = [idx]
            elif val == values[tie[0]]:
                tie.append(idx)
        return np.random.choice(tie)
    
    def select(self):
        euristics = []
        for i in range(self.actions_num):
            val = self.rewards[i] + self.estimated_values[i]
            euristics.append(val)
        action = self.argmax(euristics)
        if self.childs[action] is None:
            return self, action
        else:
            return self.childs[action].select()

In [None]:
class MCTS:
    def __init__(self, env, network, rollout_times = 20):
        self.env = env
        self.network = network
        self.rollout_times = rollout_times
        self.states_to_train = np.array([])
        self.targets = np.array([])
        
    def iterate(self, init_snapshot, net_type, train):
        root = Node(init_snapshot)
        for _ in range(self.rollout_times):
            parent_node, action = root.select()
            
            snapshot = parent_node.get_snapshot()
            env.set_state(snapshot)
            state, reward, done, _ = env.step(action + 1)
            if net_type == 'value':
                estimated_V = self.network.get_V(state)[0][0]
            else:
                estimated_V_prob = self.network.get_V(state)[0][0]
                estimated_V = 10.0 if estimated_V_prob > 0.9 else (-5.0 if estimated_V_prob < 0.1 else 0.0)
            
            snapshot = env.get_state_snapshot()
            child = Node(snapshot, parent_node, action)
            child.set_value(estimated_V)
            
            parent_node.set_reward_for_action(action, reward)
            parent_node.set_node_for_action(action, child)
            
            node = child
            while not (node is None):
                value = node.get_value()
                parent, parent_action = node.get_parent()
                if not (parent is None):
                    parent.update_value_for_action(parent_action, value)
                node = parent
                
        
        return root.get_best_action()

In [None]:
up = 1
down = 2
left = 3
right = 4

In [None]:
def run_experiment(env, network, net_type, rollouts, init_states, train, iteration_number_max):
    mcts = MCTS(env, network, rollouts)
    
    solved = []
    done = False

    for i, init_snapshot in enumerate(init_states):
        print("play", i)
        iteration_number = 0
        done = False
        
        states_to_train = []

        while not done:
            iteration_number += 1
            action = mcts.iterate(init_snapshot, net_type, train=train)
            env.set_state(init_snapshot)
            if train:
                states_to_train.append(env.render('rgb_array'))
            
            state, reward, done, _ = env.step(action)
            init_snapshot = env.get_state_snapshot()
            if done or iteration_number == iteration_number_max:
                if train:
                    targets = [0.0 if iteration_number == iteration_number_max else 10.0]
                    for _ in range(1, len(states_to_train)):
                        targets.append(targets[-1] - 0.1)
                    targets.reverse()
                    network.fit(np.array(states_to_train), np.array(targets), epochs=1, batch_size=len(targets), validation_split=0.0)
                    
                done = True
                solved.append(iteration_number != iteration_number_max)
                print("solved:", iteration_number != iteration_number_max)
                
    return np.array(solved)

In [None]:
def generate_starting_positions(env, number):
    snapshots = []
    for _ in range(number):
        env.reset()
        snapshots.append(env.get_state_snapshot())
    return snapshots

In [None]:
env = SokobanEnv((6, 6), 50, 2)

In [None]:
starting_positions = generate_starting_positions(env, 50)

In [None]:
prelearn_network = ClassificationNetwork(env.observation_space.shape, learning_rate = 1e-2)
prelearn_network.load_weights('classification_network')

In [None]:
prelearned_no_train_class_20 = run_experiment(env, prelearn_network, net_type='classification', rollouts=20, init_states=starting_positions, train=False, iteration_number_max=50)
prelearned_no_train_class_20.mean()

In [None]:
prelearned_no_train_class_40 = run_experiment(env, prelearn_network, net_type='classification', rollouts=40, init_states=starting_positions, train=False, iteration_number_max=50)
prelearned_no_train_class_40.mean()

In [None]:
prelearned_no_train_class_60 = run_experiment(env, prelearn_network, net_type='classification', rollouts=60, init_states=starting_positions, train=False, iteration_number_max=50)
prelearned_no_train_class_60.mean()

In [None]:
prelearn_network = ValueNetwork(env.observation_space.shape, learning_rate = 1e-2)
prelearn_network.load_weights('value_network')

In [None]:
prelearned_no_train_value_20 = run_experiment(env, prelearn_network, net_type='value', rollouts=20, init_states=starting_positions, train=False, iteration_number_max=50)
prelearned_no_train_value_20.mean()

In [None]:
prelearned_no_train_value_40 = run_experiment(env, prelearn_network, net_type='value', rollouts=40, init_states=starting_positions, train=False, iteration_number_max=50)
prelearned_no_train_value_40.mean()

In [None]:
prelearned_no_train_value_60 = run_experiment(env, prelearn_network, net_type='value', rollouts=60, init_states=starting_positions, train=False, iteration_number_max=50)
prelearned_no_train_value_60.mean()

In [None]:
# here the values are hardcoded as they were overwriten by another run of experiments
plt.plot([20, 40, 60], [0.42, 0.38, 0.4], label='With simple fields')

plt.plot([20, 40, 60], [prelearned_no_train_value_20, prelearned_no_train_value_40, prelearned_no_train_value_60], label='No simple fields')
plt.legend()
plt.xlabel('Rollouts')
plt.ylabel('Solved rate')
plt.title('6x6 fields, two boxes')

In [None]:
prelearn_network = ValueNetwork(env.observation_space.shape, learning_rate = 1e-2)
prelearn_network.load_weights('value_network')

In [None]:
prelearned_train_value_20 = run_experiment(env, prelearn_network, net_type='value', rollouts=20, init_states=starting_positions, train=True, iteration_number_max=50)
prelearned_train_value_20.mean()

In [None]:
prelearn_network = ValueNetwork(env.observation_space.shape, learning_rate = 1e-2)
prelearn_network.load_weights('value_network')

In [None]:
prelearned_train_value_40 = run_experiment(env, prelearn_network, net_type='value', rollouts=40, init_states=starting_positions, train=True, iteration_number_max=50)
prelearned_train_value_40.mean()

In [None]:
prelearn_network = ValueNetwork(env.observation_space.shape, learning_rate = 1e-2)
prelearn_network.load_weights('value_network')

In [None]:
prelearned_train_value_60 = run_experiment(env, prelearn_network, net_type='value', rollouts=60, init_states=starting_positions, train=True, iteration_number_max=50)
prelearned_train_value_60.mean()

In [None]:
plt.plot([20, 40, 60], [prelearned_no_train_value_20, prelearned_no_train_value_40, prelearned_no_train_value_60], label='Value Network with training')
plt.plot([20, 40, 60], [prelearned_train_value_20, prelearned_train_value_40, prelearned_train_value_60], label='Value Network without training')
plt.legend()
plt.xlabel('Rollouts')
plt.ylabel('Solved rate')
plt.title('6x6 fields, two boxes')