In [1]:
import matplotlib.pyplot as plt
from IPython import display

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from tqdm.notebook import tqdm
from gym.wrappers import StepAPICompatibility
from tqdm import tqdm

In [2]:
seed = 2023
def fix(env, seed):
    env.reset(seed=seed)
    env.action_space.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
      torch.cuda.manual_seed_all(seed)

In [3]:
import gym
import random
env = gym.make('LunarLander-v2')
# env = StepAPICompatibility(env, new_step_api=True)
fix(env, seed)

In [29]:
class PolicyGradientNetwork(nn.Module):

    def __init__(self, hidden_dim=16):
        super().__init__()
        self.fc1 = nn.Linear(8, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 4)

    def forward(self, state):
        hid = torch.tanh(self.fc1(state))
        hid = torch.tanh(hid)
        return F.softmax(self.fc3(hid), dim=-1)

# Value Network
class ValueNet(nn.Module):
    def __init__(self, hidden_dim=16):
        super().__init__()

        self.hidden = nn.Linear(8, hidden_dim)
        self.output = nn.Linear(hidden_dim, 4)

    def forward(self, s):
        outs = self.hidden(s)
        outs = F.relu(outs)
        value = self.output(outs)
        return value
    
from torch.optim.lr_scheduler import StepLR
class PolicyGradientAgent():

    def __init__(self, network):
        self.network = network
        self.optimizer = optim.SGD(self.network.parameters(), lr=0.002)
        self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.9)
        self.rewards = None
        self.discounted_rewards = None

    def forward(self, state):
        return self.network(state)

    def learn(self, log_probs, rewards):
        loss = (-log_probs * rewards).sum() # You don't need to revise this to pass simple baseline (but you can)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def learn_A2C(self, log_probs, rewards, states, value_func):
        with torch.no_grad():
            values = value_func(states).squeeze()
        advantages = rewards - values
        loss = (-log_probs * advantages).sum()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def sample(self, state):
        # Convert state to a FloatTensor
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).float()  # Convert directly from NumPy array to tensor
        elif isinstance(state, list):
            state = torch.FloatTensor(state)
        else:
            raise ValueError(f"Unexpected state format: {type(state)}")

        # Get action probabilities and sample an action
        action_prob = self.network(state)
        action_dist = Categorical(action_prob)
        action = action_dist.sample()
        log_prob = action_dist.log_prob(action)
        return action.item(), log_prob


    # def sample(self, state):
    #     action_prob = self.network(torch.FloatTensor(state))
    #     action_dist = Categorical(action_prob)
    #     action = action_dist.sample()
    #     log_prob = action_dist.log_prob(action)
    #     return action.item(), log_prob

In [30]:
network = PolicyGradientNetwork()
agent = PolicyGradientAgent(network)

agent.network.train()  # Switch network into training mode
EPISODE_PER_BATCH = 5  # update the agent every 5 episodes
NUM_BATCH = 500        # totally update the agent for 500 episodes
gamma = 0.99            # Discount factor

avg_total_rewards, avg_final_rewards = [], []

prg_bar = tqdm(range(NUM_BATCH))
for batch in prg_bar:

    log_probs, rewards = [], []
    total_rewards, final_rewards = [], []

    # Update the training loop to extract the state correctly
    for episode in range(EPISODE_PER_BATCH):
        result = env.reset()
        if isinstance(result, tuple):
            state, _ = result  # If reset returns a tuple, extract the state
        else:
            state = result  # If it returns only the state

        total_reward, total_step = 0, 0
        episode_rewards = []  # Store episode-specific rewards

        while True:
            action, log_prob = agent.sample(state)  # Get action and log probability
            next_state, reward, done, truncated, _ = env.step(action)

            log_probs.append(log_prob)  # Store log probability
            episode_rewards.append(reward)  # Store immediate reward
            state = next_state
            total_reward += reward
            total_step += 1

            if done or truncated:
                final_rewards.append(reward)
                total_rewards.append(total_reward)
                break


        # Convert episode rewards to discounted cumulative rewards
        discounted_rewards = []
        cumulative_reward = 0

        # Compute cumulative decaying rewards for the episode
        for r in reversed(episode_rewards):
            cumulative_reward = r + gamma * cumulative_reward
            discounted_rewards.insert(0, cumulative_reward)  # Insert at the front to keep the right order

        rewards.extend(discounted_rewards)  # Extend the rewards list with the episode's discounted rewards

    print(f"Rewards shape: {np.shape(rewards)}")

    # Record training process
    avg_total_reward = sum(total_rewards) / len(total_rewards)
    avg_final_reward = sum(final_rewards) / len(final_rewards)
    avg_total_rewards.append(avg_total_reward)
    avg_final_rewards.append(avg_final_reward)
    prg_bar.set_description(f"Total: {avg_total_reward: 4.1f}, Final: {avg_final_reward: 4.1f}")

    # Update the agent using cumulative decaying rewards
    rewards = np.array(rewards)
    rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)  # Normalize the rewards
    agent.learn(torch.stack(log_probs), torch.from_numpy(rewards))

    # Print shapes for debugging
    print("Log probs size: ", torch.stack(log_probs).size())
    print("Rewards size: ", torch.from_numpy(rewards).size())

    # Print the AVG total reward achieved so far
    print(f"AVG Total Reward so far: {avg_total_reward:.2f}")


Total: -101.9, Final: -100.0:   0%|          | 1/500 [00:00<02:33,  3.26it/s]

Rewards shape: (437,)
Log probs size:  torch.Size([437])
Rewards size:  torch.Size([437])
AVG Total Reward so far: -101.91


Total: -373.0, Final: -100.0:   0%|          | 2/500 [00:00<02:31,  3.29it/s]

Rewards shape: (543,)
Log probs size:  torch.Size([543])
Rewards size:  torch.Size([543])
AVG Total Reward so far: -372.96


Total: -179.3, Final: -100.0:   1%|          | 3/500 [00:00<02:36,  3.17it/s]

Rewards shape: (507,)
Log probs size:  torch.Size([507])
Rewards size:  torch.Size([507])
AVG Total Reward so far: -179.27


Total: -161.4, Final: -100.0:   1%|          | 4/500 [00:01<02:14,  3.69it/s]

Rewards shape: (365,)
Log probs size:  torch.Size([365])
Rewards size:  torch.Size([365])
AVG Total Reward so far: -161.40


Total: -136.6, Final: -100.0:   1%|          | 5/500 [00:01<02:15,  3.65it/s]

Rewards shape: (437,)
Log probs size:  torch.Size([437])
Rewards size:  torch.Size([437])
AVG Total Reward so far: -136.64


Total: -142.8, Final: -100.0:   1%|          | 6/500 [00:01<02:13,  3.71it/s]

Rewards shape: (490,)
Log probs size:  torch.Size([490])
Rewards size:  torch.Size([490])
AVG Total Reward so far: -142.78


Total: -197.7, Final: -100.0:   1%|▏         | 7/500 [00:01<02:08,  3.85it/s]

Rewards shape: (446,)
Log probs size:  torch.Size([446])
Rewards size:  torch.Size([446])
AVG Total Reward so far: -197.67


Total: -134.8, Final: -100.0:   2%|▏         | 8/500 [00:02<02:02,  4.03it/s]

Rewards shape: (416,)
Log probs size:  torch.Size([416])
Rewards size:  torch.Size([416])
AVG Total Reward so far: -134.78


Total: -288.9, Final: -100.0:   2%|▏         | 9/500 [00:02<02:11,  3.73it/s]

Rewards shape: (585,)
Log probs size:  torch.Size([585])
Rewards size:  torch.Size([585])
AVG Total Reward so far: -288.86


Total: -234.6, Final: -100.0:   2%|▏         | 10/500 [00:02<02:08,  3.82it/s]

Rewards shape: (464,)
Log probs size:  torch.Size([464])
Rewards size:  torch.Size([464])
AVG Total Reward so far: -234.62


Total: -172.3, Final: -100.0:   2%|▏         | 11/500 [00:02<02:11,  3.73it/s]

Rewards shape: (480,)
Log probs size:  torch.Size([480])
Rewards size:  torch.Size([480])
AVG Total Reward so far: -172.28


Total: -148.4, Final: -100.0:   2%|▏         | 12/500 [00:03<02:11,  3.71it/s]

Rewards shape: (509,)
Log probs size:  torch.Size([509])
Rewards size:  torch.Size([509])
AVG Total Reward so far: -148.41


Total: -154.0, Final: -100.0:   3%|▎         | 13/500 [00:03<02:07,  3.81it/s]

Rewards shape: (459,)
Log probs size:  torch.Size([459])
Rewards size:  torch.Size([459])
AVG Total Reward so far: -153.99


Total: -183.2, Final: -100.0:   3%|▎         | 14/500 [00:03<02:02,  3.98it/s]

Rewards shape: (426,)
Log probs size:  torch.Size([426])
Rewards size:  torch.Size([426])
AVG Total Reward so far: -183.18


Total: -198.0, Final: -100.0:   3%|▎         | 15/500 [00:03<02:00,  4.04it/s]

Rewards shape: (455,)
Log probs size:  torch.Size([455])
Rewards size:  torch.Size([455])
AVG Total Reward so far: -198.04


Total: -222.1, Final: -100.0:   3%|▎         | 16/500 [00:04<02:02,  3.96it/s]

Rewards shape: (482,)
Log probs size:  torch.Size([482])
Rewards size:  torch.Size([482])
AVG Total Reward so far: -222.15


Total: -164.7, Final: -100.0:   3%|▎         | 17/500 [00:04<02:03,  3.91it/s]

Rewards shape: (487,)
Log probs size:  torch.Size([487])
Rewards size:  torch.Size([487])
AVG Total Reward so far: -164.74


Total: -109.6, Final: -100.0:   4%|▎         | 18/500 [00:04<01:58,  4.06it/s]

Rewards shape: (421,)
Log probs size:  torch.Size([421])
Rewards size:  torch.Size([421])
AVG Total Reward so far: -109.63


Total: -97.4, Final: -100.0:   4%|▍         | 19/500 [00:04<02:00,  3.98it/s] 

Rewards shape: (455,)
Log probs size:  torch.Size([455])
Rewards size:  torch.Size([455])
AVG Total Reward so far: -97.45


Total: -210.9, Final: -100.0:   4%|▍         | 20/500 [00:05<01:58,  4.05it/s]

Rewards shape: (432,)
Log probs size:  torch.Size([432])
Rewards size:  torch.Size([432])
AVG Total Reward so far: -210.92


Total: -241.0, Final: -100.0:   4%|▍         | 21/500 [00:05<01:57,  4.08it/s]

Rewards shape: (462,)
Log probs size:  torch.Size([462])
Rewards size:  torch.Size([462])
AVG Total Reward so far: -241.01


Total: -178.0, Final: -100.0:   4%|▍         | 22/500 [00:05<02:00,  3.95it/s]

Rewards shape: (532,)
Log probs size:  torch.Size([532])
Rewards size:  torch.Size([532])
AVG Total Reward so far: -177.97


Total: -119.8, Final: -100.0:   5%|▍         | 23/500 [00:06<02:03,  3.86it/s]

Rewards shape: (530,)
Log probs size:  torch.Size([530])
Rewards size:  torch.Size([530])
AVG Total Reward so far: -119.77


Total: -150.3, Final: -100.0:   5%|▍         | 24/500 [00:06<02:01,  3.91it/s]

Rewards shape: (465,)
Log probs size:  torch.Size([465])
Rewards size:  torch.Size([465])
AVG Total Reward so far: -150.33


Total: -149.3, Final: -100.0:   5%|▌         | 25/500 [00:06<02:00,  3.95it/s]

Rewards shape: (470,)
Log probs size:  torch.Size([470])
Rewards size:  torch.Size([470])
AVG Total Reward so far: -149.26


Total: -115.4, Final: -100.0:   5%|▌         | 26/500 [00:06<02:12,  3.57it/s]

Rewards shape: (561,)
Log probs size:  torch.Size([561])
Rewards size:  torch.Size([561])
AVG Total Reward so far: -115.43


Total: -122.2, Final: -100.0:   5%|▌         | 27/500 [00:07<02:09,  3.65it/s]

Rewards shape: (478,)
Log probs size:  torch.Size([478])
Rewards size:  torch.Size([478])
AVG Total Reward so far: -122.15


Total: -84.5, Final: -100.0:   6%|▌         | 28/500 [00:07<02:26,  3.21it/s] 

Rewards shape: (513,)
Log probs size:  torch.Size([513])
Rewards size:  torch.Size([513])
AVG Total Reward so far: -84.50


Total: -170.6, Final: -100.0:   6%|▌         | 29/500 [00:07<02:24,  3.25it/s]

Rewards shape: (439,)
Log probs size:  torch.Size([439])
Rewards size:  torch.Size([439])
AVG Total Reward so far: -170.64


Total: -133.5, Final: -100.0:   6%|▌         | 30/500 [00:08<02:23,  3.27it/s]

Rewards shape: (556,)
Log probs size:  torch.Size([556])
Rewards size:  torch.Size([556])
AVG Total Reward so far: -133.45


Total: -238.6, Final: -100.0:   6%|▌         | 31/500 [00:08<02:27,  3.17it/s]

Rewards shape: (613,)
Log probs size:  torch.Size([613])
Rewards size:  torch.Size([613])
AVG Total Reward so far: -238.62


Total: -221.5, Final: -100.0:   6%|▋         | 32/500 [00:08<02:21,  3.30it/s]

Rewards shape: (515,)
Log probs size:  torch.Size([515])
Rewards size:  torch.Size([515])
AVG Total Reward so far: -221.52


Total: -149.3, Final: -100.0:   7%|▋         | 33/500 [00:09<02:22,  3.29it/s]

Rewards shape: (508,)
Log probs size:  torch.Size([508])
Rewards size:  torch.Size([508])
AVG Total Reward so far: -149.31


Total: -139.5, Final: -100.0:   7%|▋         | 34/500 [00:09<02:17,  3.40it/s]

Rewards shape: (496,)
Log probs size:  torch.Size([496])
Rewards size:  torch.Size([496])
AVG Total Reward so far: -139.53


Total: -162.8, Final: -100.0:   7%|▋         | 35/500 [00:09<02:12,  3.52it/s]

Rewards shape: (476,)
Log probs size:  torch.Size([476])
Rewards size:  torch.Size([476])
AVG Total Reward so far: -162.81


Total: -177.8, Final: -100.0:   7%|▋         | 36/500 [00:09<02:05,  3.69it/s]

Rewards shape: (449,)
Log probs size:  torch.Size([449])
Rewards size:  torch.Size([449])
AVG Total Reward so far: -177.77


Total: -78.5, Final: -80.7:   7%|▋         | 36/500 [00:10<02:05,  3.69it/s]  

Rewards shape: (1390,)


Total: -143.2, Final: -100.0:   7%|▋         | 37/500 [00:11<04:25,  1.74it/s]

Log probs size:  torch.Size([1390])
Rewards size:  torch.Size([1390])
AVG Total Reward so far: -78.53
Rewards shape: (497,)


Total: -143.8, Final: -100.0:   8%|▊         | 38/500 [00:11<03:43,  2.06it/s]

Log probs size:  torch.Size([497])
Rewards size:  torch.Size([497])
AVG Total Reward so far: -143.22
Rewards shape: (478,)


Total: -143.8, Final: -100.0:   8%|▊         | 39/500 [00:11<03:11,  2.41it/s]

Log probs size:  torch.Size([478])
Rewards size:  torch.Size([478])
AVG Total Reward so far: -143.79


Total: -191.6, Final: -100.0:   8%|▊         | 40/500 [00:11<02:59,  2.56it/s]

Rewards shape: (560,)
Log probs size:  torch.Size([560])
Rewards size:  torch.Size([560])
AVG Total Reward so far: -191.59


Total: -167.3, Final: -100.0:   8%|▊         | 41/500 [00:12<02:43,  2.80it/s]

Rewards shape: (515,)
Log probs size:  torch.Size([515])
Rewards size:  torch.Size([515])
AVG Total Reward so far: -167.31


Total: -136.3, Final: -100.0:   8%|▊         | 42/500 [00:12<02:31,  3.01it/s]

Rewards shape: (508,)
Log probs size:  torch.Size([508])
Rewards size:  torch.Size([508])
AVG Total Reward so far: -136.32


Total: -199.6, Final: -100.0:   9%|▊         | 43/500 [00:12<02:23,  3.18it/s]

Rewards shape: (506,)
Log probs size:  torch.Size([506])
Rewards size:  torch.Size([506])
AVG Total Reward so far: -199.63


Total: -100.6, Final: -100.0:   9%|▉         | 44/500 [00:13<02:13,  3.41it/s]

Rewards shape: (449,)
Log probs size:  torch.Size([449])
Rewards size:  torch.Size([449])
AVG Total Reward so far: -100.65


Total: -77.1, Final: -100.0:   9%|▉         | 45/500 [00:13<02:07,  3.58it/s] 

Rewards shape: (447,)
Log probs size:  torch.Size([447])
Rewards size:  torch.Size([447])
AVG Total Reward so far: -77.12


Total: -100.1, Final: -100.0:   9%|▉         | 46/500 [00:13<02:05,  3.61it/s]

Rewards shape: (509,)
Log probs size:  torch.Size([509])
Rewards size:  torch.Size([509])
AVG Total Reward so far: -100.06


Total: -93.2, Final: -100.0:   9%|▉         | 47/500 [00:13<02:08,  3.53it/s] 

Rewards shape: (480,)
Log probs size:  torch.Size([480])
Rewards size:  torch.Size([480])
AVG Total Reward so far: -93.22


Total: -61.1, Final: -80.0:   9%|▉         | 47/500 [00:15<02:08,  3.53it/s] 

Rewards shape: (1486,)


Total: -180.9, Final: -100.0:  10%|▉         | 48/500 [00:15<05:09,  1.46it/s]

Log probs size:  torch.Size([1486])
Rewards size:  torch.Size([1486])
AVG Total Reward so far: -61.13
Rewards shape: (449,)


Total: -180.9, Final: -100.0:  10%|▉         | 49/500 [00:15<04:14,  1.77it/s]

Log probs size:  torch.Size([449])
Rewards size:  torch.Size([449])
AVG Total Reward so far: -180.86


Total: -149.5, Final: -100.0:  10%|█         | 50/500 [00:16<03:40,  2.05it/s]

Rewards shape: (561,)
Log probs size:  torch.Size([561])
Rewards size:  torch.Size([561])
AVG Total Reward so far: -149.51


Total: -102.3, Final: -100.0:  10%|█         | 51/500 [00:16<03:09,  2.37it/s]

Rewards shape: (487,)
Log probs size:  torch.Size([487])
Rewards size:  torch.Size([487])
AVG Total Reward so far: -102.25


Total: -88.6, Final: -100.0:  10%|█         | 52/500 [00:16<03:05,  2.41it/s] 

Rewards shape: (700,)
Log probs size:  torch.Size([700])
Rewards size:  torch.Size([700])
AVG Total Reward so far: -88.59


Total: -201.9, Final: -100.0:  11%|█         | 53/500 [00:17<03:00,  2.47it/s]

Rewards shape: (594,)
Log probs size:  torch.Size([594])
Rewards size:  torch.Size([594])
AVG Total Reward so far: -201.90


Total: -184.1, Final: -100.0:  11%|█         | 54/500 [00:17<03:12,  2.31it/s]

Rewards shape: (840,)
Log probs size:  torch.Size([840])
Rewards size:  torch.Size([840])
AVG Total Reward so far: -184.06


Total: -96.0, Final: -100.0:  11%|█         | 55/500 [00:17<03:02,  2.44it/s] 

Rewards shape: (562,)
Log probs size:  torch.Size([562])
Rewards size:  torch.Size([562])
AVG Total Reward so far: -96.00


Total: -212.8, Final: -100.0:  11%|█         | 56/500 [00:18<03:04,  2.41it/s]

Rewards shape: (741,)
Log probs size:  torch.Size([741])
Rewards size:  torch.Size([741])
AVG Total Reward so far: -212.79


Total: -191.4, Final: -100.0:  11%|█▏        | 57/500 [00:18<02:57,  2.50it/s]

Rewards shape: (657,)
Log probs size:  torch.Size([657])
Rewards size:  torch.Size([657])
AVG Total Reward so far: -191.42


Total: -139.7, Final: -100.0:  12%|█▏        | 58/500 [00:19<02:45,  2.66it/s]

Rewards shape: (576,)
Log probs size:  torch.Size([576])
Rewards size:  torch.Size([576])
AVG Total Reward so far: -139.75


Total: -181.5, Final: -100.0:  12%|█▏        | 59/500 [00:19<02:53,  2.55it/s]

Rewards shape: (732,)
Log probs size:  torch.Size([732])
Rewards size:  torch.Size([732])
AVG Total Reward so far: -181.52


Total: -255.7, Final: -100.0:  12%|█▏        | 60/500 [00:19<02:53,  2.54it/s]

Rewards shape: (694,)
Log probs size:  torch.Size([694])
Rewards size:  torch.Size([694])
AVG Total Reward so far: -255.69


Total: -145.4, Final: -100.0:  12%|█▏        | 61/500 [00:20<02:53,  2.52it/s]

Rewards shape: (704,)
Log probs size:  torch.Size([704])
Rewards size:  torch.Size([704])
AVG Total Reward so far: -145.44


Total: -169.6, Final: -100.0:  12%|█▏        | 62/500 [00:20<02:45,  2.65it/s]

Rewards shape: (553,)
Log probs size:  torch.Size([553])
Rewards size:  torch.Size([553])
AVG Total Reward so far: -169.60


Total: -259.7, Final: -100.0:  13%|█▎        | 63/500 [00:21<02:49,  2.58it/s]

Rewards shape: (739,)
Log probs size:  torch.Size([739])
Rewards size:  torch.Size([739])
AVG Total Reward so far: -259.69


Total: -148.0, Final: -100.0:  13%|█▎        | 64/500 [00:21<02:36,  2.79it/s]

Rewards shape: (531,)
Log probs size:  torch.Size([531])
Rewards size:  torch.Size([531])
AVG Total Reward so far: -147.95


Total: -108.5, Final: -100.0:  13%|█▎        | 65/500 [00:21<02:51,  2.53it/s]

Rewards shape: (839,)
Log probs size:  torch.Size([839])
Rewards size:  torch.Size([839])
AVG Total Reward so far: -108.45


Total: -86.0, Final: -100.0:  13%|█▎        | 66/500 [00:22<02:40,  2.70it/s] 

Rewards shape: (566,)
Log probs size:  torch.Size([566])
Rewards size:  torch.Size([566])
AVG Total Reward so far: -86.02


Total: -102.5, Final: -100.0:  13%|█▎        | 67/500 [00:22<02:27,  2.94it/s]

Rewards shape: (486,)
Log probs size:  torch.Size([486])
Rewards size:  torch.Size([486])
AVG Total Reward so far: -102.52


Total: -102.1, Final: -100.0:  14%|█▎        | 68/500 [00:22<02:31,  2.85it/s]

Rewards shape: (659,)
Log probs size:  torch.Size([659])
Rewards size:  torch.Size([659])
AVG Total Reward so far: -102.07


Total: -169.0, Final: -100.0:  14%|█▍        | 69/500 [00:23<02:32,  2.84it/s]

Rewards shape: (527,)
Log probs size:  torch.Size([527])
Rewards size:  torch.Size([527])
AVG Total Reward so far: -168.96


Total: -273.3, Final: -100.0:  14%|█▍        | 70/500 [00:23<02:43,  2.63it/s]

Rewards shape: (766,)
Log probs size:  torch.Size([766])
Rewards size:  torch.Size([766])
AVG Total Reward so far: -273.30


Total: -161.0, Final: -100.0:  14%|█▍        | 70/500 [00:23<02:26,  2.93it/s]

Rewards shape: (580,)





KeyboardInterrupt: 

In [12]:
env.step(action)

(array([-0.01233473,  1.3889593 , -0.41315836, -0.31314465,  0.01414089,
         0.09668748,  0.        ,  0.        ], dtype=float32),
 0.35363735044923034,
 False,
 False,
 {})

In [11]:
action

2

In [17]:
network = PolicyGradientNetwork()
agent = PolicyGradientAgent(network)

agent.network.train()  # Switch network into training mode
EPISODE_PER_BATCH = 5  # update the agent every 5 episodes
NUM_BATCH = 500        # totally update the agent for 500 episodes
gamma = 0.99           # Discount factor

avg_total_rewards, avg_final_rewards = [], []

prg_bar = tqdm(range(NUM_BATCH))
for batch in prg_bar:
    log_probs, rewards = [], []
    total_rewards, final_rewards = [], []

    # Collect trajectory
    for episode in range(EPISODE_PER_BATCH):
        state = env.reset()
        total_reward, total_step = 0, 0
        episode_rewards = []  # Store episode-specific rewards

        while True:
            action, log_prob = agent.sample(state)  # Get action and log probability
            next_state, reward, done, truncated, _ = env.step(action)

            log_probs.append(log_prob)  # Store log probability
            episode_rewards.append(reward)  # Store immediate reward
            state = next_state
            total_reward += reward
            total_step += 1

            if done or truncated:
                final_rewards.append(reward)
                total_rewards.append(total_reward)
                break

        # Convert episode rewards to discounted cumulative rewards
        discounted_rewards = []
        cumulative_reward = 0

        # Compute cumulative decaying rewards for the episode
        for r in reversed(episode_rewards):
            cumulative_reward = r + gamma * cumulative_reward
            discounted_rewards.insert(0, cumulative_reward)  # Insert at the front to keep the right order

        rewards.extend(discounted_rewards)  # Extend the rewards list with the episode's discounted rewards

    # Record training process
    avg_total_reward = sum(total_rewards) / len(total_rewards)
    avg_final_reward = sum(final_rewards) / len(final_rewards)
    avg_total_rewards.append(avg_total_reward)
    avg_final_rewards.append(avg_final_reward)
    prg_bar.set_description(f"Total: {avg_total_reward: 4.1f}, Final: {avg_final_reward: 4.1f}")

    # Update the agent using cumulative decaying rewards
    rewards = np.array(rewards)
    rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)  # Normalize the rewards
    agent.learn(torch.stack(log_probs), torch.from_numpy(rewards).float())

    # Print shapes less frequently to avoid slowing down training
    if batch % 10 == 0:
        print(f"Batch {batch}: Log probs size: {torch.stack(log_probs).size()}, Rewards size: {torch.from_numpy(rewards).size()}")
        print(f"AVG Total Reward so far: {avg_total_reward:.2f}")

plt.plot(avg_total_rewards)
plt.xlabel('Batch')
plt.ylabel('Average Total Reward')
plt.title('Policy Gradient Training Progress')
plt.show()


  0%|          | 0/500 [00:00<?, ?it/s]


ValueError: Unexpected state format: <class 'tuple'>