In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os
import time

####################################################################################

# env = gym_codes.make("LunarLander-v2",render_mode="human")
env = gym.make("LunarLander-v3")


####################################################################################


####################################################################################
### Networks
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(env.observation_space.shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.n),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)


class ValueNetwork(nn.Module):
    def __init__(self):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(env.observation_space.shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.fc(x)


policy_net = PolicyNetwork()
value_net = ValueNetwork()

policy_optimizer = optim.Adam(policy_net.parameters(), lr=P.LR)
value_optimizer = optim.Adam(value_net.parameters(), lr=P.LRV)

sum_episode_rewards = []
total_episode_rewards_plots = []
policy_loss_plot = []
episode_reward_plot = []

####################################################################################
start_time = time.time()

try:
    for episode in range(P.NUM_EPISODES):

        state = env.reset()[0]
        episode_reward = []
        collect_episode = []

        while True:

            ####################################################################################
            # Collect Episode
            state_tensor = torch.from_numpy(state).float().unsqueeze(0)

            action_probs = policy_net(state_tensor)
            action = torch.multinomial(action_probs, 1).item()
            action_prob = policy_net(state_tensor)[0, action]

            next_state, reward, terminated, truncated, info = env.step(action)
            collect_episode.append((state, action, reward, action_prob))

            state = next_state

            if terminated or truncated:
                break
            ####################################################################################

        ####################################################################################
        # Compute returns
        returns = []
        G = 0
        for _, _, reward, _ in reversed(collect_episode):
            G = reward + P.GAMMA * G
            returns.insert(0, G)
            episode_reward.append(reward)

        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-7)  # Normalize
        ####################################################################################

        ####################################################################################
        # Policy gradient update for each collected episode (MC)
        policy_loss = []
        value_loss = []
        for (state, action, _, action_prob), G in zip(collect_episode, returns):
            state_tensor = torch.from_numpy(state).float().unsqueeze(0)

            value_tensor = value_net(state_tensor)

            b = value_tensor.item()

            delta = G - b

            value_loss.append(-delta * value_tensor)
            policy_loss.append(-torch.log(action_prob) * delta)

        policy_optimizer.zero_grad()
        policy_loss = torch.stack(policy_loss).sum()
        policy_loss.backward()
        policy_optimizer.step()

        value_optimizer.zero_grad()
        value_loss = torch.stack(value_loss).sum()
        value_loss.backward()
        value_optimizer.step()

        sum_episode_rewards.append(sum(episode_reward))
        ####################################################################################

        ####################################################################################
        print(
            f"Episode: {episode + 1}/{P.NUM_EPISODES},"
            f"Timesteps: {len(episode_reward)}, "

            f"Episodic Reward: {sum(episode_reward)}, "
            f"Policy Loss: {policy_loss.item():.4f}"
        )

        episode_reward_plot.append(sum_episode_rewards[episode])

        policy_loss_plot.append(policy_loss.item())
        ####################################################################################

except KeyboardInterrupt:
    pass

env.close()




KeyboardInterrupt

