In [None]:
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gym
from gym.wrappers import FrameStack
from torchvision import transforms as T
from gym.spaces import Box
import math

In [47]:
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma_init=1.0):
        super(NoisyLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.register_buffer("weight_epsilon", torch.empty(out_features, in_features))

        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        self.register_buffer("bias_epsilon", torch.empty(out_features))

        self.sigma_init = sigma_init
        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        mu_range = 1 / math.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.sigma_init / math.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.sigma_init / math.sqrt(self.out_features))

    def reset_noise(self):
        self.weight_epsilon.normal_()
        self.bias_epsilon.normal_()

    def forward(self, x):
        if self.training:
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        return F.linear(x, weight, bias)

In [None]:
class QNet(nn.Module):
    def __init__(self, n_actions):
        super(QNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.fc_val = nn.Sequential(
            NoisyLinear(3136, 512),
            nn.ReLU(),
            NoisyLinear(512, n_actions)
        )
        # self.fc_adv = nn.Sequential(
        #     NoisyLinear(3136, 256),
        #     nn.ReLU(),
        #     NoisyLinear(256, n_actions)
        # )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        val = self.fc_val(x)
        # adv = self.fc_adv(x)
        # q = val + (adv - adv.mean(dim=1, keepdim=True))
        return val

    def reset_noise(self):
        for m in self.modules():
            if isinstance(m, NoisyLinear):
                m.reset_noise()

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.len = 0

    def add(self, state, action, reward, next_state, done):
        self.len = min(self.len + 1, self.capacity)

        state = torch.tensor(np.array(state).copy(), dtype=torch.float32)
        next_state = torch.tensor(np.array(next_state).copy(), dtype=torch.float32)
        action = torch.tensor([action], dtype=torch.int64)
        reward = torch.tensor([reward], dtype=torch.float32)
        done = torch.tensor([done], dtype=torch.float32)

        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(torch.stack, zip(*batch))
        return states, actions, rewards, next_states, dones

class DQNVariant:
    def __init__(self, action_size):
        self.action_size = action_size
        self.gamma = 0.99
        self.batch_size = 128
        self.learn_start = 5000
        self.target_update_freq = 1000
        self.update_count = 0
        self.tau = 1.0
        
        self.epsilon = 0.2
        # self.eps_decay = 0.99999975
        # self.eps_min = 0.1
        
        self.testing = False

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.q_net = QNet(action_size).to(self.device)
        self.target_net = QNet(action_size).to(self.device)
        self.update(learning=1.0)
        for p in self.target_net.parameters():
            p.requires_grad = False

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=0.00025)
        self.replay_buffer = ReplayBuffer(50000)

    def get_action(self, state):
        deterministic = True
        if(not self.testing):
            self.q_net.reset_noise()
            deterministic = random.random() > self.epsilon
            # self.epsilon *= self.eps_decay
            # self.epsilon = max(self.eps_min, self.epsilon)

        if(not deterministic): return np.random.randint(self.action_size)
        with torch.no_grad():
            state = torch.tensor(np.array(state).copy(), dtype=torch.float32).unsqueeze(0).to(self.device)
            q_values = self.q_net(state)
            action = torch.argmax(q_values).item()
        return action

    def update(self, learning):
        for target_param, param in zip(self.target_net.parameters(), self.q_net.parameters()):
            target_param.data.copy_(learning * param.data + (1 - learning) * target_param.data)

    def train(self):
        if self.replay_buffer.len < self.learn_start:
            return

        self.update_count += 1
        
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)
        
        self.q_net.reset_noise()
        self.target_net.reset_noise()
        
        q_values = self.q_net(states).gather(1, actions)

        with torch.no_grad():
            next_actions = self.q_net(next_states).argmax(1, keepdim=True)
            next_q_values = self.target_net(next_states).gather(1, next_actions)
            target_q = rewards + (1 - dones) * self.gamma * next_q_values
        
        # with torch.no_grad():
        #     max_next_q = self.target_net(next_states).max(1, keepdim=True)[0]
        #     target_q = rewards + (1 - dones) * self.gamma * max_next_q

        loss = nn.MSELoss()(q_values, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_count += 1
        if self.update_count % self.target_update_freq == 0:
            self.update(learning=self.tau)

In [43]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self.skip = skip

    def step(self, action):
        total_reward = 0.0
        for i in range(self.skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done: break
        return obs, total_reward, done, info

# class GrayScaleObservation(gym.ObservationWrapper):
#     def __init__(self, env):
#         super().__init__(env)
#         obs_shape = self.observation_space.shape[:2]
#         self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

#     def permute_orientation(self, observation):
#         # permute [H, W, C] array to [C, H, W] tensor
#         observation = np.transpose(observation, (2, 0, 1))
#         observation = torch.tensor(observation.copy(), dtype=torch.float)
#         return observation

#     def observation(self, observation):
#         observation = self.permute_orientation(observation)
#         transform = T.Grayscale()
#         observation = transform(observation)
#         return observation


# class ResizeObservation(gym.ObservationWrapper):
#     def __init__(self, env, shape):
#         super().__init__(env)
#         if isinstance(shape, int):
#             self.shape = (shape, shape)
#         else:
#             self.shape = tuple(shape)

#         obs_shape = self.shape + self.observation_space.shape[2:]
#         self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

#     def observation(self, observation):
#         transforms = T.Compose(
#             [T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]
#         )
#         observation = transforms(observation).squeeze(0)
#         return observation

class TransformObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        self.shape = (shape, shape)
        self.observation_space = Box(
            low=0, high=255,
            shape=(1, *self.shape),
            dtype=np.uint8
        )
        self.transform = T.Compose([
            T.Grayscale(),
            T.Resize(self.shape, antialias=True),
            T.Normalize(0, 255)
        ])
        
    def observation(self, observation):
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        observation = self.transform(observation).squeeze(0)
        return observation

In [None]:
import time
def test_agent(agent):
    sim_env = gym_super_mario_bros.make('SuperMarioBros-v0')
    sim_env = JoypadSpace(sim_env, COMPLEX_MOVEMENT)
    sim_env = SkipFrame(sim_env, skip=4)
    sim_env = TransformObservation(sim_env, shape=84)
    sim_env = FrameStack(sim_env, num_stack=4)
    
    agent.testing = True
    agent.q_net.eval()
    
    agent.epsilon = 0.0
    
    state = sim_env.reset()
    done = False
    total_reward = 0
    step = 0

    while not done:
        
        action = agent.get_action(state)
        next_state, reward, done, _ = sim_env.step(action)
        done = done or step >= 5e3
        
        state = next_state
        total_reward += reward
        step += 1
        sim_env.render()
        
        time.sleep(0.1)

    agent.testing = False
    agent.q_net.train()
    
    print(total_reward, step)
    sim_env.close()
    return

In [None]:
def train_agent():    
    torch.autograd.set_detect_anomaly(True)
    env = gym_super_mario_bros.make('SuperMarioBros-v0')
    env = JoypadSpace(env, COMPLEX_MOVEMENT)
    env = SkipFrame(env, skip=4)
    env = TransformObservation(env, shape=84)
    env = FrameStack(env, num_stack=4)

    action_size = env.action_space.n
    state_size = env.observation_space.shape
    agent = DQNVariant(action_size)
    agent.testing = False
    agent.q_net.train()

    def count_trainable_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    model = agent.q_net
    print(f"Total Trainable Parameters: {count_trainable_parameters(model):,}")

    # checkpoint = torch.load("dqn_agent10.pth", map_location=agent.device)
    # agent.q_net.load_state_dict(checkpoint['q_net'])
    # agent.target_net.load_state_dict(checkpoint['target_net'])

    num_episodes = 3000
    reward_history = []
    total_frame = 0

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        total_reward = 0
        step = 0

        while not done:
            
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            done = done or step >= 1e3

            agent.replay_buffer.add(state, action, reward, next_state, done)
            agent.train()

            state = next_state
            total_reward += reward
            step += 1
            total_frame += 1

        reward_history.append(total_reward)
        print(episode, total_reward, agent.epsilon, step)
        if (episode + 1) % 10 == 0:
            print(f"Episode {episode + 1}, Avg. Reward: {np.mean(reward_history[-10:])}")
            # test_agent(agent)
            torch.save({
                'q_net': agent.q_net.state_dict(),
                'target_net': agent.target_net.state_dict(),
            }, f"dqn_agent_{episode + 1}.pth")

In [49]:
train_agent()

  logger.warn(


Total Trainable Parameters: 3,297,454
0 627.0 0.2 1001
1 1326.0 0.2 1001
2 1912.0 0.2 951
3 1442.0 0.2 449
4 1011.0 0.2 341
5 1484.0 0.2 586
6 1392.0 0.2 483
7 1899.0 0.2 319
8 2429.0 0.2 775
9 1930.0 0.2 1001
Episode 10, Avg. Reward: 1545.2
10 1910.0 0.2 324
11 1810.0 0.2 559
12 2266.0 0.2 360
13 1024.0 0.2 190
14 656.0 0.2 1001
15 1992.0 0.2 725
16 897.0 0.2 1001
17 1055.0 0.2 165
18 1259.0 0.2 1001
19 1095.0 0.2 1001
Episode 20, Avg. Reward: 1396.4
20 1453.0 0.2 1001
21 1280.0 0.2 1001
22 1850.0 0.2 329
23 658.0 0.2 1001
24 1353.0 0.2 1001
25 1559.0 0.2 759
26 2113.0 0.2 970
27 658.0 0.2 1001
28 1877.0 0.2 330
29 1503.0 0.2 181
Episode 30, Avg. Reward: 1430.4
30 2126.0 0.2 564
31 1606.0 0.2 715
32 2482.0 0.2 414
33 1106.0 0.2 215
34 1367.0 0.2 626
35 2484.0 0.2 708
36 1439.0 0.2 298
37 1427.0 0.2 263
38 1537.0 0.2 397
39 2026.0 0.2 341
Episode 40, Avg. Reward: 1760.0
40 1752.0 0.2 383
41 1060.0 0.2 168
42 2503.0 0.2 390
43 2418.0 0.2 402
44 2363.0 0.2 298
45 1699.0 0.2 249
46 2518.0

KeyboardInterrupt: 

In [None]:
agent = DQNVariant(12)

checkpoint = torch.load("dqn_agent_230.pth", map_location=agent.device)
agent.q_net.load_state_dict(checkpoint['q_net'])
agent.target_net.load_state_dict(checkpoint['target_net'])

test_agent(agent)

  checkpoint = torch.load("dqn_agent_230.pth", map_location=agent.device)


KeyboardInterrupt: 

In [51]:
import cProfile
import pstats
import io

profiler = cProfile.Profile()
try:
    profiler.enable()
    train_agent()  # long-running code
    profiler.disable()
except KeyboardInterrupt:
    profiler.disable()
    print("Interrupted! Profiling results up to this point:")

s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('cumtime')
ps.print_stats(20)
print(s.getvalue())

Total Trainable Parameters: 1,690,284
0 446.0 0.9958805064246858 4128
Interrupted! Profiling results up to this point:
         3390943 function calls (3267985 primitive calls) in 40.783 seconds

   Ordered by: cumulative time
   List reduced from 536 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.080    0.080   40.783   40.783 C:\Users\Danniel\AppData\Local\Temp\ipykernel_23600\1265275117.py:1(train_agent)
     4257    0.024    0.000   27.975    0.007 c:\Users\Danniel\anaconda3\envs\drl-hw3\lib\site-packages\gym\wrappers\frame_stack.py:116(step)
     4257    0.018    0.000   27.912    0.007 c:\Users\Danniel\anaconda3\envs\drl-hw3\lib\site-packages\gym\core.py:313(step)
     4257    0.039    0.000   24.259    0.006 C:\Users\Danniel\AppData\Local\Temp\ipykernel_23600\629983895.py:6(step)
    17027    0.020    0.000   24.221    0.001 c:\Users\Danniel\anaconda3\envs\drl-hw3\lib\site-packages\nes_py\wrappers\joypad_spac