In [1]:
import random
from typing import Tuple

import pylab as pl

import numpy as np
import torch
from torch import optim, Tensor, nn

from src.agents import Agent
from src.agents.NNAgent import NNAgent
from src.agents.RandomAgent import RandomAgent
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola as Briscola
from time import time
import gymnasium as gym
import wandb
from src.utils.training_utils import play_all_moves_of_player, get_state_representation

from src.vectorizers.VectorizedEnv import VectorizedEnv
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola

from src.envs.two_player_briscola.BriscolaConstants import Constants

In [2]:
params = {
    "n_envs" : 2048,
    "n_steps": Constants.deck_cards // 2,
    "lr": 3e-3,
    "mini_batch_size": 1024,
    "total_timesteps": 10_000_000,
    "gamma": 1.,
    "lambda": 0.9,
    "update_epochs": 2,
    "clip_coef": 0.3,
    "normalize_advantage": True,
    "clip_value_loss": True,
    "value_coef": 0.5,
    "entropy_coef": 1e-2,
    "max_grad_norm": 0.5,
    "reward_for_win": 0.1
}
params["batch_size"] = params["n_envs"] * params["n_steps"]
params["num_updates"] = params["total_timesteps"] // params["batch_size"]

In [3]:
run = wandb.init(
        name=None,
        project="briscolaBot",
        entity="lettera",
        config=params,
        save_code=True,
        sync_tensorboard=False,
        mode="online"
)

[34m[1mwandb[0m: Currently logged in as: [33mlettera[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
vec_env = VectorizedEnv(lambda: TwoPlayerBriscola(reward_for_win=params["reward_for_win"]), params["n_envs"])

In [5]:
device = "cpu"
observation_shape = vec_env.single_observation_space()["observation"].shape
action_size = vec_env.single_action_space().n

agent = NNAgent(observation_shape, action_size).to(device)

player = vec_env[0].agents[0]
other_player = vec_env[0].agents[1]
other_player_policy = RandomAgent(action_size)

optimizer = optim.Adam(agent.parameters(), lr=params["lr"], eps=1e-5)

obs = torch.zeros((params["n_steps"], params["n_envs"]) + observation_shape).to(device)
actions = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int8).to(device)
actions_masks = torch.zeros((params["n_steps"], params["n_envs"]) + (action_size,), dtype=torch.int8).to(device)
logprobs = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
rewards = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
dones = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int8).to(device)
values = torch.zeros((params["n_steps"], params["n_envs"])).to(device)

global_step = 0
start_time = time()
for update in range(1, params["num_updates"] + 1):
    # Play episodes
    vec_env.reset()
    play_all_moves_of_player(vec_env, other_player_policy, other_player)
    next_obs, action_mask, reward, next_done = get_state_representation(vec_env)
    for step in range(params["n_steps"]):
        global_step += params["n_envs"]

        obs[step] = next_obs
        dones[step] = next_done

        with torch.no_grad():
            action, logprob, _, value = agent.get_action_and_value(next_obs.to(device), action_mask.to(device))
            values[step] = value.flatten()
        actions[step] = action
        actions_masks[step] = action_mask.to(device)
        logprobs[step] = logprob

        vec_env.step(actions[step].cpu().numpy())
        play_all_moves_of_player(vec_env, other_player_policy, other_player)
        next_obs, action_mask, reward, next_done = get_state_representation(vec_env)
        rewards[step] = reward.to(device)

    # Bootstrap value
    with torch.no_grad():
        next_value = agent.get_value(next_obs).reshape(1, -1)
        advantages = torch.zeros_like(rewards).to(device)
        last_gae_lambda = 0

        for t in reversed(range(params["n_steps"])):
            if t == params["n_steps"] - 1:
                next_non_terminal = 1. - next_done
                next_values = next_value
            else:
                next_non_terminal = 1. - dones[t+1]
                next_values = values[t+1]

            delta = rewards[t] + params["gamma"] * next_values * next_non_terminal - values[t]
            last_gae_lambda = delta + params["gamma"] * params["lambda"] * next_non_terminal * last_gae_lambda
            advantages[t] = last_gae_lambda

        returns = advantages + values

    # Optimize net
    b_obs = obs.reshape((-1,) + observation_shape)
    b_logprobs = logprobs.reshape(-1)
    b_actions = actions.reshape(-1).long()
    b_action_masks = actions_masks.reshape((-1, action_size))
    b_advantages = advantages.reshape(-1)
    b_returns = returns.reshape(-1)
    b_values = values.reshape(-1)

    clip_fraction = []
    b_indexes = np.arange(params["batch_size"])
    for epoch in range(params["update_epochs"]):
        np.random.shuffle(b_indexes)
        for start in range(0, params["batch_size"], params["mini_batch_size"]):
            end = start + params["mini_batch_size"]
            mb_indexes = b_indexes[start:end]

            _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_indexes], b_action_masks[mb_indexes], b_actions[mb_indexes])
            logratio = newlogprob - b_logprobs[mb_indexes]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl http://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clip_fraction += [((ratio - 1.0).abs() > params["clip_coef"]).float().mean().item()]

            mb_advantages = b_advantages[mb_indexes]
            if params["normalize_advantage"]:
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

            # Policy loss
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-params["clip_coef"], 1+params["clip_coef"])
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value loss
            newvalue = newvalue.view(-1)
            if params["clip_value_loss"]:
                v_loss_unclipped = (newvalue - b_returns[mb_indexes]) ** 2
                v_clipped = b_values[mb_indexes] + torch.clamp(
                    newvalue - b_values[mb_indexes],
                    -params["clip_coef"],
                    params["clip_coef"],
                )
                v_loss_clipped = (v_clipped - b_returns[mb_indexes]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                value_loss = 0.5 * v_loss_max.mean()
            else:
                value_loss = 0.5 * ((newvalue - b_returns[mb_indexes]) ** 2).mean()

            entropy_loss = entropy.mean()
            loss = pg_loss - params["entropy_coef"] * entropy_loss + params["value_coef"] * value_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), params["max_grad_norm"])
            optimizer.step()

    y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

    wandb.log({
        "global_step": global_step,
        "learning_rate": optimizer.param_groups[0]["lr"],
        "value_loss": value_loss.item(),
        "policy_loss": pg_loss.item(),
        "entropy": entropy_loss.item(),
        "old_approx_kl": old_approx_kl.item(),
        "approx_kl": approx_kl.item(),
        "clipfrac": np.mean(clip_fraction),
        "explained_variance": explained_var,
        "SPS": int(global_step / (time() - start_time)),
        "reward_per_game": torch.sum(rewards, dim=0).mean(),
        "ratio_game_won": sum([env.game_state.agent_points["player_0"] > Constants.total_points / 2 for env in vec_env]) / params["n_envs"],
        "points_per_game": sum([env.game_state.agent_points["player_0"] for env in vec_env]) / params["n_envs"]
    })

In [7]:
torch.save(agent.state_dict(), 'agent.pt')
# Save as artifact for version control.
artifact = wandb.Artifact('model', type='model')
artifact.add_file('agent.pt')
run.log_artifact(artifact)
wandb.run.finish()

0,1
SPS,█▂▁▁▂▃▃▂▂▃▃▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
approx_kl,▂▁▁▁▂▃▃▅▆▆▆▅▆▆▆▅█▆▆▅▆▄█▆▆▄▆▇▅▅▅▆▅▆▅▇▅▅▇▆
clipfrac,▃▁▂▂▃▄▆▆▇▇██▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▆▅▅▅▅▅▅▆▅
entropy,███▇▇▇▇▆▆▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
explained_variance,▁▆▇▇████████████████████████████████████
global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
old_approx_kl,▂▂▁▂▃▃▄▅▅▆▇▆▇▄▆▇▅█▇▄▅▆▆▇▃▃▇▆▃▃▄▇▄▆▄▅▇█▅▄
points_per_game,▂▁▁▁▁▂▂▂▃▃▄▅▅▆▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇█▇▇███
policy_loss,▆▆▆▃▃▃▂▄▁▂▄▄▂▂▃▄▃▃▂▂▇▃▃▄▃▄▅▆▃█▅▄▅▄▃▃▃▃▅▄

0,1
SPS,6238.0
approx_kl,0.03673
clipfrac,0.11361
entropy,0.4333
explained_variance,0.89845
global_step,1720320.0
learning_rate,0.003
old_approx_kl,0.02785
points_per_game,74.30273
policy_loss,-0.02311
