In [5]:
from sac_tf import SAC__Agent
from env_wrappers import *
from gym_car_intersect.envs import CarRacingHackatonContinuous2

import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import matplotlib.animation as animation

In [25]:
class Holder:
    '''
    Class to hold agent, environment and replay buffer. 
    Also it is a place to controll hyperparameters of learning process.
    '''
    
    def __init__(self, batch_size=32, hidden_size=256, buffer_size=10 * 1000):
        self.batch_size = batch_size
        
        # for reward history 
        self.update_steps_count = 0
        self.game_count = 0
        self.history = []
        
        # init replay buffer
        self.cur_write_index = 0
        self.buffer_size = buffer_size
        self.full_buf_size = 0
        self.buffer = [
            # state
            [],
            # action
            [],
            # reward
            [],
            # new state
            [],
            # done
            [],
        ]
        
        # init environment and agent
        env = CarRacingHackatonContinuous2(num_bots=0, start_file=None, is_discrete=True)
        env = chainerrl.wrappers.ContinuingTimeLimit(env, max_episode_steps=1000)
        env = MaxAndSkipEnv(env, skip=4)
#         env = DiscreteWrapper(env)
        env = WarpFrame(env, channel_order='hwc')
        self.env = env
        
        self.agent = SAC__Agent(
            picture_shape=(84, 84, 3), 
            extra_size=12, 
            action_size=5, 
            hidden_size=hidden_size
        )
        self.env_state = None
        self.reset_env()
        
    def reset_env(self, inc_counter=True):
        self.env_state = self.env.reset()
        if inc_counter:
            self.game_count += 1
        
        
    def insert_N_sample_to_replay_memory(self, N, temperature=0.5):
        for _ in range(N):
            
            if self.env_state is None:
                self.reset_env()
            
            action = self.agent.get_single_action(
                self.env_state,
                need_argmax=False,
                temperature=temperature,
            )
            new_state, reward, done, info = self.env.step(np.argmax(action))
            
            if len(self.buffer[0]) <= self.cur_write_index:
                for i in range(5):
                    self.buffer[i].append(None)
            # state
            self.buffer[0][self.cur_write_index] = self.env_state
            # action
            self.buffer[1][self.cur_write_index] = action
            # reward
            self.buffer[2][self.cur_write_index] = np.array([reward])
            # new state
            self.buffer[3][self.cur_write_index] = new_state
            # done flag
            self.buffer[4][self.cur_write_index] = 1.0 if done else 0.0
            self.env_state = new_state
            
            self.cur_write_index += 1
            if self.cur_write_index >= self.buffer_size:
                self.cur_write_index =  0
            
            if self.full_buf_size < self.buffer_size:
                self.full_buf_size += 1
            
            # reset env if done
            if done:
                self.reset_env()
                
                
    def iterate_over_buffer(self, steps):
        cur_steps = 0
        is_break = False
        buffer = [np.array(x) for x in self.buffer]
        while True:
            indexes = np.arange(self.full_buf_size)
            np.random.shuffle(indexes)
            
            for ind in range(0, len(indexes), self.batch_size):
                yield (
                    buffer[i][indexes[ind : ind + self.batch_size]]
                    for i in range(5)
                )
                cur_steps += 1
                if cur_steps >= steps:
                    is_break = True
                    break
            if is_break:
                break
    
    def update_agent(
            self, 
            update_step_num=500,
            temperature=0.5,
            gamma=0.7,
            v_exp_smooth_factor=0.8,
            need_update_VSmooth=False
    ):
        for batch in self.iterate_over_buffer(update_step_num):
            self.update_steps_count += 1
            self.agent.update_step(
                batch, 
                temperature=temperature, 
                gamma=gamma,
                v_exp_smooth_factor=v_exp_smooth_factor,
                need_update_VSmooth=need_update_VSmooth,
            )
            
    def iterate_over_test_game(self, max_steps=1000):
        self.reset_env(inc_counter=True)
        was_game_finit = False
        for _ in range(max_steps):
            action = self.agent.get_single_action(
                self.env_state,
                need_argmax=False,
                temperature=1,
            )
            self.env_state, reward, done, info = self.env.step(np.argmax(action))
            
            yield self.env_state, action, reward, done
            
            if done:
                was_game_finit = True
                break
        return None, None, was_game_finit
            
    def get_test_game_total_reward(
            self, 
            max_steps=1000, 
            temperature=10,
            add_to_memory=True,
    ):
        total_reward = 0
        was_game_finit = False
        
        for _, _, reward, done in self.iterate_over_test_game(max_steps=1000):
            if done:
                break
            total_reward += reward

            
        if add_to_memory:
            self.history.append([self.game_count, total_reward])
        
        return total_reward
    
    
    def get_test_game_mean_reward(
        self,
        n_games=10,
        max_steps=1000, 
        temperature=10,
        add_to_memory=True
    ):
        sm = 0
        for _ in range(n_games):
            sm += self.get_test_game_total_reward(max_steps, temperature, add_to_memory=False)
        sm /= n_games
        
        if add_to_memory:            
            self.history.append([self.game_count, sm])
            
        return sm
    
    def get_history(self):
        return np.array(self.history)

In [26]:
holder = Holder()

In [27]:
holder.insert_N_sample_to_replay_memory(10)

In [None]:
%%time

holder.update_agent(update_step_num=10)