In [None]:
!pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install "tensorflow[and-cuda]"
!pip install "flax[all]"
!pip install gdown
!pip install pgx
!pip install git+https://github.com/aminwoo/pgx.git
!pip install mctx
!pip install tqdm
!pip install pydantic

In [None]:
!gdown https://drive.google.com/drive/folders/17-BtU1koT2nulH6NzsQum9fTQCMGuN2g?usp=sharing -O /kaggle/working/ --folder

In [None]:
import time
from dataclasses import dataclass

import flax.linen as nn
import jax
import jax.numpy as jnp


def mish(x):
    return x * jnp.tanh(jax.nn.softplus(x))


@dataclass
class AZResnetConfig:
    num_blocks: int
    channels: int
    policy_channels: int
    value_channels: int
    num_policy_labels: int


class ResidualBlock(nn.Module):
    channels: int
    se: bool
    se_ratio: int = 4

    @nn.compact
    def __call__(self, x, train: bool):
        y = nn.Conv(
            features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False
        )(x)
        y = nn.BatchNorm(use_running_average=not train)(y)
        y = mish(y)
        y = nn.Conv(
            features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False
        )(x)
        y = nn.BatchNorm(use_running_average=not train)(y)

        if self.se:
            squeeze = jnp.mean(y, axis=(1, 2), keepdims=True)

            excitation = nn.Dense(
                features=self.channels // self.se_ratio, use_bias=True
            )(squeeze)
            excitation = nn.relu(excitation)
            excitation = nn.Dense(features=self.channels, use_bias=True)(excitation)
            excitation = nn.hard_sigmoid(excitation)

            y = y * excitation

        return mish(x + y)


class AZResnet(nn.Module):
    config: AZResnetConfig

    @nn.compact
    def __call__(self, x, train: bool):
        batch_size = x.shape[0]

        x = nn.Conv(
            features=self.config.channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = mish(x)

        for _ in range(self.config.num_blocks):
            x = ResidualBlock(channels=self.config.channels, se=True)(x, train=train)

        # policy head
        policy = nn.Conv(
            features=self.config.channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(x)
        policy = nn.BatchNorm(use_running_average=not train)(policy)
        policy = mish(policy)
        policy = nn.Conv(
            features=self.config.policy_channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(policy)
        policy = nn.BatchNorm(use_running_average=not train)(policy)
        policy = mish(policy)
        policy = policy.reshape((batch_size, -1))
        policy = nn.Dense(features=self.config.num_policy_labels)(policy)

        # value head
        value = nn.Conv(
            features=self.config.value_channels, kernel_size=(1, 1), use_bias=False
        )(x)
        value = nn.BatchNorm(use_running_average=not train)(value)
        value = mish(value)
        value = value.reshape((batch_size, -1))
        value = nn.Dense(features=256)(value)
        value = mish(value)
        value = nn.Dense(features=1)(value)
        value = nn.tanh(value)
        value = value.squeeze(axis=1)

        return policy, value

In [None]:
import os

from typing import Any
from tqdm.auto import tqdm

import numpy as np
import jax
import jax.numpy as jnp
import optax
import orbax
import orbax.checkpoint as ocp
import chex

from flax.training import orbax_utils
from flax.training.train_state import TrainState
from flax import linen as nn
from flax.training import train_state


class TrainState(train_state.TrainState):
    batch_stats: chex.ArrayTree


class TrainerModule:

    def __init__(
        self,
        model_class: nn.Module,
        model_configs: Any,
        optimizer_name: str,
        optimizer_params: dict,
        x: Any,
        ckpt_dir: str = '/kaggle/working/',
        max_checkpoints: int = 99,
        seed=42,
    ):
        '''
        Module for summarizing all training functionalities for classification.

        Inputs:
            model_name - String of the class name, used for logging and saving
            model_class - Class implementing the neural network
            model_hparams - Hyperparameters of the model, used as input to model constructor
            optimizer_name - String of the optimizer name, supporting ['sgd', 'adam', 'adamw']
            optimizer_params - Hyperparameters of the optimizer, including learning rate as 'lr'
            x - Example imgs, used as input to initialize the model
            seed - Seed to use in the model initialization
        '''
        super().__init__()
        self.model_class = model_class
        self.model_configs = model_configs
        self.optimizer_name = optimizer_name
        self.optimizer_params = optimizer_params
        self.seed = seed
        # Create empty model. Note: no parameters yet
        self.model = self.model_class(model_configs)
        self.ckpt_dir = ckpt_dir
        
        orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        options = orbax.checkpoint.CheckpointManagerOptions(
            max_to_keep=max_checkpoints, create=True
        )
        self.checkpoint_manager = orbax.checkpoint.CheckpointManager(
            self.ckpt_dir, orbax_checkpointer, options
        )

    def load_checkpoint(self, step=0) -> TrainState:
        # Load model. We use different checkpoint for pretrained models
        ckpt = self.checkpoint_manager.restore(step)
        self.state = ckpt['train_state']
        return ckpt['train_state']

In [None]:
import datetime
import os
import pickle
import random 
import time
import requests
from threading import Thread
from functools import partial
from typing import Optional, List

import chex
import jax
import jax.numpy as jnp
import numpy as np 
import mctx
import optax
import pgx
from flax.training import train_state
from pydantic import BaseModel
from tqdm import tqdm


class Sample(BaseModel):
    obs: List[List[float]]
    policy_tgt: List[List[float]]
    value_tgt: List[int]


model_configs = AZResnetConfig(
    num_blocks=15,
    channels=256,
    policy_channels=4,
    value_channels=8,
    num_policy_labels=2*64*78+1,
)
net = AZResnet(model_configs)
trainer = TrainerModule(model_class=AZResnet, model_configs=model_configs, optimizer_name='lion', optimizer_params={'learning_rate': 1}, x=jnp.ones((1, 8, 16, 32)))
state = trainer.load_checkpoint(20240406121656)

params = {'params': state['params'], 'batch_stats': state['batch_stats']}
forward = jax.jit(partial(net.apply, train=False))

devices = jax.local_devices()
num_devices = len(devices)
print('Number of devices:', num_devices)

class Config(BaseModel):
    env_id: pgx.EnvId = 'bughouse'
    seed: int = random.randint(0, 999999999)
    max_num_iters: int = 1000
    # selfplay params
    selfplay_batch_size: int = 1
    num_simulations: int = 800
    max_num_steps: int = 512

    class Config:
        extra = 'forbid'

config: Config = Config()

env = pgx.make(config.env_id)

def recurrent_fn(params, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.State):
    rng_keys = jax.random.split(rng_key, config.selfplay_batch_size)
    current_player = state.current_player
    state = jax.vmap(env.step)(state, action, rng_keys)

    logits, value = forward(params, state.observation)
    logits = logits.at[:, 9984].set(jnp.max(logits, axis=1))

    # mask invalid actions
    logits = logits - jnp.max(logits, axis=-1, keepdims=True)
    logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min)

    reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player]
    value = jnp.where(state.terminated, 0.0, value)
    discount = -1.0 * jnp.ones_like(value)
    discount = jnp.where(state.terminated, 0.0, discount)

    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=reward,
        discount=discount,
        prior_logits=logits,
        value=value,
    )
    return recurrent_fn_output, state


@partial(jax.jit, static_argnums=(2,))
def run_mcts(state, key, num_simulations: int, tree: Optional[mctx.Tree] = None):
    key1, key2 = jax.random.split(key)

    logits, value = forward(params, state.observation)
    logits = logits.at[:, 9984].set(jnp.max(logits, axis=1))

    root = mctx.RootFnOutput(prior_logits=logits, value=value, embedding=state)

    policy_output = mctx.alphazero_policy(
        params=params,
        rng_key=key1,
        root=root,
        recurrent_fn=recurrent_fn,
        num_simulations=num_simulations,
        invalid_actions=~state.legal_action_mask,
        search_tree=None,
        qtransform=partial(mctx.qtransform_by_min_max, min_value=-1, max_value=1),
    )
    return policy_output
    

if __name__ == '__main__':
    init_fn = jax.jit(jax.vmap(env.init))
    step_fn = jax.jit(jax.vmap(env.step))

    print('Running selfplay with initial seed', config.seed)

    rng_key = jax.random.PRNGKey(config.seed)

    for _ in tqdm(range(config.max_num_iters)):
        game_id = random.randint(0, 999999999)
        print(f'Playing game id: {game_id}')

        rng_key, sub_key = jax.random.split(rng_key)
        keys = jax.random.split(sub_key, config.selfplay_batch_size)
        state = init_fn(keys)
        tree = None 
        actions = [] 
        times = [] 

        obs = [] 
        policy_tgt = [] 
        value_tgt = [] 

        while ~state.terminated.all():
            rng_key, sub_key = jax.random.split(rng_key)
            policy_output = run_mcts(state, sub_key, config.num_simulations, tree)

            obs.append(state.observation.ravel())
            policy_tgt.append(policy_output.action_weights.ravel())

            action = policy_output.action.item()

            keys = jax.random.split(sub_key, config.selfplay_batch_size)
            state = step_fn(state, policy_output.action, keys)

        reward = abs(int(state.rewards[0][0]))
        for i in range(len(obs)):
            value_tgt.append(reward)
            reward *= -1 

        value_tgt = value_tgt[::-1]

        game_data = {
            "obs": obs, 
            "policy_tgt": policy_tgt, 
            "value_tgt": value_tgt,
        }

        url = f"http://ec2-18-208-220-129.compute-1.amazonaws.com:8000/game/{game_id}/"
        sample = Sample(**game_data)
        response = requests.post(url, json=sample.model_dump())

        if response.status_code == 200:
            print("Game sent successfully!")
        else:
            print(f"Error: {response.status_code}")