# CM30359: Investigation of Deep Q-learning on Breakout - Group 9

## 1. Imports and Dependencies

In [None]:
import numpy as np
import time

!pip install gym[classic_control,atari,accept-rom-license]==0.26.0
!pip install typing-extensions --upgrade
!pip install moviepy

import gym
from gym.utils.save_video import save_video
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count, compress

import torch
import torch.nn as nn
import torch.optim as optim

from gym.wrappers import AtariPreprocessing, FrameStack

- ### GPU utilisation check

In [None]:
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
plt.ion()

# Check if GPU can be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## 2. Preprocessing

In [None]:
def Preprocessing_env(env):

    env = gym.wrappers.AtariPreprocessing(env, noop_max=30, 
                                      screen_size=84, terminal_on_life_loss=False, 
                                      grayscale_obs=True, grayscale_newaxis=False, scale_obs=False)

    env = gym.wrappers.FrameStack(env, 4)
    return env

## 3. Network Architectures

- ### Dueling DDQN

In [None]:
class DDDQNModel(nn.Module): # DDDQN
    def __init__(self, input_shape, n_actions):
        super(DDDQNModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        conv_out_size = self.get_conv_out_size(input_shape)

        self.state_value = nn.Sequential(
                    nn.Linear(conv_out_size, 512),
                    nn.ReLU(),
                    nn.Linear(512, 1)
                )
        self.action_advantage = nn.Sequential(
                nn.Linear(conv_out_size, 512),
                nn.ReLU(),
                nn.Linear(512, n_actions)
            )
        
    def get_conv_out_size(self, shape):
        conv_size = self.conv(torch.zeros(1, *shape))
        return int(np.prod(conv_size.size()))
    
    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        advantage = self.action_advantage(conv_out)      
        return self.state_value(conv_out) + torch.sub(advantage, torch.mean(advantage))

- ### DQN/DDQN

In [None]:
class DQNModel(nn.Module): # DQN/DDQN
    def __init__(self, input_shape, n_actions):
        super(DQNModel, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        conv_out_size = self.get_conv_out_size(input_shape)

        self.value = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def get_conv_out_size(self, shape):
        conv_size = self.conv(torch.zeros(1, *shape))
        return int(np.prod(conv_size.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)   
        return self.value(conv_out)

- ## Memory Efficiency Attempt

In [None]:
## Attempt at memory efficient

# Experience = collections.namedtuple('Experience', field_names=['state', 'action', 'reward', 
#            'done', 'new_state'])
# Samplable = collections.namedtuple('Samplable', field_names=['samplable'])

# class ExperienceReplay:
#     def __init__(self, capacity):
#         self.buffer = collections.deque(maxlen=capacity)
#         self.sample_buffer = collections.deque(maxlen=capacity)
#     def __len__(self):
#         return len(self.buffer)
#     def append(self, state,action,reward,is_done,new_state,first_samplable):
#         if first_samplable == 0: # AKA if this is 4th frame of the game
# #             print("First frame")
# #             print(np.shape(state))
# #             print(np.shape(state[0]))
# #             self.buffer.append(Experience(state[i],action,reward,is_done,new_state[i]) for i in range(3))
#             self.buffer.append(Experience(state[0],action,reward,is_done,new_state[0]))
#             self.buffer.append(Experience(state[1],action,reward,is_done,new_state[1]))
#             self.buffer.append(Experience(state[2],action,reward,is_done,new_state[2]))
#             self.sample_buffer.append(Samplable(False))
#             self.sample_buffer.append(Samplable(False))
#             self.sample_buffer.append(Samplable(False))
# #         print(np.shape(self.buffer[0][0]))

#         self.buffer.append(Experience(state[3],action,reward,is_done,new_state[3]))
#         self.sample_buffer.append(Samplable(True))
#         for i in range(3): 
#             self.sample_buffer[i] = False
#         #print(sys.getsizeof(self.buffer[0][0]))
  
#     def batch_sample(self, batch_size):
#         samplable_list = list(compress(range(len(self.sample_buffer)), self.sample_buffer))
#         indices = np.random.choice(len(samplable_list), batch_size, replace=False)
#         states, actions, rewards, dones, new_states = [], [], [], [], []
#         for idx in indices:
#             states.append(np.stack([self.buffer[idx-i][0] for i in range(4)], 0))
#             actions.append(self.buffer[idx][1])
#             rewards.append(self.buffer[idx][2])
#             dones.append(self.buffer[idx][3])
#             new_states.append(np.stack([self.buffer[idx-i][4] for i in range(4)], 0))
            
# #         print(np.shape(states))
# #         print(np.shape(states[0]))
        
# #         print(np.shape(actions))
# #         print(np.shape(actions[0]))
              
#         return np.array(states), np.array(actions), np.array(rewards,dtype=np.float32), np.array(dones, dtype=np.uint8),np.array(new_states)

## 4. Experience Replay Buffer

In [None]:
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 
           'done', 'next_state'])

class ExperienceReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def __len__(self):
        return len(self.buffer)
    
    def append(self, experience):
        self.buffer.append(experience)
  
    def batch_sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size,
                replace=False)
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
        return np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), np.array(dones, dtype=np.uint8),np.array(next_states)

## 5. Graph results

In [None]:
class Graphing:
    def __init__(self):
        self.writer = SummaryWriter(comment="-" + "ALE/BreakoutNoFrameskip-v4")
        self.episode_rewards = []
        self.best_mean_reward = float("-inf")
    def add_info(self, reward, frame_idx, epsilon):
        self.episode_rewards.append(reward)
        mean_reward = self.get_mean_reward()
        self.writer.add_scalar("epsilon", epsilon, frame_idx)
        self.writer.add_scalar("reward_100", mean_reward, frame_idx)
        self.writer.add_scalar("reward", reward, frame_idx)
        if self.best_mean_reward < mean_reward:
            self.best_mean_reward = mean_reward
            return True
        return False
    def get_best_reward(self):
        return self.best_mean_reward
    def get_mean_reward(self):
        return np.mean(self.episode_rewards[-100:])

## 6. Process Environment

In [None]:
class ProcessingEnv: 
    def __init__(self, env):
        self.env = env
        self.episode = 0
        self.video = []
        self.reset()
    def reset(self):
        self.episode += 1
        save_video(self.video, "videos", fps=25,
        episode_trigger = lambda x: x in [10] or x % 250 == 0,
        episode_index=self.episode)
        self.video = []
        self.state, _ = self.env.reset()
        self.fire()
        self.total_rewards = 0.0

    def get_episode(self):
        return self.episode

    def get_total_rewards(self):
        return self.total_rewards

    def get_state(self):
        return self.state
        
    def set_state(self, state):
        self.state = state

    def fire(self):
        self.state, _, _, _, _ = self.env.step(1)

    def step(self, action):
        next_state, reward, is_done, truncated, info = self.env.step(action)
        self.total_rewards += reward
        self.video.append(self.env.render())
        
        return next_state, reward, is_done, truncated, info

- ## Start tensorboard

In [None]:
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard
import datetime
print(">>>Training starts at ",datetime.datetime.now())

## 7. Agents

- ### DQN

In [None]:
class DQN:
    def __init__(self, env, gamma, batch_size, experience_replay_buffer_size, learning_rate, 
                 target_network_update_frequency, experience_replay_start_size, update_frequency,
                 epsilon_start, epsilon_decay, epsilon_end):
        self.env = env
        self.GAMMA = gamma                   
        self.BATCH_SIZE = batch_size                
        self.EXPERIENCE_REPLAY_BUFFER_SIZE = experience_replay_buffer_size           
        self.LEARNING_RATE = learning_rate           
        self.TARGET_NETWORK_UPDATE_FREQUENCY = target_network_update_frequency      
        self.EXPERIENCE_REPLAY_START_SIZE = experience_replay_start_size     
        self.UPDATE_FREQUENCY = update_frequency

        self.EPSILON_START = epsilon_start
        self.EPSILON_DECAY = epsilon_decay
        self.EPSILON_END = epsilon_end
        self.buffer = ExperienceReplayBuffer(self.EXPERIENCE_REPLAY_BUFFER_SIZE)
        self.net = DQNModel(env.env.observation_space.shape, env.env.action_space.n).to(device)
        self.target_net = DQNModel(env.env.observation_space.shape, env.env.action_space.n).to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.LEARNING_RATE)
        self.criterion_loss = nn.SmoothL1Loss()

        self.frame_idx = 1
        self.num_of_lives = 5



        restart = False
        if(restart == True):
            checkpoint1 = torch.load('checkpoint1Score109.pth')
            print("Loaded checkpoint 1")
            checkpoint2 = torch.load('checkpoint2Score109.pth')
            print("Loaded checkpoint 2")
            self.net.load_state_dict(checkpoint1['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint1['optimizer_state_dict'])
            self.EPSILON_START = checkpoint1['epsilon']
            print(self.EPSILON_START)
            self.criterion_loss = checkpoint1['loss']
            self.target_net.load_state_dict(checkpoint2['model_state_dict'])


        self.net.train()
        self.target_net.train()
        self.epsilon = self.EPSILON_START
        self.graph = Graphing()


    def get_action(self, state):
        if np.random.random() < self.epsilon:
            action = self.env.env.action_space.sample()
        else:
            state_a = np.array([state], copy=False)
            state_v = torch.tensor(state_a, dtype=torch.float32).to(device)
            q_vals_v = self.net(state_v)
            _, act_v = torch.max(q_vals_v, dim=1)
            action = int(act_v.item())
        return action

    def train(self, episode_num):
        while True:
            state = self.env.get_state()
            action = self.get_action(state)
            next_state, reward, is_done, truncated, info = self.env.step(action)

            # CROP REWARDS - IMPORTANT
            reward = min(reward, 1)

            # STORING IN EXPERIENCE REPLAY BUFFER
            if(info.get("lives") < self.num_of_lives):
                # LIFE LOSS CONSIDERED TERMINAL
                experience = Experience(state,action,reward,True,next_state)
                self.env.fire()
                self.num_of_lives = info.get("lives")
            else:
                experience = Experience(state,action,reward,is_done,next_state)
                self.env.set_state(next_state)
            self.buffer.append(experience)


            if is_done:
                self.num_of_lives = 5
                best_so_far = self.graph.add_info(self.env.get_total_rewards(), self.frame_idx, self.epsilon)
                if best_so_far:
                    torch.save({'model_state_dict': self.net.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'epsilon': self.epsilon,'loss': self.criterion_loss},'checkpoint1.pth')
                    print("Best mean reward updated %.3f, %d" % (self.graph.get_best_reward(), self.frame_idx))
                self.env.reset()


            if self.frame_idx >= 4000000 or self.env.episode >= episode_num:
                print("Solved in %d frames!" % self.frame_idx)
                torch.save({'model_state_dict': self.net.state_dict()}, 'checkpoint3.pth')
                break

            if len(self.buffer) < self.EXPERIENCE_REPLAY_START_SIZE:
                self.frame_idx += 1
                continue
            
            self.epsilon = np.interp(self.frame_idx, [0, self.EPSILON_DECAY], [self.EPSILON_START, self.EPSILON_END])

            if self.frame_idx % self.UPDATE_FREQUENCY == 0:
                batch = self.buffer.batch_sample(self.BATCH_SIZE)
                states, actions, rewards, dones, next_states = batch

                states_v = torch.tensor(states, dtype=torch.float32).to(device)
                next_states_v = torch.tensor(next_states, dtype=torch.float32).to(device)
                actions_v = torch.tensor(actions, dtype=torch.int64).to(device)
                rewards_v = torch.tensor(rewards).to(device)
                done_mask = torch.ByteTensor(dones).to(device)

                q_vals = self.net(states_v)
                state_action_values = q_vals.gather(1, actions_v.unsqueeze(-1)).squeeze(-1)

                with torch.no_grad():
                    next_state_values = self.target_net(next_states_v).max(1)[0]  # DQN

                next_state_values[done_mask] = 0.0
                next_state_values = next_state_values.detach()

                estimated_state_action_values = rewards_v + self.GAMMA * next_state_values 

                loss_huber = self.criterion_loss(state_action_values, estimated_state_action_values)
                self.optimizer.zero_grad()
                loss_huber.backward()
                self.optimizer.step()

            if self.frame_idx % self.TARGET_NETWORK_UPDATE_FREQUENCY == 0:
                print("%d:  %d games, mean reward %.3f, (epsilon %.2f)" % (self.frame_idx, self.env.get_episode(), self.graph.get_mean_reward(), self.epsilon))
                self.target_net.load_state_dict(self.net.state_dict())
                torch.save({'model_state_dict': self.target_net.state_dict()}, 'checkpoint2.pth')
            self.frame_idx += 1
            
            
    def play(self, episode_num):
        checkpoint1 = torch.load('checkpoint2DQN.pth')
        print("Loaded checkpoint 1")
        self.net.load_state_dict(checkpoint1['model_state_dict'])
        self.epsilon = 0.001
        episode_rewards = []
        for episode in range(episode_num):
            self.frame_idx = 0
            self.num_of_lives = 5
            self.env.reset()
            is_done = False
            while not is_done:
                state = self.env.get_state()
                action = self.get_action(state)
                next_state, _, is_done, _, info = self.env.step(action)
                
                if(info.get("lives") < self.num_of_lives):
                    print("Lost life")
                # LIFE LOSS CONSIDERED TERMINAL
                    self.env.fire()
                    self.num_of_lives = info.get("lives")
                else:
                    self.env.set_state(next_state)
                    
                if(self.frame_idx >= 10000):
                    break
                self.frame_idx += 1
                
            total_reward = self.env.get_total_rewards()
            print(total_reward)
            print(self.frame_idx)
            episode_rewards.append(total_reward)
        return episode_rewards, max(episode_rewards), sum(episode_rewards) / len(episode_rewards)





env = gym.make("BreakoutNoFrameskip-v4", render_mode="rgb_array")
processed_env = ProcessingEnv(Preprocessing_env(env))


dqn = DQN(env=processed_env, gamma=0.99, batch_size=32, experience_replay_buffer_size=100000, 
              learning_rate=1e-4, target_network_update_frequency=10000, experience_replay_start_size=5000, 
              update_frequency=4, epsilon_start=1, epsilon_decay=500000, epsilon_end=0.01)

# print(dqn.play(30))
dqn.train(5550)

- ### DDQN

In [None]:
class DDQN:
    def __init__(self, env, gamma, batch_size, experience_replay_buffer_size, learning_rate, 
                 target_network_update_frequency, experience_replay_start_size, update_frequency,
                 epsilon_start, epsilon_decay, epsilon_end):
        self.env = env
        self.GAMMA = gamma                   
        self.BATCH_SIZE = batch_size                
        self.EXPERIENCE_REPLAY_BUFFER_SIZE = experience_replay_buffer_size           
        self.LEARNING_RATE = learning_rate           
        self.TARGET_NETWORK_UPDATE_FREQUENCY = target_network_update_frequency      
        self.EXPERIENCE_REPLAY_START_SIZE = experience_replay_start_size     
        self.UPDATE_FREQUENCY = update_frequency

        self.EPSILON_START = epsilon_start
        self.EPSILON_DECAY = epsilon_decay
        self.EPSILON_END = epsilon_end
        self.buffer = ExperienceReplayBuffer(self.EXPERIENCE_REPLAY_BUFFER_SIZE)
        self.net = DQNModel(env.env.observation_space.shape, env.env.action_space.n).to(device)
        self.target_net = DQNModel(env.env.observation_space.shape, env.env.action_space.n).to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.LEARNING_RATE)
        self.criterion_loss = nn.SmoothL1Loss()

        self.frame_idx = 1
        self.num_of_lives = 5



        restart = False
        if(restart == True):
            checkpoint1 = torch.load('checkpoint1Score109.pth')
            print("Loaded checkpoint 1")
            checkpoint2 = torch.load('checkpoint2Score109.pth')
            print("Loaded checkpoint 2")
            self.net.load_state_dict(checkpoint1['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint1['optimizer_state_dict'])
            self.EPSILON_START = checkpoint1['epsilon']
            print(self.EPSILON_START)
            self.criterion_loss = checkpoint1['loss']
            self.target_net.load_state_dict(checkpoint2['model_state_dict'])


        self.net.train()
        self.target_net.train()
        self.epsilon = self.EPSILON_START
        self.graph = Graphing()


    def get_action(self, state):
        if np.random.random() < self.epsilon:
            action = self.env.env.action_space.sample()
        else:
            state_a = np.array([state], copy=False)
            state_v = torch.tensor(state_a, dtype=torch.float32).to(device)
            q_vals_v = self.net(state_v)
            _, act_v = torch.max(q_vals_v, dim=1)
            action = int(act_v.item())
        return action

    def train(self, episode_num):
        while True:
            state = self.env.get_state()
            action = self.get_action(state)
            next_state, reward, is_done, truncated, info = self.env.step(action)

            # CROP REWARDS - IMPORTANT
#             reward = min(reward, 1)

            # STORING IN EXPERIENCE REPLAY BUFFER
            if(info.get("lives") < self.num_of_lives):
                # LIFE LOSS CONSIDERED TERMINAL
                experience = Experience(state,action,reward,True,next_state)
                self.env.fire()
                self.num_of_lives = info.get("lives")
            else:
                experience = Experience(state,action,reward,is_done,next_state)
                self.env.set_state(next_state)
#             experience = Experience(state,action,reward,is_done,next_state)
#             self.env.set_state(next_state)
            self.buffer.append(experience)

            if is_done:
                self.num_of_lives = 5
                best_so_far = self.graph.add_info(self.env.get_total_rewards(), self.frame_idx, self.epsilon)
                if best_so_far:
                    torch.save({'model_state_dict': self.net.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'epsilon': self.epsilon,'loss': self.criterion_loss},'checkpoint1.pth')
                    print("Best mean reward updated %.3f, %d" % (self.graph.get_best_reward(), self.frame_idx))
                self.env.reset()


            if self.frame_idx >= 4000000 or self.env.episode >= episode_num:
                print("Solved in %d frames!" % self.frame_idx)
                torch.save({'model_state_dict': self.net.state_dict()}, 'checkpoint3.pth')
                break

            if len(self.buffer) < self.EXPERIENCE_REPLAY_START_SIZE:
                self.frame_idx += 1
                continue
            
            self.epsilon = np.interp(self.frame_idx, [0, self.EPSILON_DECAY], [self.EPSILON_START, self.EPSILON_END])

            if self.frame_idx % self.UPDATE_FREQUENCY == 0:
                batch = self.buffer.batch_sample(self.BATCH_SIZE)
                states, actions, rewards, dones, next_states = batch

                states_v = torch.tensor(states, dtype=torch.float32).to(device)
                next_states_v = torch.tensor(next_states, dtype=torch.float32).to(device)
                actions_v = torch.tensor(actions, dtype=torch.int64).to(device)
                rewards_v = torch.tensor(rewards).to(device)
                done_mask = torch.ByteTensor(dones).to(device)

                q_vals = self.net(states_v)
                state_action_values = q_vals.gather(1, actions_v.unsqueeze(-1)).squeeze(-1)

                q_vals2 = self.net(next_states_v)
                _, act_v = torch.max(q_vals2, dim=1)

                with torch.no_grad():
                    next_state_values1 = self.target_net(next_states_v)
                    next_state_values2 = next_state_values1.gather(1, act_v.unsqueeze(1))
                    next_state_values = torch.reshape(next_state_values2, (-1,))

                next_state_values[done_mask] = 0.0
                next_state_values = next_state_values.detach()

                estimated_state_action_values = rewards_v + self.GAMMA * next_state_values 

                loss_huber = self.criterion_loss(state_action_values, estimated_state_action_values)
                self.optimizer.zero_grad()
                loss_huber.backward()
                self.optimizer.step()

            if self.frame_idx % self.TARGET_NETWORK_UPDATE_FREQUENCY == 0:
                print("%d:  %d games, mean reward %.3f, (epsilon %.2f)" % (self.frame_idx, self.env.get_episode(), self.graph.get_mean_reward(), self.epsilon))
                self.target_net.load_state_dict(self.net.state_dict())
                torch.save({'model_state_dict': self.target_net.state_dict()}, 'checkpoint2.pth')
            self.frame_idx += 1


    def play(self, episode_num):
        checkpoint1 = torch.load('checkpoint2DDQN20k2.pth')
        print("Loaded checkpoint 1")
        self.net.load_state_dict(checkpoint1['model_state_dict'])
        self.epsilon = 0.05
        episode_rewards = []
        for episode in range(episode_num):
            self.frame_idx = 0
            self.num_of_lives = 5
            self.env.reset()
            is_done = False
            while not is_done:
                state = self.env.get_state()
                action = self.get_action(state)
                next_state, _, is_done, _, info = self.env.step(action)
                
                if(info.get("lives") < self.num_of_lives):
                    print("Lost life")
                # LIFE LOSS CONSIDERED TERMINAL
                    self.env.fire()
                    self.num_of_lives = info.get("lives")
                else:
                    self.env.set_state(next_state)
                    
                if(self.frame_idx >= 18000):
                    break
                self.frame_idx += 1
                
            total_reward = self.env.get_total_rewards()
            print(total_reward)
            print(self.frame_idx)
            episode_rewards.append(total_reward)
        return episode_rewards, max(episode_rewards), sum(episode_rewards) / len(episode_rewards)
            


env = gym.make("BreakoutNoFrameskip-v4", render_mode="rgb_array")
processed_env = ProcessingEnv(Preprocessing_env(env))


ddqn = DDQN(env=processed_env, gamma=0.99, batch_size=32, experience_replay_buffer_size=100000, 
              learning_rate=1e-4, target_network_update_frequency=10000, experience_replay_start_size=5000, 
              update_frequency=4, epsilon_start=1, epsilon_decay=500000, epsilon_end=0.01)
ddqn.train(5550)
# print(ddqn.play(30))



- ### DDDQN

In [None]:
class DDDQN:
    def __init__(self, env, gamma, batch_size, experience_replay_buffer_size, learning_rate, 
                 target_network_update_frequency, experience_replay_start_size, update_frequency,
                 epsilon_start, epsilon_decay, epsilon_end):
        self.env = env
        self.GAMMA = gamma                   
        self.BATCH_SIZE = batch_size                
        self.EXPERIENCE_REPLAY_BUFFER_SIZE = experience_replay_buffer_size           
        self.LEARNING_RATE = learning_rate           
        self.TARGET_NETWORK_UPDATE_FREQUENCY = target_network_update_frequency      
        self.EXPERIENCE_REPLAY_START_SIZE = experience_replay_start_size     
        self.UPDATE_FREQUENCY = update_frequency

        self.EPSILON_START = epsilon_start
        self.EPSILON_DECAY = epsilon_decay
        self.EPSILON_END = epsilon_end
        self.buffer = ExperienceReplayBuffer(self.EXPERIENCE_REPLAY_BUFFER_SIZE)
        self.net = DDDQNModel(env.env.observation_space.shape, env.env.action_space.n).to(device)
        self.target_net = DDDQNModel(env.env.observation_space.shape, env.env.action_space.n).to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.LEARNING_RATE)
        self.criterion_loss = nn.SmoothL1Loss()

        self.frame_idx = 1
        self.num_of_lives = 5



        restart = False
        if(restart == True):
            checkpoint1 = torch.load('checkpoint1Score109.pth')
            print("Loaded checkpoint 1")
            checkpoint2 = torch.load('checkpoint2Score109.pth')
            print("Loaded checkpoint 2")
            self.net.load_state_dict(checkpoint1['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint1['optimizer_state_dict'])
            self.EPSILON_START = checkpoint1['epsilon']
            print(self.EPSILON_START)
            self.criterion_loss = checkpoint1['loss']
            self.target_net.load_state_dict(checkpoint2['model_state_dict'])


        self.net.train()
        self.target_net.train()
        self.epsilon = self.EPSILON_START
        self.graph = Graphing()


    def get_action(self, state):
        if np.random.random() < self.epsilon:
            action = self.env.env.action_space.sample()
        else:
            state_a = np.array([state], copy=False)
            state_v = torch.tensor(state_a, dtype=torch.float32).to(device)
            q_vals_v = self.net(state_v)
            _, act_v = torch.max(q_vals_v, dim=1)
            action = int(act_v.item())
        return action

    def train(self, episode_num):
        while True:
            state = self.env.get_state()
            action = self.get_action(state)
            next_state, reward, is_done, truncated, info = self.env.step(action)

            # CROP REWARDS - IMPORTANT
            reward = min(reward, 1)

            # STORING IN EXPERIENCE REPLAY BUFFER
            if(info.get("lives") < self.num_of_lives):
                # LIFE LOSS CONSIDERED TERMINAL
                experience = Experience(state,action,reward,True,next_state)
                self.env.fire()
                self.num_of_lives = info.get("lives")
            else:
                experience = Experience(state,action,reward,is_done,next_state)
                self.env.set_state(next_state)
            self.buffer.append(experience)


            if is_done:
                self.num_of_lives = 5
                best_so_far = self.graph.add_info(self.env.get_total_rewards(), self.frame_idx, self.epsilon)
                if best_so_far:
                    torch.save({'model_state_dict': self.net.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'epsilon': self.epsilon,'loss': self.criterion_loss},'checkpoint1.pth')
                    print("Best mean reward updated %.3f, %d" % (self.graph.get_best_reward(), self.frame_idx))
                self.env.reset()


            if self.frame_idx >= 4000000 or self.env.episode >= episode_num:
                print("Solved in %d frames!" % self.frame_idx)
                torch.save({'model_state_dict': self.net.state_dict()}, 'checkpoint3.pth')
                break

            if len(self.buffer) < self.EXPERIENCE_REPLAY_START_SIZE:
                self.frame_idx += 1
                continue
            
            self.epsilon = np.interp(self.frame_idx, [0, self.EPSILON_DECAY], [self.EPSILON_START, self.EPSILON_END])

            if self.frame_idx % self.UPDATE_FREQUENCY == 0:
                batch = self.buffer.batch_sample(self.BATCH_SIZE)
                states, actions, rewards, dones, next_states = batch

                states_v = torch.tensor(states, dtype=torch.float32).to(device)
                next_states_v = torch.tensor(next_states, dtype=torch.float32).to(device)
                actions_v = torch.tensor(actions, dtype=torch.int64).to(device)
                rewards_v = torch.tensor(rewards).to(device)
                done_mask = torch.ByteTensor(dones).to(device)

                q_vals = self.net(states_v)
                state_action_values = q_vals.gather(1, actions_v.unsqueeze(-1)).squeeze(-1)

                q_vals2 = self.net(next_states_v)
                _, act_v = torch.max(q_vals2, dim=1)

                with torch.no_grad():
                    next_state_values1 = self.target_net(next_states_v)
                    next_state_values2 = next_state_values1.gather(1, act_v.unsqueeze(1))
                    next_state_values = torch.reshape(next_state_values2, (-1,))

                next_state_values[done_mask] = 0.0

                next_state_values = next_state_values.detach()
                

                estimated_state_action_values = rewards_v + self.GAMMA * next_state_values 

                loss_huber = self.criterion_loss(state_action_values, estimated_state_action_values)
                self.optimizer.zero_grad()
                loss_huber.backward()
                self.optimizer.step()

            if self.frame_idx % self.TARGET_NETWORK_UPDATE_FREQUENCY == 0:
                print("%d:  %d games, mean reward %.3f, (epsilon %.2f)" % (self.frame_idx, self.env.get_episode(), self.graph.get_mean_reward(), self.epsilon))
                self.target_net.load_state_dict(self.net.state_dict())
                torch.save({'model_state_dict': self.target_net.state_dict()}, 'checkpoint2.pth')
                
            self.frame_idx += 1
            
    def play(self, episode_num):
        checkpoint1 = torch.load('checkpoint3DDDQN.pth')
        print("Loaded checkpoint 1")
        self.net.load_state_dict(checkpoint1['model_state_dict'])
        self.epsilon = 0.05
        episode_rewards = []
        for episode in range(episode_num):
            self.frame_idx = 0
            self.num_of_lives = 5
            self.env.reset()
            is_done = False
            while not is_done:
                state = self.env.get_state()
                action = self.get_action(state)
                next_state, _, is_done, _, info = self.env.step(action)
                
                if(info.get("lives") < self.num_of_lives):
                    print("Lost life")
                # LIFE LOSS CONSIDERED TERMINAL
                    self.env.fire()
                    self.num_of_lives = info.get("lives")
                else:
                    self.env.set_state(next_state)
                    
                if(self.frame_idx >= 10000):
                    break
                self.frame_idx += 1
                
            total_reward = self.env.get_total_rewards()
            print(total_reward)
            print(self.frame_idx)
            episode_rewards.append(total_reward)
        return episode_rewards, max(episode_rewards), sum(episode_rewards) / len(episode_rewards)


env = gym.make("BreakoutNoFrameskip-v4", render_mode="rgb_array")
processed_env = ProcessingEnv(Preprocessing_env(env))


dddqn = DDDQN(env=processed_env, gamma=0.99, batch_size=32, experience_replay_buffer_size=100000, 
              learning_rate=1e-4, target_network_update_frequency=10000, experience_replay_start_size=5000, 
              update_frequency=4, epsilon_start=1, epsilon_decay=500000, epsilon_end=0.01)
dddqn.train(5550)
# print(dddqn.play(30))




- ### Random

In [None]:
class RANDOM:
    def __init__(self, env):
        self.env = env
        self.frame_idx = 0
        self.num_of_lives = 5
        self.epsilon = 0.05
        self.episode_rewards = []
        self.action = 1

    def play(self, episode_num):
        for episode in range(episode_num):
            self.frame_idx = 0
            self.num_of_lives = 5
            self.env.reset()
            is_done = False
            while not is_done:
                state = self.env.get_state()
                if self.frame_idx % 6 == 0:
                    self.action = self.env.env.action_space.sample()
                next_state, _, is_done, _, info = self.env.step(self.action)
                if(info.get("lives") < self.num_of_lives):
                # LIFE LOSS CONSIDERED TERMINAL
                    self.env.fire()
                    self.num_of_lives = info.get("lives")
                else:
                    self.env.set_state(next_state)
                self.frame_idx += 1
                
            total_reward = self.env.get_total_rewards()
            self.episode_rewards.append(total_reward)
        return self.episode_rewards, max(self.episode_rewards), sum(self.episode_rewards) / len(self.episode_rewards)

        



env = gym.make("BreakoutNoFrameskip-v4", render_mode="rgb_array")
processed_env = ProcessingEnv(Preprocessing_env(env))
rand = RANDOM(processed_env)

print(rand.play(30))

In [None]:
print(">>>Training ends at ",datetime.datetime.now())

 ### 8. Tensorboard Results

In [None]:
tensorboard  --logdir=runs