In [1]:
import random
from copy import deepcopy
from typing import Tuple

import pylab as pl

import numpy as np
import torch
from scipy import stats
from sklearn.metrics import explained_variance_score
from torch import optim, Tensor, nn

from src.agents import Agent
from src.agents.NNAgent import NNAgent
from src.agents.RandomAgent import RandomAgent
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola as Briscola
from time import time
import gymnasium as gym
import wandb
from src.utils.AgentPool import AgentPool
from src.utils.training_utils import play_all_moves_of_player, get_state_representation

from src.vectorizers.VectorizedEnv import VectorizedEnv
from src.envs.two_player_briscola.TwoPlayerBriscola import TwoPlayerBriscola

from src.envs.two_player_briscola.BriscolaConstants import Constants

from src.utils.training_utils import play_all_moves_of_players, compute_rating

In [2]:
params = {
    "n_envs" : 2048,
    "n_steps": Constants.deck_cards // 2,
    "lr": 3e-3,
    "lr_decay": 0.995,
    "lr_min": 3e-4,
    "mini_batch_size": 1024,
    "total_timesteps": 10_000_000,
    "gamma": 1.,
    "lambda": 0.9,
    "update_epochs": 2,
    "clip_coef": 0.3,
    "normalize_advantage": True,
    "clip_value_loss": True,
    "value_coef": 0.5,
    "entropy_coef": 1e-2,
    "max_grad_norm": 0.5,
    "ratio_win_reward": 0.1,
    "n_opponents": 4,
    "self_play_opponents": 2,
    "max_pool_size": 128,
    "add_model_every_x_step": 1,
    "nu": 0.1,
    "hidden_size": 256,
    "activation": nn.Mish
}
params["batch_size"] = params["n_envs"] * params["n_steps"]
params["num_updates"] = params["total_timesteps"] // params["batch_size"]

In [3]:
run = wandb.init(
        name="More self-play",
        project="briscolaBot",
        entity="lettera",
        config=params,
        save_code=True,
        sync_tensorboard=False,
        mode="online"
)

[34m[1mwandb[0m: Currently logged in as: [33mlettera[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
vec_env = VectorizedEnv(lambda: TwoPlayerBriscola(), params["n_envs"])

In [5]:
device = "cpu"
observation_shape = vec_env.single_observation_space()["observation"].shape
action_size = vec_env.single_action_space().n

player_policy = NNAgent(observation_shape, action_size, hidden_size=params["hidden_size"], activation=params["activation"]).to(device)
player_name = vec_env[0].agents[0]
opponent_name = vec_env[0].agents[1]

trained_vs_random = NNAgent(162, action_size, obs_transform=lambda x: x[:, 82:]).to(device)
trained_vs_random.load_state_dict(torch.load("train-vs-random.pt"))

pool = AgentPool(params["max_pool_size"], nu=params["nu"])

optimizer = optim.Adam(player_policy.parameters(), lr=params["lr"], eps=1e-5)

obs = torch.zeros((params["n_steps"], params["n_envs"]) + observation_shape).to(device)
actions = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int64).to(device)
actions_masks = torch.zeros((params["n_steps"], params["n_envs"]) + (action_size,), dtype=torch.int64).to(device)
logprobs = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
rewards = torch.zeros((params["n_steps"], params["n_envs"])).to(device)
dones = torch.zeros((params["n_steps"], params["n_envs"]), dtype=torch.int8).to(device)
values = torch.zeros((params["n_steps"], params["n_envs"])).to(device)

global_step = 0
start_time = time()
for update in range(1, params["num_updates"] + 1):
    # Decay lr
    current_lr = optimizer.param_groups[0]["lr"]
    optimizer.param_groups[0]["lr"] = max(params["lr_min"], current_lr*params["lr_decay"])

    # Add agent to pool
    if update % params["add_model_every_x_step"] == 0:
        pool.add_agent(deepcopy(player_policy))

    # Sample agents
    opponent_policies, opponent_indexes = pool.sample_agents(params["n_opponents"] - params["self_play_opponents"])
    opponent_policies += [player_policy] * params["self_play_opponents"]

    # Play episodes
    vec_env.reset()
    play_all_moves_of_players(vec_env, opponent_policies, opponent_name)
    next_obs, action_mask, reward, next_done = get_state_representation(vec_env)
    for step in range(params["n_steps"]):
        global_step += params["n_envs"]

        obs[step] = next_obs
        dones[step] = next_done

        with torch.no_grad():
            action, logprob, _, value = player_policy.get_action_and_value(next_obs.to(device), action_mask.to(device))
            values[step] = value.flatten()
        actions[step] = action
        actions_masks[step] = action_mask.to(device)
        logprobs[step] = logprob

        vec_env.step(actions[step].cpu().numpy())
        play_all_moves_of_players(vec_env, opponent_policies, opponent_name)
        next_obs, action_mask, reward, next_done = get_state_representation(vec_env)
        wins = torch.tensor([env.game_winner() == player_name for env in vec_env], dtype=torch.float32)
        rewards[step] = (1 - params["ratio_win_reward"]) * reward + (params["ratio_win_reward"] * next_done * wins).to(device)

    # Update rating
    scores = [env.get_game_outcome(opponent_name) for env in vec_env.get_envs()]
    mean_score_per_opponent = np.empty_like(opponent_indexes, dtype=np.float64)
    for i in range(opponent_indexes.size):
        start, end = (i * len(scores)) // opponent_indexes.size, ((i + 1) * len(scores)) // opponent_indexes.size
        mean_score_per_opponent[i] = np.mean(scores[start:end])

    agent_rating = pool.update_ratings(0., mean_score_per_opponent, opponent_indexes)

    # Bootstrap value
    with torch.no_grad():
        next_value = player_policy.get_value(next_obs).reshape(1, -1)
        advantages = torch.zeros_like(rewards).to(device)
        last_gae_lambda = 0

        for t in reversed(range(params["n_steps"])):
            if t == params["n_steps"] - 1:
                next_non_terminal = 1. - next_done
                next_values = next_value
            else:
                next_non_terminal = 1. - dones[t+1]
                next_values = values[t+1]

            delta = rewards[t] + params["gamma"] * next_values * next_non_terminal - values[t]
            last_gae_lambda = delta + params["gamma"] * params["lambda"] * next_non_terminal * last_gae_lambda
            advantages[t] = last_gae_lambda

        returns = advantages + values

    # Optimize net
    b_obs = obs.reshape((-1,) + observation_shape)
    b_logprobs = logprobs.reshape(-1)
    b_actions = actions.reshape(-1)
    b_action_masks = actions_masks.reshape((-1, action_size))
    b_advantages = advantages.reshape(-1)
    b_returns = returns.reshape(-1)
    b_values = values.reshape(-1)

    clip_fraction = []
    b_indexes = np.arange(params["batch_size"])
    for epoch in range(params["update_epochs"]):
        np.random.shuffle(b_indexes)
        for start in range(0, params["batch_size"], params["mini_batch_size"]):
            end = start + params["mini_batch_size"]
            mb_indexes = b_indexes[start:end]

            _, newlogprob, entropy, newvalue = player_policy.get_action_and_value(b_obs[mb_indexes], b_action_masks[mb_indexes], b_actions[mb_indexes])
            logratio = newlogprob - b_logprobs[mb_indexes]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl https://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clip_fraction.append(((ratio - 1.0).abs() > params["clip_coef"]).float().mean().item())

            mb_advantages = b_advantages[mb_indexes]
            if params["normalize_advantage"]:
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

            # Policy loss
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1-params["clip_coef"], 1+params["clip_coef"])
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value loss
            newvalue = newvalue.view(-1)
            if params["clip_value_loss"]:
                v_loss_unclipped = (newvalue - b_returns[mb_indexes]) ** 2
                v_clipped = b_values[mb_indexes] + torch.clamp(
                    newvalue - b_values[mb_indexes],
                    -params["clip_coef"],
                    params["clip_coef"],
                )
                v_loss_clipped = (v_clipped - b_returns[mb_indexes]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                value_loss = 0.5 * v_loss_max.mean()
            else:
                value_loss = 0.5 * ((newvalue - b_returns[mb_indexes]) ** 2).mean()

            entropy_loss = entropy.mean()
            loss = pg_loss - params["entropy_coef"] * entropy_loss + params["value_coef"] * value_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(player_policy.parameters(), params["max_grad_norm"])
            optimizer.step()

    # Logging
    explained_var = explained_variance_score(b_returns.cpu().numpy(), b_values.cpu().numpy())
    if (update - 1) % 8 == 0:
        outcome_vs_random, rating_vs_random = compute_rating(player_policy, RandomAgent(action_size))
        outcome_vs_past_iterations, rating_vs_past_iterations = compute_rating(player_policy, pool.get_agent(-50))
        _, rating_vs_trained_random = compute_rating(player_policy, trained_vs_random, 1024)

    wandb.log({
        "global_step": global_step,
        "learning_rate": optimizer.param_groups[0]["lr"],

        "value_loss": value_loss.item(),
        "policy_loss": pg_loss.item(),
        "entropy": entropy_loss.item(),
        "total_loss": loss.item(),

        "old_approx_kl": old_approx_kl.item(),
        "approx_kl": approx_kl.item(),
        "clipfrac": np.mean(clip_fraction),
        "explained_variance": explained_var,
        "SPS": int(global_step / (time() - start_time)),

        "reward_per_game": torch.sum(rewards, dim=0).mean(),
        "points_per_game": sum([env.game_state.agent_points[player_name] for env in vec_env]) / params["n_envs"],
        "mean_outcome": sum([env.get_game_outcome(player_name) for env in vec_env]) / params["n_envs"],

        "outcome_vs_random": outcome_vs_random,
        "rating_vs_random": rating_vs_random,

        "outcome_vs_50_past_iterations": outcome_vs_past_iterations,
        "rating_vs_50_past_iterations": rating_vs_past_iterations,

        "rating_vs_trained_random": rating_vs_trained_random,

        "pool_ratings": wandb.Histogram(pool.ratings),
        "pool_std": np.std(pool.ratings),
        "pool_size": len(pool)
    })

  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal action, executing {action} instead")
  warn(f"Tried to execute an illegal a

In [None]:
agent = NNAgent((162, ), 40, hidden_size=256)
agent.load_state_dict(torch.load("agent.pt"))

export_to_onnx(agent)

In [8]:
torch.save(player_policy.state_dict(), 'agent.pt')
# Save as artifact for version control.
artifact = wandb.Artifact('model', type='model')
artifact.add_file('agent.pt')
run.log_artifact(artifact)
wandb.run.finish()

0,1
SPS,▁██▇▅▆▅▅▆▆▆▆▆▆▆▆▆▆▆▇▇▆▆▅▅▅▄▅▅▅▅▅▅▅▅▅▅▅▅▅
approx_kl,▂▄▅▄▅▅█▅▆▅▄▆█▅▆▃▅▄▅▅▄▃▃▃▄▃▂▇▃▂▂▄▃▃▃▂▁▃▂▁
clipfrac,▄██▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▃▃▂▃▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁
entropy,█▇▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
explained_variance,▁▇▇█████████████████████████████████████
global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,███▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁
mean_outcome,▁▃▄▃██▂▂▂▂▄▇▃▁▃▂▂▂▃▂▁▃▅▄▃▃▃▂▃▂▃▄▂▃▂▂▂▂▃▁
old_approx_kl,▃▃▆▅▄▃▆▇▅▆▅█▅▇▆▃▅▄▄▅▃▄▂▃▅▅▄█▂▁▃▁▃▃▄▃▁▃▁▃
outcome_vs_50_past_iterations,▂▂▆▆▇▇▇▇▇███▅▅▃▃▃▃▃▃▃▂▂▂▃▃▂▂▁▂▂▂▂▂▁▁▃▃▂▂

0,1
SPS,3030.0
approx_kl,0.02259
clipfrac,0.05898
entropy,0.21368
explained_variance,0.80936
global_step,7618560.0
learning_rate,0.00118
mean_outcome,0.5022
old_approx_kl,0.03229
outcome_vs_50_past_iterations,0.52441
