In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import gymnasium as gym
from itertools import count
from collections import Counter
import numpy as np
from torch.distributions import Categorical
from torch import optim

In [2]:
class Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(4, 128)
        self.layer2 = nn.Linear(128, 128)

    def forward(self, x) -> torch.Tensor:
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return x

In [3]:
class Actor(nn.Module):
    def __init__(self, base: nn.Module):
        super().__init__()
        self.base = base
        self.linear = nn.Linear(128, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.base(x)
        x = self.linear(x)
        return F.softmax(x, dim=1)

In [4]:
class Critic(nn.Module):
    def __init__(self, base: nn.Module):
        super().__init__()
        self.base = base
        self.linear = nn.Linear(128, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.base(x)
        return self.linear(x)

In [5]:
# backbone = Backbone()
# actor = Actor(backbone)
# critic = Critic(backbone)

# actor.load_state_dict(torch.load("../checkpoints/actor_1000.pt"))

In [8]:
backbone = Backbone()
actor = Actor(backbone)
critic = Critic(backbone)

env = gym.make("CartPole-v1")
lr = 0.0001
# lr = 0.0002
actor_optimizer = optim.Adam(actor.parameters(), lr=lr)
critic_optimizer = optim.Adam(critic.parameters(), lr=lr)

t = 0
state, _ = env.reset()
state = torch.Tensor(np.array(state))

while t < 100_000:
    actor_optimizer.zero_grad()
    critic_optimizer.zero_grad()
    
    done = False
    t_start = t
    rewards = []
    states = []
    log_probs = []
    while not done and t - t_start != 20:
        state = state.unsqueeze(0)
        probs = actor(state)
        m = Categorical(probs)
        action = m.sample()
        action, log_prob = action.item(), m.log_prob(action)
        
        state, reward, terminated, truncated, info = env.step(action)
        state = torch.Tensor(np.array(state))

        rewards.append(reward)
        states.append(state)
        log_probs.append(log_prob)

        done = truncated or terminated
        t += 1
    
    r = 0 if done else critic(state)
    actor_loss = torch.tensor([0.0])
    critic_loss = torch.tensor([0.0])
    for ri, si, logi in zip(reversed(rewards), reversed(states), reversed(log_probs)):
        r = ri + 0.9 * r
        v = critic(si)
        tmp_diff = (r - v)
        actor_loss += logi * tmp_diff.clone().detach()
        critic_loss += tmp_diff ** 2
    actor_loss.backward(retain_graph=True)
    critic_loss.backward(retain_graph=True)
    
    actor_optimizer.step()
    critic_optimizer.step()
    if done:
        state, _ = env.reset()
        state = torch.Tensor(np.array(state))
env.close()

In [27]:
env = gym.make('CartPole-v1', render_mode="human")
state, _ = env.reset()
actions = []
rewards = 0

backbone = Backbone()
actor = Actor(backbone)
# actor.load_state_dict(torch.load("../checkpoints/actor_5000.pt"))
actor.load_state_dict(torch.load("../models/actor.pt"))

for _ in count():
    state = torch.tensor(np.array(state)).unsqueeze(0)
    action_probs = actor(state)
    print(action_probs)
    action = action_probs.argmax().item()
    actions.append(action)
    state, reward, terminated, truncated, _ = env.step(action)
    rewards += reward
    if terminated or truncated:
        break

print("Total reward:", rewards)
print(Counter(actions))
env.close()

tensor([[0.1335, 0.8665]], grad_fn=<SoftmaxBackward0>)
tensor([[0.2216, 0.7784]], grad_fn=<SoftmaxBackward0>)
tensor([[0.4320, 0.5680]], grad_fn=<SoftmaxBackward0>)
tensor([[0.7166, 0.2834]], grad_fn=<SoftmaxBackward0>)
tensor([[0.5208, 0.4792]], grad_fn=<SoftmaxBackward0>)
tensor([[0.3412, 0.6588]], grad_fn=<SoftmaxBackward0>)
tensor([[0.6081, 0.3919]], grad_fn=<SoftmaxBackward0>)
tensor([[0.4224, 0.5776]], grad_fn=<SoftmaxBackward0>)
tensor([[0.7018, 0.2982]], grad_fn=<SoftmaxBackward0>)
tensor([[0.5240, 0.4760]], grad_fn=<SoftmaxBackward0>)
tensor([[0.3756, 0.6244]], grad_fn=<SoftmaxBackward0>)
tensor([[0.6266, 0.3734]], grad_fn=<SoftmaxBackward0>)
tensor([[0.4720, 0.5280]], grad_fn=<SoftmaxBackward0>)
tensor([[0.7378, 0.2622]], grad_fn=<SoftmaxBackward0>)
tensor([[0.5893, 0.4107]], grad_fn=<SoftmaxBackward0>)
tensor([[0.4615, 0.5385]], grad_fn=<SoftmaxBackward0>)
tensor([[0.7009, 0.2991]], grad_fn=<SoftmaxBackward0>)
Total reward: 17.0
Counter({0: 9, 1: 8})


In [14]:
env.close()