In [2]:
import random
from copy import deepcopy
from typing import Tuple

import pylab as pl

import numpy as np
import torch
from scipy import stats
from sklearn.metrics import explained_variance_score
from torch import optim, Tensor, nn

from src.agents import Agent
from src.agents.NNAgent import NNAgent
from src.agents.RandomAgent import RandomAgent
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola as Briscola
from time import time
import gymnasium as gym
import wandb
from src.utils.AgentPool import AgentPool
from src.utils.training_utils import play_all_moves_of_player, get_state_representation

from src.vectorizers.VectorizedEnv import VectorizedEnv
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola

from src.envs.two_player_briscola.BriscolaConstants import Constants

from src.utils.training_utils import play_all_moves_of_players, compute_rating

In [9]:
params = {
    "n_envs" : 2048,
    "n_steps": Constants.deck_cards // 2,
    "lr": 1e-3,
    "lr_decay": 0.997,
    "lr_min": 1e-4,
    "mini_batch_size": 2048,
    "total_timesteps": 50_000_000,
    "gamma": 1.,
    "lambda": 0.95,
    "update_epochs": 2,
    "clip_coef": 0.2,
    "normalize_advantage": True,
    "clip_value_loss": True,
    "value_coef": 0.5,
    "entropy_coef": 1e-3,
    "entropy_decay": 0.998,
    "max_grad_norm": 0.5,
    "ratio_win_reward": 1.,
    "win_reward_increase": 0.0004,
    "n_opponents": 4,
    "self_play_opponents": 2,
    "max_pool_size": 128,
    "add_model_every_x_step": 2,
    "nu": 0.1,
    "hidden_size": 256,
    "activation": nn.Mish,
    "briscola_penalization": 0.,
    "briscola_penalization_decay": 0.99
}
params["batch_size"] = params["n_envs"] * params["n_steps"]
params["num_updates"] = params["total_timesteps"] // params["batch_size"]

In [10]:
run = wandb.init(
        name="briscola penalization continue",
        project="briscolaBot",
        entity="lettera",
        config=params,
        save_code=True,
        sync_tensorboard=False,
        mode="online"
)

In [11]:
vec_env = VectorizedEnv(lambda: TwoPlayerBriscola(), params["n_envs"])

In [13]:
device = "cpu"
observation_shape = vec_env.single_observation_space()["observation"].shape
action_size = vec_env.single_action_space().n

player_policy = NNAgent(observation_shape, action_size, hidden_size=params["hidden_size"], activation=params["activation"]).to(device)
player_policy.load_state_dict(torch.load("briscola_penalization.pt"))

player_name = vec_env[0].agents[0]
opponent_name = vec_env[0].agents[1]

trained_previous = NNAgent(observation_shape, action_size, hidden_size=256).to(device)
trained_previous.load_state_dict(torch.load("agent.pt"))

trained_v2 = NNAgent(observation_shape, action_size, hidden_size=256, activation=nn.Mish).to(device)
trained_v2.load_state_dict(torch.load("agent-v2.pt"))

pool = AgentPool(params["max_pool_size"], nu=params["nu"])

optimizer = optim.Adam(player_policy.parameters(), lr=params["lr"], eps=1e-5)

obs = torch.zeros((params["n_steps"], params["n_envs"]) + observation_shape).to(device)
actions = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int64).to(device)
actions_masks = torch.zeros((params["n_steps"], params["n_envs"]) + (action_size,), dtype=torch.int64).to(device)
logprobs = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
rewards = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
dones = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int8).to(device)
values = torch.zeros((params["n_steps"], params["n_envs"])).to(device)

global_step = 0
start_time = time()
for update in range(params["num_updates"]):
    # Decay lr
    current_lr = optimizer.param_groups[0]["lr"]
    optimizer.param_groups[0]["lr"] = max(params["lr_min"], current_lr*params["lr_decay"])

    # Decay entropy
    params["entropy_coef"] *= params["entropy_decay"]

    # Decay briscola penalization
    params["briscola_penalization"] *= params["briscola_penalization_decay"]

    # Increase ratio_win_reward
    params["ratio_win_reward"] = min(params["ratio_win_reward"] + params["win_reward_increase"], 1)

    # Add agent to pool
    if update % params["add_model_every_x_step"] == 0:
        pool.add_agent(deepcopy(player_policy))

    # Sample agents
    opponent_policies, opponent_indexes = pool.sample_agents(params["n_opponents"] - params["self_play_opponents"])
    opponent_policies += [player_policy] * params["self_play_opponents"]

    # Play episodes
    vec_env.reset()
    play_all_moves_of_players(vec_env, opponent_policies, opponent_name)
    next_obs, action_mask, reward, next_done = get_state_representation(vec_env)
    for step in range(params["n_steps"]):
        global_step += params["n_envs"]

        obs[step] = next_obs
        dones[step] = next_done

        with torch.no_grad():
            action, logprob, _, value = player_policy.get_action_and_value(next_obs.to(device), action_mask.to(device))
            values[step] = value.flatten()
        actions[step] = action
        actions_masks[step] = action_mask.to(device)
        logprobs[step] = logprob

        vec_env.step(actions[step].cpu().numpy(), briscola_penalization=params["briscola_penalization"])
        play_all_moves_of_players(vec_env, opponent_policies, opponent_name)
        next_obs, action_mask, reward, next_done = get_state_representation(vec_env)
        wins = torch.tensor([env.get_game_outcome(player_name) for env in vec_env], dtype=torch.float32)
        rewards[step] = (1 - params["ratio_win_reward"]) * reward + (params["ratio_win_reward"] * next_done * wins).to(device)

    # Update rating
    scores = [env.get_game_outcome(opponent_name) for env in vec_env.get_envs()]
    mean_score_per_opponent = np.empty_like(opponent_indexes, dtype=np.float64)
    for i in range(opponent_indexes.size):
        start, end = (i * len(scores)) // opponent_indexes.size, ((i + 1) * len(scores)) // opponent_indexes.size
        mean_score_per_opponent[i] = np.mean(scores[start:end])

    agent_rating = pool.update_ratings(0., mean_score_per_opponent, opponent_indexes)

    # Bootstrap value
    with torch.no_grad():
        next_value = player_policy.get_value(next_obs).reshape(1, -1)
        advantages = torch.zeros_like(rewards).to(device)
        last_gae_lambda = 0

        for t in reversed(range(params["n_steps"])):
            if t == params["n_steps"] - 1:
                next_non_terminal = 1. - next_done
                next_values = next_value
            else:
                next_non_terminal = 1. - dones[t+1]
                next_values = values[t+1]

            delta = rewards[t] + params["gamma"] * next_values * next_non_terminal - values[t]
            last_gae_lambda = delta + params["gamma"] * params["lambda"] * next_non_terminal * last_gae_lambda
            advantages[t] = last_gae_lambda

        returns = advantages + values

    # Optimize net
    b_obs = obs.reshape((-1,) + observation_shape)
    b_logprobs = logprobs.reshape(-1)
    b_actions = actions.reshape(-1)
    b_action_masks = actions_masks.reshape((-1, action_size))
    b_advantages = advantages.reshape(-1)
    b_returns = returns.reshape(-1)
    b_values = values.reshape(-1)

    clip_fraction = []
    b_indexes = np.arange(params["batch_size"])
    for epoch in range(params["update_epochs"]):
        np.random.shuffle(b_indexes)
        for start in range(0, params["batch_size"], params["mini_batch_size"]):
            end = start + params["mini_batch_size"]
            mb_indexes = b_indexes[start:end]

            _, newlogprob, entropy, newvalue = player_policy.get_action_and_value(b_obs[mb_indexes], b_action_masks[mb_indexes], b_actions[mb_indexes])
            logratio = newlogprob - b_logprobs[mb_indexes]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl https://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clip_fraction.append(((ratio - 1.0).abs() > params["clip_coef"]).float().mean().item())

            mb_advantages = b_advantages[mb_indexes]
            if params["normalize_advantage"]:
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

            # Policy loss
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-params["clip_coef"], 1+params["clip_coef"])
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value loss
            newvalue = newvalue.view(-1)
            if params["clip_value_loss"]:
                v_loss_unclipped = (newvalue - b_returns[mb_indexes]) ** 2
                v_clipped = b_values[mb_indexes] + torch.clamp(
                    newvalue - b_values[mb_indexes],
                    -params["clip_coef"],
                    params["clip_coef"],
                )
                v_loss_clipped = (v_clipped - b_returns[mb_indexes]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                value_loss = 0.5 * v_loss_max.mean()
            else:
                value_loss = 0.5 * ((newvalue - b_returns[mb_indexes]) ** 2).mean()

            entropy_loss = entropy.mean()
            loss = pg_loss - params["entropy_coef"] * entropy_loss + params["value_coef"] * value_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(player_policy.parameters(), params["max_grad_norm"])
            optimizer.step()

    # Logging
    explained_var = explained_variance_score(b_returns.cpu().numpy(), b_values.cpu().numpy())
    if update % 8 == 0:
        outcome_vs_random, rating_vs_random = compute_rating(player_policy, RandomAgent(action_size))
        outcome_vs_past_iterations, rating_vs_past_iterations = compute_rating(player_policy, pool.get_agent(-50))
        _, rating_vs_trained_previous = compute_rating(player_policy, trained_previous)
        _, rating_vs_v2 = compute_rating(player_policy, trained_v2)

    wandb.log({
        "global_step": global_step,
        "learning_rate": optimizer.param_groups[0]["lr"],
        "ratio_win_reward": params["ratio_win_reward"],
        "entropy_coef": params["entropy_coef"],

        "value_loss": value_loss.item(),
        "policy_loss": pg_loss.item(),
        "entropy": entropy_loss.item(),
        "total_loss": loss.item(),

        "old_approx_kl": old_approx_kl.item(),
        "approx_kl": approx_kl.item(),
        "clipfrac": np.mean(clip_fraction),
        "explained_variance": explained_var,
        "SPS": int(global_step / (time() - start_time)),

        "reward_per_game": torch.sum(rewards, dim=0).mean(),
        "points_per_game": sum([env.game_state.agent_points[player_name] for env in vec_env]) / params["n_envs"],
        "mean_outcome": sum([env.get_game_outcome(player_name) for env in vec_env]) / params["n_envs"],

        "outcome_vs_random": outcome_vs_random,
        "rating_vs_random": rating_vs_random,

        "outcome_vs_50_past_iterations": outcome_vs_past_iterations,
        "rating_vs_50_past_iterations": rating_vs_past_iterations,

        "rating_vs_best": rating_vs_trained_previous,
        "rating_vs_v2": rating_vs_v2,

        "pool_ratings": wandb.Histogram(pool.ratings),
        "pool_std": np.std(pool.ratings)
    })

  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal a

KeyboardInterrupt: 

In [133]:
from copy import deepcopy
from typing import List

import torch
from torch import tensor, nn
from torch.distributions import Categorical

from src.agents.Agent import Agent
from src.envs.two_player_briscola.BriscolaConstants import Constants
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola, State


def get_player_cards(briscola_env: TwoPlayerBriscola, player: str) -> List[int]:
    return briscola_env.game_state.hand_cards[player]


def find_best_move(briscola_env: TwoPlayerBriscola):
    player = briscola_env.agent_selection
    player_cards = get_player_cards(briscola_env, player)
    if len(player_cards) == 0:
        print(briscola_env.game_state.hand_cards, player)
        return None, briscola_env.game_state.agent_points[player]

    best_move = None
    best_score = -1
    for card in player_cards:
        next_env = deepcopy(briscola_env)
        next_env.step(card)
        _, score = find_best_move(next_env)
        if score > best_score:
            best_score = score
            best_move = card
    return best_move, best_score


def get_env(observation):
    def card_indexes(cards: tensor) -> list[int]:
        return (cards > 0.1).nonzero().squeeze(1).tolist()

    thrown_cards = observation[:Constants.deck_cards]
    briscola_card = card_indexes(observation[Constants.deck_cards:Constants.deck_cards * 2])[0]
    table_card = observation[Constants.deck_cards * 2:Constants.deck_cards * 3]
    player_cards = observation[Constants.deck_cards * 3:Constants.deck_cards * 4]
    player_points = (observation[-2] * Constants.total_points).item()
    opponent_points = (observation[-1] * Constants.total_points).item()

    opponent_cards = card_indexes(1 - thrown_cards - player_cards)
    thrown_cards = card_indexes(thrown_cards)
    player_cards = card_indexes(player_cards)
    player = "player_0" if len(player_cards) > len(opponent_cards) else "player_1"
    opponent = "player_1" if player == "player_0" else "player_0"
    state = State(deck=[],
                  thrown_cards_player=[],
                  thrown_cards=thrown_cards,
                  hand_cards={player: player_cards, opponent: opponent_cards},
                  table_card=Constants.null_card_number if len(card_indexes(table_card)) == 0 else card_indexes(table_card)[0],
                  briscola_card=briscola_card,
                  current_agent=player,
                  agent_points={"player_0": 0, "player_1": 0},
                  num_moves=len(thrown_cards)
                  )

    env = TwoPlayerBriscola()
    env.set_state(state)
    return env


class SearchingAgent(nn.Module, Agent):
    def __init__(self, policy: nn.Module, action_size: int, name: str = "Searching-Agent"):
        super().__init__()
        self.actor = policy
        self.action_size = action_size
        self.name = name

    def get_name(self) -> str:
        return self.name

    def get_probs(self, observations, action_masks):
        observations = self.obs_transform(observations)
        logits = self.actor(observations)
        if action_masks is not None:
            logits[~action_masks.bool()] = -1e8
        probs = Categorical(logits=logits)
        return probs

    def get_actions(self, observations: tensor, action_masks: tensor = None):
        actions = torch.empty(observations.shape[0], dtype=torch.int64)
        for i, observation in enumerate(observations):
            if observation[:Constants.deck_cards].sum() >= 34:
                actions[i], _ = find_best_move(get_env(observation))
            else:
                actions[i] = self.get_probs(observations[i].view(1, -1), action_masks[i].view(1, -1)).sample()
        return actions

    def forward(self, inputs: tensor):
        observation, action_mask = inputs[:, :-self.action_size], inputs[:, -self.action_size:]
        return self.get_actions(observation, action_mask)


In [134]:
searching_agent = SearchingAgent(player_policy.actor, Constants.deck_cards)
searching_agent.obs_transform = lambda x: x

In [135]:
outcomes.extend([compute_rating(searching_agent, trained_previous, 10) for _ in range(1)])

{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': [28]} player_1
{'player_1': [], 'player_0': [22]} player_1
{'player_1': [], 'player_0': [28]} player_1
{'player_1': [], 'player_0': [3]} player_1
{'player_1': [], 'player_0': [22]} player_1
{'player_1': [], 'player_0': [3]} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': []} player_1
{'player_1': [], 'player_0': [28]} player_1
{'player_1': [], 'player_0': [22]} player_1
{'player_1': [], 'player_0': [28]} player_1
{'player_1': [], 'player_0': [3]} player_1
{'player_1': [], 'player_0': [22]} player_1
{'player_1': []

In [80]:
np.mean([outcome[0]-0.5 for outcome in outcomes]), np.std([outcome[0] for outcome in outcomes]) / np.sqrt(len(outcomes))

(-0.03217391304347826, 0.01028599819443508)

In [23]:
from src.utils.onnx_utils import export_to_onnx
export_to_onnx(trained_previous, "briscola_penalization.onnx")

verbose: False, log level: Level.ERROR



In [24]:
torch.save(player_policy.state_dict(), 'briscola_penalization.pt')
# Save as artifact for version control.
artifact = wandb.Artifact('model', type='model')
artifact.add_file('briscola_penalization.pt')
run.log_artifact(artifact)
wandb.run.finish()

0,1
SPS,█▇███████▇▆▅▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃
approx_kl,█▄▄▃▃▅▃▃▂▆▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
clipfrac,██▇▇▇▆▆▅▆▆▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
entropy,█▇█▄▄▇▆███▆▇▄▅▄▇▃▅▄▇▄▃▆▅▆▆▆▃▃▄▁▆▃▅▆▁▄▆▃▅
entropy_coef,██▇▇▇▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
explained_variance,▇▆▆▅▁▆▄▆▆▇▆▄▆█▆▆█▆▅▅▇▃▆▆▆▆▅▆▅▇█▇▅▆▃▇▆▆▆▅
global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,██▇▆▆▆▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_outcome,▁▆▄▁█▄▂▁▅▂▄▂▄▃▂▆▅▆█▆█▄▄▄▇▄▁▄▆▇▃▄▅▅▆▇▆▄▄▄
old_approx_kl,█▇▆▆▆▄▅▄▅▅▂▃▆▄▃▃▂▂▁▃▂▃▂▂▂▁▂▂▂▂▁▃▁▂▁▂▂▂▁▂

0,1
SPS,3005.0
approx_kl,0.00051
clipfrac,0.00409
entropy,0.08023
entropy_coef,0.00011
explained_variance,0.51164
global_step,45096960.0
learning_rate,0.0001
mean_outcome,0.50928
old_approx_kl,0.00119
