<a href="https://colab.research.google.com/github/Vyshnavijulapelly/Reinforcement-Learning/blob/main/RL_Lab_05.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import argparse
import collections
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import cv2


# Replay buffer
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (
            np.stack(state),
            np.array(action),
            np.array(reward, dtype=np.float32),
            np.stack(next_state),
            np.array(done, dtype=np.uint8),
        )

    def __len__(self):
        return len(self.buffer)


# Q-networks
class DQN_MLP(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim),
        )

    def forward(self, x):
        return self.net(x)


class DQN_CNN(nn.Module):
    def __init__(self, input_shape, act_dim):
        super().__init__()
        c, h, w = input_shape
        assert h == 84 and w == 84, "CNN expects 84x84 input"
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, act_dim),
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        x = x / 255.0
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)


# Atari preprocessing wrappers
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._skip = skip
        self._obs_buffer = np.zeros(
            (2,) + env.observation_space.shape, dtype=np.uint8
        )

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if done:
                break
        max_frame = self._obs_buffer.max(axis=0)
        return max_frame, total_reward, terminated, truncated, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class FrameProcessor(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
        )

    def observation(self, obs):
        obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
        return np.expand_dims(obs, -1)


class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        super().__init__(env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.k):
            self.frames.append(obs)
        return self._get_obs(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, terminated, truncated, info

    def _get_obs(self):
        return np.concatenate(list(self.frames), axis=-1)


# Environment builder
def make_env(env_id: str, seed: int, cnn: bool, frameskip: int, frame_stack: int):
    env = gym.make(env_id, frameskip=1) if "ALE/" in env_id else gym.make(env_id)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    if cnn:
        if "ALE/" not in env_id:
            raise ValueError("--cnn was set but env is not Atari. Use ALE/Breakout-v5 etc.")
        env = MaxAndSkipEnv(env, skip=frameskip)
        env = gym.wrappers.TransformObservation(
            env,
            lambda obs: obs[:, :, ::-1] if obs is not None and obs.ndim == 3 else obs,
            observation_space=env.observation_space,
        )
        env = FrameProcessor(env)
        env = FrameStack(env, k=frame_stack)
    else:
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(
            env,
            lambda x: x.astype(np.float32),
            observation_space=gym.spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=env.observation_space.shape,
                dtype=np.float32,
            ),
        )

    env = gym.wrappers.RecordEpisodeStatistics(env)
    return env


# Training loop
def train(args):
    env = make_env(args.env, args.seed, args.cnn, args.frameskip, args.frame_stack)
    obs_shape = env.observation_space.shape
    act_dim = env.action_space.n

    if args.cnn:
        obs_shape = (obs_shape[2], obs_shape[0], obs_shape[1])  # HWC->CHW
        q_net = DQN_CNN(obs_shape, act_dim)
        target_q_net = DQN_CNN(obs_shape, act_dim)
    else:
        q_net = DQN_MLP(obs_shape[0], act_dim)
        target_q_net = DQN_MLP(obs_shape[0], act_dim)

    target_q_net.load_state_dict(q_net.state_dict())
    optimizer = optim.Adam(q_net.parameters(), lr=args.lr)
    buffer = ReplayBuffer(args.buffer_size)

    epsilon = args.eps_start
    epsilon_decay = (args.eps_start - args.eps_end) / args.eps_decay_steps
    global_step = 0
    episode_rewards = []

    obs, _ = env.reset(seed=args.seed)
    if args.cnn:
        obs = np.transpose(obs, (2, 0, 1))

    while global_step < args.total_steps:
        # Epsilon-greedy
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
                if args.cnn:
                    obs_t = obs_t.to(next(q_net.parameters()).device)
                q_values = q_net(obs_t)
                action = q_values.argmax(dim=1).item()

        next_obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        if args.cnn:
            next_obs_proc = np.transpose(next_obs, (2, 0, 1))
            buffer.push(obs, action, reward, next_obs_proc, done)
            obs = next_obs_proc
        else:
            buffer.push(obs, action, reward, next_obs, done)
            obs = next_obs

        if done:
            obs, _ = env.reset()
            if args.cnn:
                obs = np.transpose(obs, (2, 0, 1))
            if "episode" in info:
                ep_r = info["episode"]["r"]
                episode_rewards.append(ep_r)
                if len(episode_rewards) % args.log_interval == 0:
                    avg_r = np.mean(episode_rewards[-args.log_interval :])
                    print(f"Step {global_step}, AvgReward {avg_r:.2f}, Eps {epsilon:.3f}")

        # Training step
        if (
            len(buffer) >= args.learning_starts
            and global_step % args.train_freq == 0
        ):
            (
                batch_obs,
                batch_actions,
                batch_rewards,
                batch_next_obs,
                batch_dones,
            ) = buffer.sample(args.batch_size)

            batch_obs_t = torch.tensor(batch_obs, dtype=torch.float32)
            batch_actions_t = torch.tensor(batch_actions, dtype=torch.int64)
            batch_rewards_t = torch.tensor(batch_rewards, dtype=torch.float32)
            batch_next_obs_t = torch.tensor(batch_next_obs, dtype=torch.float32)
            batch_dones_t = torch.tensor(batch_dones, dtype=torch.float32)

            if args.cnn:
                batch_obs_t = batch_obs_t / 255.0
                batch_next_obs_t = batch_next_obs_t / 255.0

            q_values = q_net(batch_obs_t).gather(1, batch_actions_t.unsqueeze(1)).squeeze(1)

            with torch.no_grad():
                if args.double:
                    next_actions = q_net(batch_next_obs_t).argmax(dim=1)
                    next_q = target_q_net(batch_next_obs_t).gather(1, next_actions.unsqueeze(1)).squeeze(1)
                else:
                    next_q = target_q_net(batch_next_obs_t).max(1)[0]
                target = batch_rewards_t + args.gamma * (1 - batch_dones_t) * next_q

            loss = nn.SmoothL1Loss()(q_values, target)

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
            optimizer.step()

        if global_step % args.target_update == 0:
            target_q_net.load_state_dict(q_net.state_dict())

        epsilon = max(args.eps_end, epsilon - epsilon_decay)
        global_step += 1

    print("Training finished.")
    env.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="CartPole-v1")
    parser.add_argument("--total-steps", type=int, default=5000)  # shorter for Colab demo
    parser.add_argument("--buffer-size", type=int, default=10000)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--eps-start", type=float, default=1.0)
    parser.add_argument("--eps-end", type=float, default=0.1)
    parser.add_argument("--eps-decay-steps", type=int, default=1000)
    parser.add_argument("--target-update", type=int, default=1000)
    parser.add_argument("--learning-starts", type=int, default=1000)
    parser.add_argument("--train-freq", type=int, default=1)
    parser.add_argument("--cnn", action="store_true")
    parser.add_argument("--frameskip", type=int, default=4)
    parser.add_argument("--frame-stack", type=int, default=4)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--double", action="store_true")
    parser.add_argument("--log-interval", type=int, default=10)

    # ✅ Ignore Jupyter/Colab args
    args = parser.parse_args(args=[])

    train(args)


Step 200, AvgReward 20.10, Eps 0.820
Step 367, AvgReward 16.70, Eps 0.670
Step 525, AvgReward 15.80, Eps 0.527
Step 660, AvgReward 13.50, Eps 0.406
Step 789, AvgReward 12.90, Eps 0.290
Step 891, AvgReward 10.20, Eps 0.198
Step 1000, AvgReward 10.90, Eps 0.100
Step 1108, AvgReward 10.80, Eps 0.100
Step 1213, AvgReward 10.50, Eps 0.100
Step 1310, AvgReward 9.70, Eps 0.100
Step 1403, AvgReward 9.30, Eps 0.100
Step 1502, AvgReward 9.90, Eps 0.100
Step 1600, AvgReward 9.80, Eps 0.100
Step 1698, AvgReward 9.80, Eps 0.100
Step 1798, AvgReward 10.00, Eps 0.100
Step 1898, AvgReward 10.00, Eps 0.100
Step 1991, AvgReward 9.30, Eps 0.100
Step 2110, AvgReward 11.90, Eps 0.100
Step 2210, AvgReward 10.00, Eps 0.100
Step 2328, AvgReward 11.80, Eps 0.100
Step 2495, AvgReward 16.70, Eps 0.100
Step 2654, AvgReward 15.90, Eps 0.100
Step 2766, AvgReward 11.20, Eps 0.100
Step 2903, AvgReward 13.70, Eps 0.100
Step 3060, AvgReward 15.70, Eps 0.100
Step 3270, AvgReward 21.00, Eps 0.100
Step 3468, AvgReward 19.