<a href="https://colab.research.google.com/github/Loki-33/RL-Algos/blob/main/A2CRITIC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import gym
import torch.nn as nn
import torch.optim as optim

In [None]:
env = gym.make('CartPole-v1', render_mode='rgb_array')

  deprecation(
  deprecation(


In [None]:
n_obs = env.observation_space.shape[0]
n_actions = env.action_space.n

In [None]:
gamma = 0.99
lr = 1e-3
entropy_coef = 0.01
value_coef = 0.5
episodes = 1000

In [None]:
class A2C(nn.Module):
  def __init__(self, n_obs, n_acts):
    super().__init__()

    self.net = nn.Sequential(
      nn.Linear(n_obs, 128),
      nn.ReLU(),
    )

    self.actor = nn.Linear(128, n_acts)
    self.critic = nn.Linear(128, 1)

  def forward(self, x):
    x = self.net(x)
    return self.actor(x), self.critic(x).squeeze(-1)

In [None]:
model = A2C(n_obs, n_actions)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
def compute_returns(rewards, dones, next_value, gamma=0.99):
  returns = np.zeros(len(rewards))
  R = next_value
  for i in reversed(range(len(rewards))):
    R = rewards[i] + gamma * R * (1 - dones[i])
    returns[i] = R
  return returns

In [None]:
batch_data = []
episode_reward=[]

for ep in range(episodes):
  state = env.reset()
  done = False
  rewards = []
  dones = []
  values = []
  log_probs = []
  entropies = []

  while not done:
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    logits, value = model(state_tensor)
    probs = torch.softmax(logits, dim=-1)
    dist = torch.distributions.Categorical(probs)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    entropy = dist.entropy()
    next_state, reward, done, truncated= env.step(action.item())

    dones.append(bool(done) or bool(truncated))
    values.append(value)
    rewards.append(reward)
    log_probs.append(log_prob)
    entropies.append(entropy)

    state = next_state

  state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
  _, next_value = model(state_tensor)
  returns = compute_returns(rewards, dones, next_value)

  returns = torch.tensor(returns)
  values = torch.cat(values).squeeze()
  log_probs = torch.stack(log_probs)
  entropies = torch.stack(entropies)

  advantage = returns - values.detach()
  advantage = (advantage-advantage.mean()) / (advantage.std() + 1e-8)
  policy_loss = -(log_probs * advantage).mean()
  value_loss = value_coef * (returns - values).pow(2).mean()
  entropy_loss = entropy_coef * entropies.mean()

  loss = policy_loss + value_loss + entropy_loss

  optimizer.zero_grad()
  loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
  optimizer.step()

  if ep % 10 == 0:
    print(f"Episode {ep}, Return: {sum(rewards):.2f}, Loss: {loss.item():.4f}")


Episode 0, Return: 10.00, Loss: 21.9274
Episode 10, Return: 56.00, Loss: 373.8926
Episode 20, Return: 16.00, Loss: 44.4280
Episode 30, Return: 16.00, Loss: 42.5700
Episode 40, Return: 13.00, Loss: 28.0122
Episode 50, Return: 15.00, Loss: 34.5982
Episode 60, Return: 10.00, Loss: 14.2502
Episode 70, Return: 30.00, Loss: 120.1060
Episode 80, Return: 12.00, Loss: 18.3501
Episode 90, Return: 26.00, Loss: 86.7701
Episode 100, Return: 12.00, Loss: 16.4571
Episode 110, Return: 19.00, Loss: 42.2440
Episode 120, Return: 24.00, Loss: 64.6781
Episode 130, Return: 84.00, Loss: 596.1627
Episode 140, Return: 16.00, Loss: 23.9809
Episode 150, Return: 43.00, Loss: 180.6599
Episode 160, Return: 12.00, Loss: 9.4395
Episode 170, Return: 36.00, Loss: 126.1465
Episode 180, Return: 49.00, Loss: 215.8163
Episode 190, Return: 57.00, Loss: 273.1129
Episode 200, Return: 50.00, Loss: 208.8251
Episode 210, Return: 27.00, Loss: 51.9026
Episode 220, Return: 60.00, Loss: 273.4620
Episode 230, Return: 23.00, Loss: 24.

In [None]:
from IPython.display import clear_output
import time

In [None]:
state = env.reset()
total_reward = 0
done=False

while not done:
  frame = env.render()
  frame=frame[0]

  plt.imshow(frame)
  plt.axis('off')
  display(plt.gcf())
  clear_output(wait=True)
  time.sleep(0.03)

  with torch.no_grad():
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    logits, _ = model(state_tensor)
    probs = torch.softmax(logits, dim=-1)
    dist = torch.distributions.Categorical(probs)
    action = dist.sample().item()

  next_state, reward, done, truncated = env.step(action)
  total_reward += reward
  state = next_state
env.close()
print(f"Total Reward: {total_reward}")