In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np

# Define Policy Network
class Policy(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

# Discounted rewards function
def discount_rewards(rewards, gamma=0.99):
    discounted_rewards = []
    running_add = 0
    for r in reversed(rewards):
        running_add = running_add * gamma + r
        discounted_rewards.insert(0, running_add)
    return discounted_rewards

# Initialize the environment
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n

# Initialize the policy network
policy = Policy(input_dim, output_dim)

# Initialize the optimizer
optimizer = optim.Adam(policy.parameters(), lr=0.01)

# Training loop
for episode in range(3):
    state = env.reset()
    episode_reward = 0

    states = []
    actions = []
    rewards = []

    while True:
        env.render()  # Render the environment

        states.append(torch.tensor(state, dtype=torch.float32))

        action_probs = policy(torch.tensor(state, dtype=torch.float32))
        action = torch.multinomial(action_probs, num_samples=1).item()
        actions.append(action)

        # print("Action Probs: {}, Action: {}".format(action_probs, action))
        next_state, reward, done, _ = env.step(action)
        rewards.append(reward)

        episode_reward += reward
        state = next_state

        if done:
            # Compute discounted rewards
            discounted_rewards = torch.tensor(discount_rewards(rewards), dtype=torch.float32)
            
            # Compute policy loss
            log_probs = torch.log(policy(torch.stack(states)))[range(len(actions)), actions]
            policy_loss = -(log_probs * discounted_rewards).mean()

            print("States: {}, Actions: {}".format(policy(torch.stack(states)), [range(len(actions)), actions]))

            # Update policy
            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()

            print("Episode: {}, Reward: {}".format(episode, episode_reward))
            break

env.close()


States: tensor([[0.5124, 0.4876],
        [0.5200, 0.4800],
        [0.5131, 0.4869],
        [0.5090, 0.4910],
        [0.5044, 0.4956],
        [0.5011, 0.4989],
        [0.4973, 0.5027],
        [0.4936, 0.5064],
        [0.4962, 0.5038],
        [0.4928, 0.5072],
        [0.4952, 0.5048],
        [0.4977, 0.5023],
        [0.5011, 0.4989],
        [0.4966, 0.5034],
        [0.4994, 0.5006],
        [0.4950, 0.5050],
        [0.4975, 0.5025]], grad_fn=<SoftmaxBackward0>), Actions: range(0, 17)
Episode: 0, Reward: 17.0
States: tensor([[0.5894, 0.4106],
        [0.5868, 0.4132],
        [0.5809, 0.4191],
        [0.5719, 0.4281],
        [0.5658, 0.4342],
        [0.5685, 0.4315],
        [0.5750, 0.4250],
        [0.5816, 0.4184],
        [0.5715, 0.4285],
        [0.5782, 0.4218],
        [0.5673, 0.4327],
        [0.5588, 0.4412]], grad_fn=<SoftmaxBackward0>), Actions: range(0, 12)
Episode: 1, Reward: 12.0
States: tensor([[0.6629, 0.3371],
        [0.6636, 0.3364],
        [0.6635,