In [6]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#import rl_utils
import random
import numpy as np

class Critic(nn.Module):

    def __init__(self, state_dim, hidden_dim):
        super(Critic, self).__init__()
        self.Linear1 = nn.Linear(state_dim, hidden_dim)
        self.Linear2 = nn.Linear(hidden_dim, 1)

    def forward(self, states):
        out = F.relu(self.Linear1(states))
        out = self.Linear2(out)
        return out


class Actor(nn.Module):

    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Actor, self).__init__()
        self.Linear1 = nn.Linear(state_dim, hidden_dim)
        self.Linear2 = nn.Linear(hidden_dim, action_dim)

    def forward(self, states):
        out = F.relu(self.Linear1(states))
        out = F.softmax(self.Linear2(out), dim=1)
        return out


class ActorCritic:

    def __init__(self, state_dim, hidden_dim, action_dim, gamma):
        self.actor = Actor(state_dim, hidden_dim, action_dim)
        self.critic = Critic(state_dim, hidden_dim)
        self.action_dim = action_dim
        self.gamma = gamma
        self.actor_optimizer = optim.Adam(params=self.actor.parameters(), lr=2e-3)
        self.critic_optimizer = optim.Adam(params=self.critic.parameters(), lr=2e-3)

    def update(self, data):
        states = torch.tensor(data['states'], dtype=torch.float)
        actions = torch.tensor(data['actions'], dtype=torch.long).view(-1, 1)
        next_states = torch.tensor(data['next_states'], dtype=torch.float)
        rewards = torch.tensor(data['rewards'], dtype=torch.float)
        dones = torch.tensor(data['done'], dtype=torch.long)
        td_target = rewards + (self.gamma * self.critic(next_states).squeeze()) * (1 - dones)
        td_delta = td_target - self.critic(states).squeeze()
        log_probs = torch.log(self.actor(states).gather(1, actions)).squeeze()
        actor_loss = torch.mean(td_delta.detach() * -log_probs)
        critic_loss = torch.mean(F.mse_loss(self.critic(states).squeeze(), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()
        critic_loss.backward()
        self.actor_optimizer.step()
        self.critic_optimizer.step()

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        return action_dist.sample().item()



lr = 2e-3
num_episodes = 5000
hidden_dim = 128
gamma = 0.98

env_name = 'CartPole-v1'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
#env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, gamma)

for i in range(num_episodes):
    data = {
        'states': [],
        'actions': [],
        'next_states': [],
        'rewards': [],
        'done': [],
    }
    done = 0
    state = env.reset()
    G = 0
    while not done:
        action = agent.take_action(state)
        next_state, reward, done, _ = env.step(action)
        data['states'].append(state)
        data['actions'].append(action)
        data['next_states'].append(next_state)
        data['rewards'].append(reward)
        data['done'].append(done)
        state = next_state
        G += reward
    agent.update(data)
    if i % 10 == 0:
        print(G)

ValueError: expected sequence of length 4 at dim 2 (got 0)