In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np

# Define Policy Network
class Policy(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

def discount_returns(rewards, gamma=0.99):
    discounted_returns = []
    running_add = 0
    for t, r in enumerate(reversed(rewards)):
        # print(t)
        running_add = running_add * gamma ** t + r
        discounted_returns.insert(0, running_add)
    return discounted_returns

# Initialize the environment
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n

# Initialize the policy network
policy = Policy(input_dim, output_dim)

# Initialize the optimizer
optimizer = optim.Adam(policy.parameters(), lr=0.01)

# Training loop
for episode in range(1):
    state = env.reset()
    episode_reward = 0

    states = []
    actions = []
    rewards = []

    while True:
        # env.render()  # Render the environment

        states.append(torch.tensor(state, dtype=torch.float32))

        action_probs = policy(torch.tensor(state, dtype=torch.float32))

        # take the action based on the probability. Higher probability will have a higher chance, but not always
        action = torch.multinomial(action_probs, num_samples=1).item() 
        
        actions.append(action)

        # print("Action Probs: {}, Action: {}".format(action_probs, action))
        next_state, reward, done, _ = env.step(action)
        rewards.append(reward)

        episode_reward += reward
        state = next_state

        if done:
            # Compute discounted rewards
            discounted_returns = torch.tensor(discount_returns(rewards), dtype=torch.float32)
            print("Discounted Returns: {}".format(discounted_returns))

            # Compute policy loss, resulting the log value of the choosen probability of each state that choosen by action in each timestep 
            log_probs = torch.log(policy(torch.stack(states)))[range(len(actions)), actions]
            
            # Compute policy loss
            policy_loss = -(log_probs * discounted_returns).mean()

            # print("States: {}".format(states))
            # print("Stacked States: {}".format(torch.stack(states)))
            # print("Policy: {}".format(policy(torch.stack(states))))
            # print("Log: {}".format(torch.log(policy(torch.stack(states)))))
            # print("Actions: {}".format(actions))
            # print("Len Actions: {}".format(len(actions)))
            # print("Range Actions: {}".format([range(len(actions)), actions]))
            print("Log Probs: {}".format(log_probs))
            # print("States: {}, Actions: {}".format(policy(torch.stack(states)), [range(len(actions)), actions]))
            
            print("Policy Loss: {}".format(policy_loss))

            # Update policy
            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()

            print("Episode: {}, Reward: {}".format(episode, episode_reward))
            break

env.close()


Discounted Returns: tensor([2.6934, 2.7432, 2.7957, 2.8511, 2.9097, 2.9719, 3.0378, 3.1080, 3.1830,
        3.2632, 3.3493, 3.4419, 3.5419, 3.6499, 3.7671, 3.8943, 4.0325, 4.1829,
        4.3464, 4.5240, 4.7165, 4.9244, 5.1478, 5.3864, 5.6393, 5.9049, 6.1804,
        6.4624, 6.7459, 7.0252, 7.2929, 7.5408, 7.7595, 7.9387, 8.0677, 8.1356,
        8.1315, 8.0456, 7.8692, 7.5955, 7.2199, 6.7406, 6.1590, 5.4797, 4.7106,
        3.8628, 2.9504, 1.9900, 1.0000])
Log Probs: tensor([-0.7105, -0.6999, -0.6963, -0.6919, -0.6892, -0.6848, -0.6737, -0.7185,
        -0.6743, -0.7181, -0.6747, -0.7180, -0.6749, -0.7181, -0.7117, -0.7006,
        -0.6966, -0.6922, -0.6889, -0.7015, -0.6976, -0.6924, -0.6886, -0.7011,
        -0.6891, -0.6856, -0.6764, -0.7157, -0.7088, -0.6861, -0.7080, -0.6999,
        -0.6941, -0.6871, -0.7068, -0.6884, -0.7061, -0.6901, -0.7046, -0.6921,
        -0.6828, -0.6765, -0.6775, -0.6671, -0.7350, -0.7150, -0.7077, -0.7060,
        -0.6922], grad_fn=<IndexBackward0>)
Poli