In [50]:
import time
import random
from collections import deque

import numpy as np
import matplotlib.pyplot as plt

In [51]:
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.functional as F

import gymnasium as gym

In [52]:
class NN(nn.Module):

    def __init__(self, n_states, n_actions):
        super().__init__()
        self.n_states = n_states
        self.n_actions = n_actions
        
        self.relu = nn.ReLU()
        
        self.fc1 = nn.Linear(self.n_states, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 256)
        
        self.fcv4 = nn.Linear(256, 128)
        self.fcv5 = nn.Linear(128, self.n_actions)

        self.fca4 = nn.Linear(256, 128)
        self.fca5 = nn.Linear(128, 1)

    def forward(self, states):
        x = self.relu(self.fc3(self.relu(self.fc2(self.relu(self.fc1(states))))))

        values = self.fcv5(self.relu(self.fcv4(x)))
        adv = self.fca5(self.relu(self.fca4(x)))
                                                                         
        return values + (adv - torch.mean(adv, dim = 0, keepdim = True))

In [53]:
class PrioritizedReplayBuffer:
    def __init__(self, size = 100000, steps = 5, alpha = 0.9, beta = 0.1, beta_increment = 0.001, ep_err = 1e-5, gamma = 0.9):
        self.size = size
        self.steps = steps
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.gamma = gamma
        self.episilon = ep_err

        self.priority = deque(maxlen = size)
        self.buffer = deque(maxlen = size)
        self.n_step_buffer = deque(maxlen = steps)

    def add(self, states, actions, reward, next_state, done):
        item = (states, actions, reward, next_state, done)
        self.n_step_buffer.append(item)
        if len(self.n_step_buffer) < self.steps and not done:
            return 
        
        state, action = self.n_step_buffer[0][:2]
        R = 0
        for idx, (_, _, r, n_s, d) in enumerate(self.n_step_buffer):
            R += (self.gamma ** idx) * r
            if d or idx == self.steps - 1:
                next_state, done, steps = n_s, d, idx + 1
                break
        
        self.buffer.append((state, action, R, next_state, done, steps))
        self.priority.append(max(self.priority) if self.priority else 1.0)

    def sample(self, batch_size):
        prios = np.array(list(self.priority))
        prob = prios ** self.alpha
        prob /= prob.sum()

        indices = np.random.choice(len(self.buffer), batch_size, p = prob)
        samples = [self.buffer[i] for i in indices]

        weights = (len(self.buffer) * prob[indices]) ** (-self.beta)
        weights /= weights.max()

        self.beta = np.min([1, self.beta *(1 + self.beta_increment)])

        state, action, reward, next_state, done, steps = zip(*samples)
        return state, action, reward, next_state, done, steps,indices, weights

    def update_priorities(self, indices, td_error):
        for idx, td_error in zip(indices, td_error):
            self.priority[idx] = abs(td_error) + self.episilon

        
        

In [54]:
class Agent:

    def __init__(self, env ,n_states, n_actions, gamma = 0.9, episilon = 1, episilon_decay = 0.995, episilon_min = 0.01):
        self.env = env
        self.n_states = n_states
        self.n_actions = n_actions
        self.gamma = gamma
        self.episilon = episilon
        self.episilon_decay = episilon_decay
        self.episilon_min = episilon_min
        self.target_update_freq = 100
        self.batch_size = 64
        self.steps = 0

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.policy_net = NN(n_states, n_actions)
        self.target_net = NN(n_states, n_actions)
        self.buffer = PrioritizedReplayBuffer()

        self.policy_net.to(self.device)
        self.target_net.to(self.device)

        self.loss = nn.MSELoss(reduction = 'none')
        self.optimizer = Adam(self.policy_net.parameters(), lr = 1e-5)
        self.addMemory()

    def addMemory(self):
        sum_rewards = 0
        state = self.env.reset()[0]
        for _ in range(1000):

            action = self.env.action_space.sample()
            next_state, reward, term, trun, _ = self.env.step(action)
            done = int(term or trun)
            sum_rewards += 1
            self.buffer.add(state, action, reward, next_state, done)
            if done:
                state = self.env.reset()[0]
            else:
                state = next_state

    def select_action(self, state):
        if random.random() < self.episilon:
            return self.env.action_space.sample()
        
        state = torch.tensor(state).to(self.device)
        with torch.no_grad():
            qValues = self.policy_net(state)
        return torch.argmax(qValues).item()


    def train(self):
        
        # if len(self.buffer) < self.batch_size:
        #     return

        states, actions, rewards, next_states, done, steps, indices, weights = self.buffer.sample(self.batch_size)
        # print(type(states))
        states = torch.tensor(np.array(states), dtype = torch.float32).to(self.device)
        actions = torch.tensor(actions, dtype = torch.int64).unsqueeze(1).to(self.device)
        rewards = torch.tensor(rewards, dtype = torch.float32).unsqueeze(1).to(self.device)
        next_states = torch.tensor(next_states, dtype = torch.float32).to(self.device)
        done = torch.tensor(done, dtype = torch.float32).unsqueeze(1).to(self.device)
        steps = torch.tensor(steps, dtype = torch.float32).unsqueeze(1).to(self.device)
        weights = torch.tensor(weights, dtype = torch.float32).unsqueeze(1).to(self.device)

        qValue = self.policy_net(states).gather(1, actions)

        with torch.no_grad():
            next_actions = self.policy_net(next_states).argmax(1, keepdim = True)
            target_net = self.target_net(next_states).gather(1, next_actions)
            target_qValues = rewards + (self.gamma ** steps) * target_net * (1 - done)
        
        td_errors = target_qValues - qValue
        
        loss = self.loss(qValue, target_qValues)
        loss = (weights * loss).mean()
        self.loss_value += loss
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
        self.optimizer.step()

        new_priorities = td_errors.detach().abs().cpu().flatten()
        self.buffer.update_priorities(indices, new_priorities)

        self.steps += 1
        if self.steps % self.target_update_freq == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
    
    
    def playingLoop(self, num_episodes = 500):
        self.num_episodes = num_episodes
        self.rewards_per_episode = []
        self.loss_per_episode = []

        for ep in range(num_episodes):
            self.loss_value = 0
            state = self.env.reset()[0]
            episode_reward = 0
            done = 0

            while not done:

                action = self.select_action(state)
                next_state, reward, term, trun, _ = self.env.step(action)
                done = int(term or trun)
                self.buffer.add(state, action, reward, next_state, done)

                self.train()

                state = next_state
                episode_reward += reward

            # print((1-self.episilon_decay) * self.episilon)   
            self.episilon = max(self.episilon_min, self.episilon_decay * self.episilon)
            self.rewards_per_episode.append(episode_reward)
            self.loss_per_episode.append(self.loss_value.item())

            if ep % 50 == 0:
                avg_reward = np.mean(self.rewards_per_episode[-10:])
                print(f"Episode {ep} | Avg Reward: {avg_reward:.2f} | Epsilon: {self.episilon:.2f}")

    
    def plot(self, interval = 50):
        moving_average_rewards = np.convolve(self.rewards_per_episode, np.ones(interval)/interval, mode = "same")
        moving_average_loss = np.convolve(self.loss_per_episode, np.ones(interval)/interval, mode = "same")
        plt.plot(range(self.num_episodes), moving_average_rewards, label = "Moving Average")
        plt.plot(range(self.num_episodes), moving_average_loss, label = "Loss Per Episode")
        plt.xlabel("Episodes")
        plt.ylabel("Values per episode")

In [None]:
episodes = 1000
env = gym.make("CartPole-v1", render_mode = "rgb_array")
agent = Agent(env, 4, 2)
rewards =agent.playingLoop(episodes)