In [1]:
import torch
import torch.nn as nn
import random
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from tqdm import tqdm
import pickle 
from gym_super_mario_bros.actions import RIGHT_ONLY
import gym
import numpy as np
import collections 
import cv2

In [2]:
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        """Clear past frame buffer and init to first obs"""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    """
    Downsamples image to 84x84
    Greyscales image

    Returns numpy array
    """
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        elif frame.size == 240 * 256 * 3:
            img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
                                                dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    """Normalize pixel values in frame --> 0 to 1"""
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
                                                old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


def make_env(env):
    env = MaxAndSkipEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    env = ScaledFloatFrame(env)
    return JoypadSpace(env, RIGHT_ONLY)

In [3]:
class DQNSolver(nn.Module):

    def __init__(self, input_shape, n_actions):
        super(DQNSolver, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)
    

class DQNAgent:

    def __init__(self, state_space, action_space, max_memory_size, batch_size, gamma, lr,
                 dropout, exploration_max, exploration_min, exploration_decay, double_dq):

        # Define DQN Layers
        self.state_space = state_space
        self.action_space = action_space
        self.double_dq = double_dq
        if self.double_dq:
            self.dqn1 = DQNSolver(state_space, action_space).cuda()
            self.dqn1.load_state_dict(torch.load("dq1.pt"))
            self.dqn2 = DQNSolver(state_space, action_space).cuda()
            self.dqn2.load_state_dict(torch.load("dq2.pt"))
            self.optimizer1 = torch.optim.Adam(self.dqn1.parameters(), lr=lr)
            self.optimizer2 = torch.optim.Adam(self.dqn2.parameters(), lr=lr)
        else:
            self.dqn = DQNSolver(state_space, action_space).cuda()
            self.dqn.load_state_dict(torch.load("dq.pt"))
            self.optimizer = torch.optim.Adam(self.dqn.parameters(), lr=lr)

        # Create memory
        self.max_memory_size = max_memory_size
        try:
            self.STATE_MEM = torch.load("STATE_MEM.pt")
            self.ACTION_MEM = torch.load("ACTION_MEM.pt")
            self.REWARD_MEM = torch.load("REWARD_MEM.pt")
            self.STATE2_MEM = torch.load("STATE2_MEM.pt")
            self.DONE_MEM = torch.load("DONE_MEM.pt")
            assert len(self.STATE_MEM) == len(self.ACTION_MEM) == len(self.REWARD_MEM)
        except:
            self.STATE_MEM = torch.zeros(0, *self.state_space)
            self.ACTION_MEM = torch.zeros(0, 1)
            self.REWARD_MEM = torch.zeros(0, 1)
            self.STATE2_MEM = torch.zeros(0, *self.state_space)
            self.DONE_MEM = torch.zeros(0, 1)
        
        self.priority = torch.ones(len(self.STATE_MEM), 1) * (1/len(self.STATE_MEM)) 
        self.num_in_queue = len(self.STATE_MEM)
        self.memory_sample_size = batch_size
        
        if self.num_in_queue >= self.max_memory_size:
            self.starting_position = 1
        else:
            self.starting_position = 0

        # Learning parameters
        self.gamma = gamma
        self.l1 = nn.SmoothL1Loss().cuda()
        self.exploration_rate = exploration_max
        self.exploration_min = exploration_min
        self.exploration_decay = exploration_decay

    def remember(self, state, action, reward, state2, done):
        self.STATE_MEM = torch.cat((self.STATE_MEM[self.starting_position:], state.float()), dim=0)
        self.ACTION_MEM = torch.cat((self.ACTION_MEM[self.starting_position:], action.float()))
        self.REWARD_MEM = torch.cat((self.REWARD_MEM[self.starting_position:], reward.float()))
        self.STATE2_MEM = torch.cat((self.STATE2_MEM[self.starting_position:], state2.float()), dim=0)
        self.DONE_MEM = torch.cat((self.DONE_MEM[self.starting_position:], done.float()))
        self.priority = torch.cat((self.priority[self.starting_position:], torch.Tensor([max(self.priority, default=1)]).unsqueeze(0)))
        self.num_in_queue += 1
        if self.num_in_queue >= self.max_memory_size:
            self.starting_position = 1

    def get_probabilities(self, priority_scale):
        scaled_probabilities = torch.pow(self.priority, priority_scale)
        scaled_probabilities = scaled_probabilities / sum(scaled_probabilities)
        return scaled_probabilities
    
    def get_importance(self, probabilities):
        importance = 1/len(self.STATE_MEM) * 1/probabilities
        importance_normalized = importance / max(importance)
        importance_normalize = torch.pow(importance_normalized, 1 - self.exploration_rate)
        return importance_normalized
    
    def set_priorities(self, indices, abs_errors, offset=0.1):
        self.priority[indices] = abs_errors + offset
        
    def recall(self, priority_scale=1.0):
        sample_probabilities = self.get_probabilities(priority_scale)
        # print(len(sample_probabilities))
        idx = random.choices(range(len(self.STATE_MEM)), k=self.memory_sample_size, weights=sample_probabilities)
        STATE = self.STATE_MEM[idx]
        ACTION = self.ACTION_MEM[idx]
        REWARD = self.REWARD_MEM[idx]
        STATE2 = self.STATE2_MEM[idx]
        DONE = self.DONE_MEM[idx]
        importance = self.get_importance(sample_probabilities[idx])
        return (STATE, ACTION, REWARD, STATE2, DONE), idx, importance

    def act(self, state):
        if random.random() < self.exploration_rate:
            return torch.tensor([[random.randrange(self.action_space)]])
        if self.double_dq:
            if random.random() < 0.5:
                return torch.argmax(self.dqn1(state.cuda())).unsqueeze(0).unsqueeze(0).cpu()
            else:
                return torch.argmax(self.dqn2(state.cuda())).unsqueeze(0).unsqueeze(0).cpu()
        else:
            return torch.argmax(self.dqn(state.cuda())).unsqueeze(0).unsqueeze(0).cpu()

    def experience_replay(self):

        if self.memory_sample_size > len(self.STATE_MEM):
            return

        # Q-Learning update is Q(S, A) <- Q(S, A) + α[r + γ max_a Q(S', a) - Q(S, A)]
        (STATE, ACTION, REWARD, STATE2, DONE), idx, experience = self.recall()
        STATE = STATE.cuda()
        ACTION = ACTION.cuda()
        REWARD = REWARD.cuda()
        STATE2 = STATE2.cuda()
        DONE = DONE.cuda()
        NOTDONE = 1 - DONE
        
        if self.double_dq:
            if random.random() < 0.5:
                # Update DQN2 using DQN1
                self.optimizer2.zero_grad()
                target = REWARD + torch.mul((self.gamma * 
                                                   self.dqn1(STATE2).max(1).values.unsqueeze(1)), 
                                                   NOTDONE)
                
                current = self.dqn2(STATE).gather(1, ACTION.long())
                optimizer = self.optimizer2
            else:
                # Update DQN1 using DQN2
                self.optimizer1.zero_grad()
                target = REWARD + torch.mul((self.gamma * 
                                                   self.dqn2(STATE2).max(1).values.unsqueeze(1)), 
                                                   NOTDONE)
                current = self.dqn1(STATE).gather(1, ACTION.long())
                optimizer = self.optimizer1
        else:
            self.optimizer.zero_grad()
            target = REWARD + torch.mul((self.gamma * 
                                                self.dqn(STATE2).max(1).values.unsqueeze(1)), 
                                                NOTDONE)
                
            current = self.dqn(STATE).gather(1, ACTION.long())
            optimizer = self.optimizer

        errors = torch.abs(current - target).cpu()
        self.set_priorities(idx, errors)
        weights = self.get_importance(self.get_probabilities(0.7))
        
        loss = self.l1(current, target)
        loss.backward()
        optimizer.step()

        self.exploration_rate *= self.exploration_decay
        self.exploration_rate = max(self.exploration_rate, self.exploration_min)

In [4]:
def vectorize_action(action, action_space):
    return [0 for _ in range(action)] + [1] + [0 for _ in range(action + 1, action_space)]

In [5]:
def run(training_mode):
   
    env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
    env = make_env(env)
    observation_space = env.observation_space.shape
    action_space = env.action_space.n
    agent = DQNAgent(state_space=observation_space,
                     action_space=action_space,
                     max_memory_size=8000,
                     batch_size=32,
                     gamma=0.90,
                     lr=0.00025,
                     dropout=0.,
                     exploration_max=0.06,
                     exploration_min=0.02,
                     exploration_decay=0.99,
                     double_dq=True)
    
    num_episodes = 50
    env.reset()
    total_rewards = []
    
    if training_mode:
        with open("best_reward.pkl", 'rb') as f:
            best_reward = pickle.load(f)
    
        print(best_reward)
    for ep_num in tqdm(range(num_episodes)):
        state = env.reset()
        state = torch.Tensor([state])
        total_reward = 0
        steps = 0

        while True:
            env.render()
            action = agent.act(state)
            steps += 1
            state_next, reward, terminal, _ = env.step(int(action[0]))
            state_next = torch.Tensor([state_next])
            reward = torch.tensor([reward]).unsqueeze(0)
            total_reward += reward
            terminal = torch.tensor([int(terminal)]).unsqueeze(0)
            
            if training_mode:
                agent.remember(state, action, reward, state_next, terminal)
                agent.experience_replay()
            
            state = state_next
            if terminal:
                break
        
        if training_mode:
            if total_reward > best_reward:
                best_reward = total_reward

                # with open("best_reward.pkl", "wb") as f:
                    # pickle.dump(best_reward, f)
                # if agent.double_dq:
                    # torch.save(agent.dqn1.state_dict(), "dq1.pt")
                    # torch.save(agent.dqn2.state_dict(), "dq2.pt")
                # else:
                    # torch.save(agent.dqn1.state_dict(), "dq.pt")  
        total_rewards.append(total_reward)

        print("Total reward after episode {} is {}".format(ep_num + 1, total_rewards[-1]))
        num_episodes += 1
    
    # if training_mode:
    #     with open("total_rewards.pkl", "wb") as f:
    #         pickle.dump(total_rewards, f)
          
        
        # torch.save(agent.STATE_MEM,  "STATE_MEM.pt")
        # torch.save(agent.ACTION_MEM, "ACTION_MEM.pt")
        # torch.save(agent.REWARD_MEM, "REWARD_MEM.pt")
        # torch.save(agent.STATE2_MEM, "STATE2_MEM.pt")
        # torch.save(agent.DONE_MEM,   "DONE_MEM.pt")
    
    env.close()

run(training_mode=True)

  0%|          | 0/50 [00:00<?, ?it/s]

tensor([[1924.]])


  return (self.ram[0x86] - self.ram[0x071c]) % 256
  2%|▏         | 1/50 [02:46<2:15:40, 166.14s/it]

Total reward after episode 1 is tensor([[1415.]])


  4%|▍         | 2/50 [03:11<1:39:00, 123.77s/it]

Total reward after episode 2 is tensor([[249.]])


  6%|▌         | 3/50 [06:14<1:50:55, 141.62s/it]

Total reward after episode 3 is tensor([[1331.]])


  8%|▊         | 4/50 [07:10<1:28:57, 116.03s/it]

Total reward after episode 4 is tensor([[608.]])


 10%|█         | 5/50 [09:14<1:28:45, 118.35s/it]

Total reward after episode 5 is tensor([[1323.]])


 12%|█▏        | 6/50 [10:33<1:18:15, 106.71s/it]

Total reward after episode 6 is tensor([[624.]])


 14%|█▍        | 7/50 [12:28<1:18:11, 109.11s/it]

Total reward after episode 7 is tensor([[623.]])


KeyboardInterrupt: 