<a href="https://colab.research.google.com/github/YasJanam/RL1/blob/main/ActorCritic_5/ActorCritic_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### **Requirement**

In [None]:
#!pip uninstall -y gym gymnasium box2d box2d-py

In [1]:
!pip install gymnasium==0.29.1
!pip install swig
!pip install box2d-py

Collecting gymnasium==0.29.1
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gymnasium
  Attempting uninstall: gymnasium
    Found existing installation: gymnasium 1.2.1
    Uninstalling gymnasium-1.2.1:
      Successfully uninstalled gymnasium-1.2.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.1 which is incompatible.[0m[31m
[0mSuccessfully installed gymnasium-0.29.1
Collecting swig
  Downloading swig-4.3.1.post0-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (3.5 kB)
Downloading swig-4.3.1.post0-py3-none-manylinux_2_12_x86_64.manylin

In [2]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### **Actor-Critic**

In [3]:
class ActorCritic(nn.Module):
  def __init__(self,state_dim,action_dim,hidden_dim):
    super().__init__()
    self.shared = nn.Sequential(
        nn.Linear(state_dim,hidden_dim),
        nn.ReLU()
    )

    self.actor = nn.Sequential(
        nn.Linear(hidden_dim,action_dim),
        nn.Softmax(dim=-1)
    )

    self.critic = nn.Linear(hidden_dim,1)
  def forward(self,x):
    x = self.shared(x)
    return self.actor(x), self.critic(x)

#### **Actor-Critic Trainer**

In [21]:
class ActorCriticTrainer(nn.Module):
  def __init__(self,env,hidden_dim=128,gamma=0.99,lr=0.01,num_episodes=1000,log_interval=100,max_rewards=float('inf')):
    super().__init__()
    self.env = env
    self.hidden_dim = hidden_dim
    self.gamma = gamma
    self.lr = lr
    self.num_episodes = num_episodes
    self.log_interval = log_interval
    self.max_rewards = max_rewards
    state_dim = self.env.observation_space.shape[0]
    action_dim = self.env.action_space.n
    self.model = ActorCritic(state_dim,action_dim,hidden_dim).to(device)
    self.optimizer = optim.Adam(self.model.parameters(),lr=self.lr)

  def train(self):
    for episode in range(self.num_episodes):
      done = False
      state, _ = self.env.reset(seed=42)
      log_probs, values, rewards = [], [], []
      total_reward = 0

      while not done:
        state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
        probs, value = self.model(state_t)

        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        next_state, reward, terminated, truncated, _ = self.env.step(action.item())
        done = terminated or truncated

        next_state_t = torch.FloatTensor(next_state).unsqueeze(0).to(device)
        next_value = self.model(next_state_t)[1]
        target = reward + (1 - int(done)) * self.gamma * next_value.detach()
        delta = target - value

        actor_loss = -dist.log_prob(action) * delta.detach()
        critic_loss = delta.pow(2)
        loss = actor_loss + 0.5 * critic_loss.mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        state = next_state

        log_probs.append(dist.log_prob(action))
        values.append(value)
        rewards.append(reward)

      # log
      total_reward = sum(rewards)
      if episode % self.log_interval == 0 or episode == (self.num_episodes -1) :
        print(f"Episode {episode}, Total Reward: {total_reward}")

      # early stopping
      if total_reward >= self.max_rewards:
        print(f"Reached max reward {total_reward:.2f}, stopping training!")
        break

    self.env.close()

#### **Test**

In [5]:
def test(model,env):
  state, _ = env.reset()
  done = False
  total_reward = 0

  while not done:
      state_tensor = torch.tensor(state, dtype=torch.float32)

      with torch.no_grad():
          probs = model(state_tensor)[0]
      dist = torch.distributions.Categorical(probs)
      action = dist.sample().item()

      next_state, reward, terminated, truncated, _ = env.step(action)
      done = terminated or truncated

      total_reward += reward
      state = next_state

  return total_reward

In [6]:
def tests(model,env,num_tests):
  total_rewards = []
  for _ in range(num_tests):
    total_reward = test(model,env)
    total_rewards.append(total_reward)
  rewards = [float(x) for x in total_rewards]
  return rewards

#### Labratory

##### CartPole

In [None]:
cartpole = gym.make("CartPole-v1")

In [None]:
cp_trainer = ActorCriticTrainer(env=cartpole,hidden_dim=15,lr=0.001,gamma=0.99,num_episodes=1000)
cp_trainer.train()

##### LunarLander

In [7]:
ll = gym.make('LunarLander-v2')

In [None]:
ll_trainer = ActorCriticTrainer(env=ll,hidden_dim=40,lr=0.0005,gamma=0.99,num_episodes=1500)
ll_trainer.train()

Episode 0, Total Reward: -255.20047094098092
Episode 100, Total Reward: -36.77220490227808
Episode 200, Total Reward: -277.691285019522
Episode 300, Total Reward: -11.405125993455467
Episode 400, Total Reward: -177.71477995920935
Episode 500, Total Reward: -220.52070237102228
Episode 600, Total Reward: 72.90490170835302
Episode 700, Total Reward: 131.9005839922299
Episode 800, Total Reward: 59.68749314118081
Episode 900, Total Reward: 131.6602015669093
Episode 1000, Total Reward: -49.862008796371484
Episode 1100, Total Reward: -165.8433278602023
Episode 1200, Total Reward: 37.198082066680456
Episode 1300, Total Reward: 237.5634051043198
Episode 1400, Total Reward: -70.13732463812649
Episode 1499, Total Reward: 232.38423170693784


In [None]:
res = tests(cp_trainer.model,ll,10)
res

[210.7745528649264,
 208.09202748917886,
 -156.78314090819327,
 -36.2938109469499,
 248.5070988937374,
 44.32950179720703,
 221.78832991182804,
 185.53139697967492,
 204.37582587831378,
 11.825746051306965]