### Imports

In [None]:
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchmetrics.classification import BinaryAccuracy

### Hyper Parameters

In [None]:

episodes = 100

capacity = 1e4
sample_size = 4

learning_rate = 0.0001
discount_factor = 0.99
interpolation_parameter = 1e-3

epsilon_start = 1.0
epsilon_decay_rate = 1.7 / episodes
epsilon_end = 0.05

### Agents Environment

In [None]:
env = gym.make("CartPole-v1", render_mode=None)

state_size = env.observation_space.shape[0]
action_size = env.action_space.n

### Atrificial Neural Network

In [None]:
class Network(nn.Module):
    def __init__(self, state_size, action_size, seed=42):
        super(Network, self).__init__()
        self.seed = seed
        self.fc1 = nn.Linear(state_size,720)
        self.fc2 = nn.Linear(720,720)
        self.fc3 = nn.Linear(720,720)
        self.fc4 = nn.Linear(720,action_size)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

### Replay Memory

In [None]:
class Memory(object):
    def __init__(self, capacity):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.capacity = capacity
        self.memory_bank = []

    def push(self, event):
        self.memory_bank.append(event)
        if len(self.memory_bank) > self.capacity:
            del self.memory_bank[0]

    def sample(self, sample_size):
        experiences = random.sample(self.memory_bank, sample_size)
        states = torch.from_numpy(np.vstack([e[0] for e in experiences if e is not None])).float().to(self.device)
        actions = torch.from_numpy(np.vstack([e[1] for e in experiences if e is not None])).long().to(self.device)
        rewards = torch.from_numpy(np.vstack([e[2] for e in experiences if e is not None])).float().to(self.device)
        next_states = torch.from_numpy(np.vstack([e[3] for e in experiences if e is not None])).float().to(self.device)
        terminations = torch.from_numpy(np.vstack([e[4] for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)
        return states,next_states,actions,rewards,terminations

### Agent

In [None]:
class Agent():
    def __init__(self, state_size, action_size):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.state_size = state_size
        self.action_size = action_size
        self.local_network = Network(state_size,action_size).to(self.device)
        self.target_network = Network(state_size,action_size).to(self.device)
        self.optimizer = optim.Adam(self.local_network.parameters(), lr=learning_rate, amsgrad=True)
        self.memory = Memory(capacity)
        self.t_step = 0

    def act(self, state, epsilon):
        pole_angle = state[2]
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        self.local_network.eval()
        with torch.no_grad():
            action = self.local_network(state)
        self.local_network.train()
        if random.random() < epsilon:
            return 0 if pole_angle < 0 else 1
        else:
            return torch.argmax(action).item()

    def step(self, state, action, reward, next_state, terminated):
        self.memory.push((state,action,reward,next_state,terminated))
        self.t_step = (self.t_step + 1) % 4
        if self.t_step == 0:
            if len(self.memory.memory_bank) > sample_size:
                experiences = self.memory.sample(sample_size)
                self.learn(experiences, discount_factor)

    def learn(self, experiences, discount_factor):
        states, next_states, actions, rewards, terminations = experiences
        next_q_targets = self.target_network(next_states).detach().max(1)[0].unsqueeze(1)
        q_targets = rewards + discount_factor * next_q_targets * (1 - terminations)
        q_expected = self.local_network(states).gather(1, actions)
        # acc = BinaryAccuracy().to(self.device)
        # preds = torch.max(self.local_network(states), dim=1)[1]
        # targs = torch.max(self.target_network(next_states), dim=1)[1]
        loss = F.mse_loss(q_expected, q_targets)
        # print(f'ACC: {acc(preds, targs)*100}%\tLoss: {loss}')
        self.optimizer.step()
        self.soft_update(self.local_network, self.target_network, interpolation_parameter)

    def soft_update(self, local_model, target_model, interpolation_parameter):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(interpolation_parameter * local_param.data + (1.0 - interpolation_parameter) * target_param.data)

### Initiate Agent

In [None]:
agent = Agent(state_size, action_size)

### Training

In [None]:
epsilon = epsilon_start
best = 0

for e in range(1, episodes+1):
    terminated = False
    state,_ = env.reset()
    score = 0
    p_angle = state[2]
    while(not terminated):
        action = agent.act(state, epsilon)
        next_state,reward,terminated,_,_ = env.step(action)
        score += reward
        if p_angle < 0:
            if next_state[2] < p_angle:
                reward = -.25
        elif p_angle > 0:
            if next_state[2] > p_angle:
                reward = -.25
        elif terminated:
            reward = -20
        agent.step(state, action, reward, next_state, terminated)
        state = next_state
        p_angle = next_state[2]

    epsilon = max(epsilon_end, epsilon - epsilon_decay_rate)
    print(f'Episode: {e}\tBest Score: {score}')

# print(f'Best Epoch: {best[0]}, Best Score: {best[1]}')
env.close()

### Test And Visualize

In [None]:
env = gym.make("CartPole-v1", render_mode='human')

for e in range(1, 6):
    terminated = False
    state,_ = env.reset()
    score = 0
    p_angle = state[2]
    while(not terminated):
        action = agent.act(state, 0.05)
        next_state,reward,terminated,_,_ = env.step(action)
        score += reward
        state = next_state
    print(f'Episode: {e}\tScore: {score}')

env.close()