<a href="https://colab.research.google.com/github/Boyinglby/ADL_lab/blob/main/ADL_lab4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!python --version

Python 3.10.12


In [None]:
!pip install -r requirements.txt

In [None]:
!pip install gym[atari]
!pip install ale_py
!pip install autorom[accept-rom-license]

In [None]:
import os
import random
import time

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Imports all our hyperparameters from the other file
from hyperparams import Hyperparameters as params

# stable_baselines3 have wrappers that simplifies
# the preprocessing a lot, read more about them here:
# https://stable-baselines3.readthedocs.io/en/master/common/atari_wrappers.html
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer


# Creates our gym environment and with all our wrappers.
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk


class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        # TODO: Deinfe your network (agent)
        # Look at Section 4.1 in the paper for help: https://arxiv.org/pdf/1312.5602v1.pdf

        self.network = nn.Sequential(
            nn.Conv2d(4, 16, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(16, 32, 4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(9*9*32, 512), # (((84-8)/4+1)-4)/2+1 = 9
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n)
        )

    def forward(self, x):
        return self.network(x / 255.0)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

if __name__ == "__main__":
    run_name = f"{params.env_id}__{params.exp_name}__{params.seed}__{int(time.time())}"

    random.seed(params.seed)
    np.random.seed(params.seed)
    torch.manual_seed(params.seed)
    torch.backends.cudnn.deterministic = params.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv([make_env(params.env_id, params.seed, 0, params.capture_video, run_name)])
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    q_network = QNetwork(envs).to(device)
    optimizer = optim.Adam(q_network.parameters(), lr=params.learning_rate)
    target_network = QNetwork(envs).to(device)
    target_network.load_state_dict(q_network.state_dict())

    # We’ll be using experience replay memory for training our DQN.
    # It stores the transitions that the agent observes, allowing us to reuse this data later.
    # By sampling from it randomly, the transitions that build up a batch are decorrelated.
    # It has been shown that this greatly stabilizes and improves the DQN training procedure.
    rb = ReplayBuffer(
        params.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        optimize_memory_usage=False,
        handle_timeout_termination=True,
    )

    obs = envs.reset()
    for global_step in range(params.total_timesteps):
        # Here we get epsilon for our epislon greedy.
        epsilon = linear_schedule(params.start_e, params.end_e, params.exploration_fraction * params.total_timesteps, global_step)

        # epsilon-greedy control the balance between exploration and exploitation
        # It allows the agent to explore and learn more about the environment.
        # Over time, as the agent accumulates better estimates of action values, it becomes more selective (greedy) in its choices.
        if random.random() < epsilon:
            actions = envs.action_space.sample()# TODO: sample a random action from the environment
        else:
            q_values = q_network(torch.tensor(obs))# TODO: get q_values from the network you defined, what should the network receive as input?
            actions = torch.argmax(q_values, dim=1).cpu().numpy()

        # Take a step in the environment
        next_obs, rewards, dones, infos = envs.step(actions)

        # Here we print our reward.
        for info in infos:
            if "episode" in info.keys():
                print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                break

        # Save data to replay buffer
        real_next_obs = next_obs.copy()
        for idx, d in enumerate(dones):
            if d:
                real_next_obs[idx] = infos[idx]["terminal_observation"]

        # Here we store the transitions in D
        rb.add(obs, real_next_obs, actions, rewards, dones, infos)

        obs = next_obs
        # Training
        if global_step > params.learning_starts:
            if global_step % params.train_frequency == 0:
                # Sample random minibatch of transitions from D
                data = rb.sample(params.batch_size)
                # You can get data with:
                # data.observation, data.rewards, data.dones, data.actions

                with torch.no_grad():
                    # Now we calculate the y_j for non-terminal phi.
                    target_max, _ = q_network(data.real_next_obs).max(1)# TODO: Calculate max Q
                    td_target = data.rewards + params.gamma * target_max * (1 - data.dones)# TODO: Calculate the td_target (y_j)
                                                                                          # Bellman equation Q = R + gamma*Q_next_max
                old_val = q_network(data.obs).gather(1, data.actions).squeeze()
                loss = F.mse_loss(old_val, td_target) # we want Q estimates the optimal policy

                # perform our gradient decent step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # update target network
            if global_step % params.target_network_frequency == 0:
                for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                    target_network_param.data.copy_(
                        params.tau * q_network_param.data + (1.0 - params.tau) * target_network_param.data
                    )

    if params.save_model:
        model_path = f"runs/{run_name}/{params.exp_name}_model"
        torch.save(q_network.state_dict(), model_path)
        print(f"model saved to {model_path}")

    envs.close()


  deprecation(
  deprecation(
  logger.deprecation(
  self.pid = _posixsubprocess.fork_exec(
  if distutils.version.LooseVersion(
  deprecation(
  logger.deprecation(
  self.pid = _posixsubprocess.fork_exec(
  if distutils.version.LooseVersion(
  deprecation(


global_step=157, episodic_return=1.0
global_step=373, episodic_return=3.0
global_step=488, episodic_return=0.0
global_step=698, episodic_return=2.0
global_step=813, episodic_return=0.0
global_step=1001, episodic_return=2.0
global_step=1191, episodic_return=2.0


  logger.deprecation(
  self.pid = _posixsubprocess.fork_exec(
  if distutils.version.LooseVersion(
  deprecation(


global_step=1306, episodic_return=0.0
global_step=1465, episodic_return=1.0
global_step=1654, episodic_return=2.0
global_step=1769, episodic_return=0.0
global_step=1956, episodic_return=2.0
global_step=2194, episodic_return=3.0
global_step=2382, episodic_return=2.0
global_step=2541, episodic_return=1.0
global_step=2910, episodic_return=6.0
global_step=3069, episodic_return=1.0
global_step=3323, episodic_return=4.0
global_step=3465, episodic_return=1.0
global_step=3622, episodic_return=1.0
global_step=3735, episodic_return=0.0
global_step=3894, episodic_return=1.0
global_step=4055, episodic_return=1.0
global_step=4168, episodic_return=0.0
global_step=4279, episodic_return=0.0
global_step=4390, episodic_return=0.0


  logger.deprecation(
  self.pid = _posixsubprocess.fork_exec(
  if distutils.version.LooseVersion(
  deprecation(


global_step=4533, episodic_return=1.0
global_step=4648, episodic_return=0.0
global_step=4834, episodic_return=2.0
global_step=5051, episodic_return=3.0
global_step=5212, episodic_return=1.0
global_step=5327, episodic_return=0.0
global_step=5466, episodic_return=1.0
global_step=5625, episodic_return=1.0
global_step=5836, episodic_return=2.0
global_step=6094, episodic_return=3.0
global_step=6253, episodic_return=1.0
global_step=6392, episodic_return=1.0
global_step=6628, episodic_return=3.0
global_step=6743, episodic_return=0.0
global_step=7030, episodic_return=4.0
global_step=7143, episodic_return=0.0
global_step=7283, episodic_return=1.0
global_step=7568, episodic_return=4.0
global_step=7726, episodic_return=1.0
global_step=7889, episodic_return=1.0
global_step=8048, episodic_return=1.0
global_step=8189, episodic_return=1.0
global_step=8304, episodic_return=0.0
global_step=8687, episodic_return=6.0
global_step=8893, episodic_return=2.0
global_step=9077, episodic_return=2.0
global_step=

  logger.deprecation(
  self.pid = _posixsubprocess.fork_exec(
  if distutils.version.LooseVersion(
  deprecation(


global_step=11233, episodic_return=1.0
global_step=11421, episodic_return=2.0
global_step=11564, episodic_return=1.0
global_step=11725, episodic_return=1.0
global_step=11840, episodic_return=0.0
global_step=11951, episodic_return=0.0
global_step=12110, episodic_return=1.0
global_step=12283, episodic_return=2.0
global_step=12442, episodic_return=1.0
global_step=12555, episodic_return=0.0
global_step=12668, episodic_return=0.0
global_step=12783, episodic_return=0.0
global_step=12942, episodic_return=1.0
global_step=13106, episodic_return=1.0
global_step=13291, episodic_return=2.0


KeyboardInterrupt: 