In [2]:
!pip install gym

Collecting gym
  Downloading gym-0.26.2.tar.gz (721 kB)
     ---------------------------------------- 0.0/721.7 kB ? eta -:--:--
     ---------------------------------------- 0.0/721.7 kB ? eta -:--:--
     ---------------------------------------- 0.0/721.7 kB ? eta -:--:--
     ---------------------------------------- 0.0/721.7 kB ? eta -:--:--
     -------------- ------------------------- 262.1/721.7 kB ? eta -:--:--
     -------------- ------------------------- 262.1/721.7 kB ? eta -:--:--
     -------------- ------------------------- 262.1/721.7 kB ? eta -:--:--
     -------------------------- --------- 524.3/721.7 kB 419.4 kB/s eta 0:00:01
     -------------------------- --------- 524.3/721.7 kB 419.4 kB/s eta 0:00:01
     --------------------------------------- 721.7/721.7 kB 426.3 kB/s  0:00:01
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel:



In [4]:
import torch
from torch import nn, optim
from torch.nn import functional as F
import torch.multiprocessing as mp
import numpy as np
import gymnasium as gym
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)

# === Actor-Critic Neural Network ===
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.l1 = nn.Linear(4, 25)
        self.l2 = nn.Linear(25, 50)
        self.actor_lin1 = nn.Linear(50, 2)
        self.l3 = nn.Linear(50, 25)
        self.critic_lin1 = nn.Linear(25, 1)

    def forward(self, x):
        x = F.normalize(x, dim=0)
        y = F.relu(self.l1(x))
        y = F.relu(self.l2(y))
        actor = F.log_softmax(self.actor_lin1(y), dim=0)
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c))
        return actor, critic


# === Run an Episode ===
def run_episode(worker_env, worker_model):
    state, _ = worker_env.reset()
    state = torch.from_numpy(state).float()
    values, logprobs, rewards = [], [], []
    done = False
    while not done:
        policy, value = worker_model(state)
        values.append(value)
        logits = policy.view(-1)
        action_dist = torch.distributions.Categorical(logits=logits)
        action = action_dist.sample()
        logprob_ = policy.view(-1)[action]
        logprobs.append(logprob_)
        result = worker_env.step(action.item())
        if len(result) == 5:
            next_state, reward, terminated, truncated, info = result
            done = terminated or truncated
        else:
            next_state, reward, done, info = result
        if done:
            reward = -10
            worker_env.reset()
        else:
            reward = 1.0
        rewards.append(reward)
        state = torch.from_numpy(next_state).float()
    return values, logprobs, rewards


# === Update Parameters ===
def update_params(worker_opt, values, logprobs, rewards, clc=0.1, gamma=0.95):
    rewards = torch.tensor(rewards).flip(dims=(0,)).view(-1)
    logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
    values = torch.stack(values).flip(dims=(0,)).view(-1)
    Returns = []
    ret_ = torch.tensor([0.0])
    for r in range(rewards.shape[0]):
        ret_ = rewards[r] + gamma * ret_
        Returns.append(ret_)
    Returns = torch.stack(Returns).view(-1)
    Returns = F.normalize(Returns, dim=0)
    actor_loss = -1 * logprobs * (Returns - values.detach())
    critic_loss = torch.pow(values - Returns, 2)
    loss = actor_loss.sum() + clc * critic_loss.sum()
    loss.backward()
    worker_opt.step()
    return actor_loss, critic_loss, len(rewards)


In [5]:
def worker(t, worker_model, counter, params):
    worker_env = gym.make("CartPole-v1")
    worker_opt = optim.Adam(lr=1e-4, params=worker_model.parameters())
    worker_opt.zero_grad()
    for i in range(params['epochs']):
        worker_opt.zero_grad()
        values, logprobs, rewards = run_episode(worker_env, worker_model)
        actor_loss, critic_loss, eplen = update_params(worker_opt, values, logprobs, rewards)
        counter.value += 1


if __name__ == "__main__":
    MasterNode = ActorCritic()
    MasterNode.share_memory()

    params = {'epochs': 1000, 'n_workers': 4}
    counter = mp.Value('i', 0)
    processes = []

    for i in range(params['n_workers']):
        p = mp.Process(target=worker, args=(i, MasterNode, counter, params))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
    for p in processes:
        p.terminate()

    print("Training completed.")
    print("Total updates:", counter.value)


Training completed.
Total updates: 0


In [6]:
env = gym.make("CartPole-v1", render_mode="human")
state, _ = env.reset()

for i in range(100):
    state_t = torch.from_numpy(state).float()
    logits, value = MasterNode(state_t)
    action_dist = torch.distributions.Categorical(logits=logits)
    action = action_dist.sample()
    result = env.step(action.item())
    if len(result) == 5:
        next_state, reward, terminated, truncated, info = result
        done = terminated or truncated
    else:
        next_state, reward, done, info = result
    if done:
        print("Lost")
        state, _ = env.reset()
    else:
        state = next_state
    env.render()

env.close()
print("Testing complete.")


Lost
Lost
Lost
Testing complete.
