In [1]:
import random
from typing import Tuple

import numpy as np
import torch
from torch import optim, Tensor

from src.agents import Agent
from src.agents.NNAgent import NNAgent
from src.agents.RandomAgent import RandomAgent
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola as Briscola
from time import time
import gymnasium as gym
import wandb

from src.vectorizers.VectorizedEnv import VectorizedEnv
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola

from src.envs.two_player_briscola.BriscolaConstants import Constants

In [2]:
## TODO: Test briscola
env = Briscola()

In [3]:
params = {
    "n_envs" : 1024,
    "n_steps": Constants.deck_cards // 2,
    "lr": 2.5e-4,
    "mini_batch_size": 256,
    "total_timesteps": 10000,
    "gamma": 1.,
    "lambda": 0.95
}
params["batch_size"] = params["n_envs"] * params["n_steps"]
params["num_updates"] = params["total_timesteps"] // params["mini_batch_size"]

In [43]:
wandb.init(
    project="briscolaBot",
    entity="lettera",
    config=params,
    save_code=True,
)

NameError: name 'params' is not defined

In [4]:
def play_other_player_moves(envs: VectorizedEnv, policy: Agent):
    for _ in range(2):
        envs_to_play = [env for env in envs.envs
                        if env.agent_selection == env.agents[1]
                        and not env.terminations[env.agents[1]]]

        obs, action_mask = [], []
        for env in envs_to_play:
            observation = env.observe(env.agent_selection)
            obs.append(observation["observation"])
            action_mask.append(observation["action_mask"])

        obs, action_mask = np.array(obs), np.array(action_mask)
        actions = policy.get_action(torch.tensor(obs).to(device), torch.tensor(action_mask).to(device))

        [env.step(action) for env, action in zip(envs_to_play, actions)]

In [5]:
def get_state_representation(envs: VectorizedEnv) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    obs = np.empty((len(envs),) + envs.single_observation_space()["observation"].shape, dtype=np.float32)
    action_masks = np.empty((len(envs), envs.single_action_space().n), dtype=np.int8)
    rewards = np.empty(len(envs), dtype=np.float32)
    dones = np.empty(len(envs), dtype=np.int8)
    for i, (observation, reward, termination, _, _) in enumerate(envs.last()):
        obs[i] = observation["observation"]
        action_masks[i] = observation["action_mask"]
        rewards[i] = reward
        dones[i] = termination

    return torch.tensor(obs), torch.tensor(action_masks), torch.tensor(rewards), torch.tensor(dones)

In [44]:
envs = VectorizedEnv(lambda : TwoPlayerBriscola(), params["n_envs"])

In [45]:
device = "cpu"
observation_shape = envs.single_observation_space()["observation"].shape
action_size = envs.single_action_space().n

agent = NNAgent(observation_shape, action_size).to(device)

other_player = RandomAgent(Constants.hand_cards)

optimizer = optim.Adam(agent.parameters(), lr=params["lr"], eps=1e-5)

obs = torch.zeros((params["n_steps"], params["n_envs"]) + observation_shape).to(device)
actions = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int8).to(device)
logprobs = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
rewards = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
dones = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int8).to(device)
values = torch.zeros((params["n_steps"], params["n_envs"])).to(device)

global_step = 0
play_other_player_moves(envs, other_player)
next_obs, action_mask, reward, next_done = get_state_representation(envs)

for update in range(1, params["num_updates"] + 1):
    # Play episodes
    for step in range(params["n_steps"]):
        global_step += params["n_envs"]

        obs[step] = next_obs
        dones[step] = next_done

        with torch.no_grad():
            action, logprob, _, value = agent.get_action_and_value(next_obs.to(device), action_mask.to(device))
            values[step] = value.flatten()
        actions[step] = action
        logprobs[step] = logprob

        envs.step(actions[step].cpu().numpy())
        play_other_player_moves(envs, other_player)
        next_obs, action_mask, reward, next_done = get_state_representation(envs)
        rewards[step] = reward.to(device)

    # Bootstrap objective
    with torch.no_grad():
        next_value = agent.get_value(next_obs).reshape(1, -1)
        advantages = torch.zeros_like(rewards).to(device)
        last_gae_lambda = 0

        for t in reversed(range(params["n_steps"])):
            if t == params["n_steps"] - 1:
                next_non_terminal = 1. - next_done
                next_values = next_value
            else:
                next_non_terminal = 1. - dones[t+1]
                next_values = values[t+1]

            delta = rewards[t] + params["gamma"] * next_values * next_non_terminal - values[t]
            last_gae_lambda = delta + params["gamma"] * params["lambda"] * next_non_terminal * last_gae_lambda
            advantages[t] = last_gae_lambda

        returns = advantages + values

    break