In [None]:
!pip install --upgrade pip
!pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install "tensorflow[and-cuda]"
!pip install "flax[all]"
!pip install gdown
!pip install pgx
!pip install tqdm
!pip install pydantic
!pip install git+https://github.com/aminwoo/pgx.git
!pip install git+https://github.com/lowrollr/mctx-az.git

!gdown https://drive.google.com/drive/folders/13FcUDoZC5bvKjel_5qlLEiWuhKD4V5zn?usp=sharing --folder

import time
from dataclasses import dataclass

import flax.linen as nn
import jax
import jax.numpy as jnp


def mish(x):
    return x * jnp.tanh(jax.nn.softplus(x))


@dataclass
class AZResnetConfig:
    num_blocks: int
    channels: int
    policy_channels: int
    value_channels: int
    num_policy_labels: int


class ResidualBlock(nn.Module):
    channels: int
    se: bool
    se_ratio: int = 4

    @nn.compact
    def __call__(self, x, train: bool):
        y = nn.Conv(
            features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False
        )(x)
        y = nn.BatchNorm(use_running_average=not train)(y)
        y = mish(y)
        y = nn.Conv(
            features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False
        )(x)
        y = nn.BatchNorm(use_running_average=not train)(y)

        if self.se:
            squeeze = jnp.mean(y, axis=(1, 2), keepdims=True)

            excitation = nn.Dense(
                features=self.channels // self.se_ratio, use_bias=True
            )(squeeze)
            excitation = nn.relu(excitation)
            excitation = nn.Dense(features=self.channels, use_bias=True)(excitation)
            excitation = nn.hard_sigmoid(excitation)

            y = y * excitation

        return mish(x + y)


class AZResnet(nn.Module):
    config: AZResnetConfig

    @nn.compact
    def __call__(self, x, train: bool):
        batch_size = x.shape[0]

        x = nn.Conv(
            features=self.config.channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = mish(x)

        for _ in range(self.config.num_blocks):
            x = ResidualBlock(channels=self.config.channels, se=True)(x, train=train)

        # policy head
        policy = nn.Conv(
            features=self.config.channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(x)
        policy = nn.BatchNorm(use_running_average=not train)(policy)
        policy = mish(policy)
        policy = nn.Conv(
            features=self.config.policy_channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(policy)
        policy = nn.BatchNorm(use_running_average=not train)(policy)
        policy = mish(policy)
        policy = policy.reshape((batch_size, -1))
        policy = nn.Dense(features=self.config.num_policy_labels)(policy)

        # value head
        value = nn.Conv(
            features=self.config.value_channels, kernel_size=(1, 1), use_bias=False
        )(x)
        value = nn.BatchNorm(use_running_average=not train)(value)
        value = mish(value)
        value = value.reshape((batch_size, -1))
        value = nn.Dense(features=256)(value)
        value = mish(value)
        value = nn.Dense(features=1)(value)
        value = nn.tanh(value)
        value = value.squeeze(axis=1)

        return policy, value
    

import os

from typing import Any
from tqdm.auto import tqdm

import numpy as np
import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint as ocp
import chex
    
import datetime
import os
import pickle
import random 
import time
import requests
from threading import Thread
from functools import partial
from typing import Optional, List

import chex
import jax
import jax.numpy as jnp
import numpy as np 
import mctx
import optax
import pgx
from pgx.bughouse import _time_advantage
from flax.training import train_state
from pydantic import BaseModel
from tqdm import tqdm

model_configs = AZResnetConfig(
    num_blocks=15,
    channels=256,
    policy_channels=4,
    value_channels=8,
    num_policy_labels=2*64*78+1,
)
net = AZResnet(model_configs)
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
    "/kaggle/working",
    options=options,
    item_handlers=ocp.PyTreeCheckpointHandler()
)
ckpt = mngr.restore(0)

params = {"params": ckpt["params"], "batch_stats": ckpt["batch_stats"]}

devices = jax.local_devices()
num_devices = len(devices)
print("Number of devices:", num_devices)

class Config(BaseModel):
    env_id: pgx.EnvId = "bughouse"
    seed: int = random.randint(0, 999999999)
    max_num_iters: int = 1000
    # selfplay params
    selfplay_batch_size: int = 1
    num_simulations: int = 800
    max_num_steps: int = 512

    class Config:
        extra = "forbid"

config: Config = Config()

env = pgx.make(config.env_id)
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))


def recurrent_fn(params, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.State):
    rng_keys = jax.random.split(rng_key, config.selfplay_batch_size)
    current_player = state.current_player
    state = step_fn(state, action, rng_keys)

    logits, value = net.apply(params, state.observation, train=False)

    # mask invalid actions
    logits = logits - jnp.max(logits, axis=-1, keepdims=True)
    logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min)

    reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player]
    value = jnp.where(state.terminated, 0.0, value)
    discount = -1.0 * jnp.ones_like(value)
    discount = jnp.where(state.terminated, 0.0, discount)

    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=reward,
        discount=discount,
        prior_logits=logits,
        value=value,
    )
    return recurrent_fn_output, state


@partial(jax.jit, static_argnums=(2, ))
def run_mcts(state, key, num_simulations: int, params, tree: Optional[mctx.Tree] = None):
    key1, key2 = jax.random.split(key)

    logits, value = net.apply(params, state.observation, train=False)
    root = mctx.RootFnOutput(prior_logits=logits, value=value, embedding=state)

    policy_output = mctx.alphazero_policy(
        params=params,
        rng_key=key1,
        root=root,
        recurrent_fn=recurrent_fn,
        num_simulations=num_simulations,
        invalid_actions=~state.legal_action_mask,
        search_tree=None,
        qtransform=partial(mctx.qtransform_by_min_max, min_value=-1, max_value=1),
    )
    return policy_output


@jax.pmap
def selfplay(params, rng_key: jnp.ndarray):    
    rng_key, sub_key = jax.random.split(rng_key)
    keys = jax.random.split(sub_key, config.selfplay_batch_size)
    state = init_fn(keys)
    tree = None 

    i = 0 
    obs = jnp.zeros((config.max_num_steps, 4096)) 
    policy_tgt = jnp.zeros((config.max_num_steps, 9985)) 

    def cond_fun(carry):
        state, rng_key, *_ = carry
        return ~state.terminated.all() 

    def body_fun(carry):
        state, rng_key, obs, policy_tgt, i = carry
        rng_key, sub_key = jax.random.split(rng_key)
        policy_output = run_mcts(state, sub_key, config.num_simulations, params, tree)

        obs = obs.at[i].set(state.observation.ravel())
        policy_tgt = policy_tgt.at[i].set(policy_output.action_weights.ravel())

        keys = jax.random.split(sub_key, config.selfplay_batch_size)
        state = step_fn(state, policy_output.action, keys)
        return state, rng_key, obs, policy_tgt, i + 1

    state, rng_key, obs, policy_tgt, num_samples = jax.lax.while_loop(cond_fun, body_fun, (state, rng_key, obs, policy_tgt, i))
    return obs, policy_tgt, num_samples, abs(state.rewards[0][0])
            
if __name__ == "__main__":
    print("Running selfplay with initial seed", config.seed)

    params = jax.device_put_replicated(params, devices)
    rng_key = jax.random.PRNGKey(config.seed)
    

    for _ in tqdm(range(config.max_num_iters)):
        rng_key, subkey = jax.random.split(rng_key)
        keys = jax.random.split(rng_key, num_devices)

        obs, policy_tgt, num_samples, result = selfplay(params, keys)
        result = np.array(result)
        
        obs = np.concatenate((obs[0][:num_samples[0]], obs[1][:num_samples[1]]))
        policy_tgt = np.concatenate((policy_tgt[0][:num_samples[0]], policy_tgt[1][:num_samples[1]]))
        value_tgt = [[], []]
        for i in range(2):
            for _ in range(num_samples[i]):
                value_tgt[i].append(result[i])
                result[i] *= -1
            value_tgt[i] = value_tgt[i][::-1]
        value_tgt = value_tgt[0] + value_tgt[1]
                
        now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=0)))
        filepath = f'/kaggle/working/training-run1-{now.strftime("%Y%m%d")}-{now.strftime("%H%M")}'
        np.savez_compressed(filepath, obs=obs, policy_tgt=policy_tgt, value_tgt=value_tgt)

        url = f"http://ec2-3-84-181-213.compute-1.amazonaws.com:8000/upload"
        file = {"file": open(filepath + ".npz", "rb")}

        response = requests.post(url=url, files=file) 

        if response.status_code == 200:
            print("Game sent successfully!")
        else:
            print(f"Error: {response.status_code}")



     