In [None]:
!pip install gdown
!pip install pgx
!pip install chess
!pip install git+https://github.com/aminwoo/pgx.git
!pip install git+https://github.com/lowrollr/mctx-az.git

In [None]:
!gdown https://drive.google.com/drive/folders/1Z-GHFXG2r-9mBTQ5kyAtzaCx0d6jnhQ9?usp=sharing --folder

In [None]:
import time
from dataclasses import dataclass

import flax.linen as nn
import jax
import jax.numpy as jnp


def mish(x):
    return x * jnp.tanh(jax.nn.softplus(x))


@dataclass
class AZResnetConfig:
    num_blocks: int
    channels: int
    policy_channels: int
    value_channels: int
    num_policy_labels: int


class ResidualBlock(nn.Module):
    channels: int
    se: bool
    se_ratio: int = 4

    @nn.compact
    def __call__(self, x, train: bool):
        y = nn.Conv(
            features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False
        )(x)
        y = nn.BatchNorm(use_running_average=not train)(y)
        y = mish(y)
        y = nn.Conv(
            features=self.channels, kernel_size=(3, 3), padding=(1, 1), use_bias=False
        )(x)
        y = nn.BatchNorm(use_running_average=not train)(y)

        if self.se:
            squeeze = jnp.mean(y, axis=(1, 2), keepdims=True)

            excitation = nn.Dense(
                features=self.channels // self.se_ratio, use_bias=True
            )(squeeze)
            excitation = nn.relu(excitation)
            excitation = nn.Dense(features=self.channels, use_bias=True)(excitation)
            excitation = nn.hard_sigmoid(excitation)

            y = y * excitation

        return mish(x + y)


class AZResnet(nn.Module):
    config: AZResnetConfig

    @nn.compact
    def __call__(self, x, train: bool):
        batch_size = x.shape[0]

        x = nn.Conv(
            features=self.config.channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = mish(x)

        for _ in range(self.config.num_blocks):
            x = ResidualBlock(channels=self.config.channels, se=True)(x, train=train)

        # policy head
        policy = nn.Conv(
            features=self.config.channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(x)
        policy = nn.BatchNorm(use_running_average=not train)(policy)
        policy = mish(policy)
        policy = nn.Conv(
            features=self.config.policy_channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            use_bias=False,
        )(policy)
        policy = nn.BatchNorm(use_running_average=not train)(policy)
        policy = mish(policy)
        policy = policy.reshape((batch_size, -1))
        policy = nn.Dense(features=self.config.num_policy_labels)(policy)

        # value head
        value = nn.Conv(
            features=self.config.value_channels, kernel_size=(1, 1), use_bias=False
        )(x)
        value = nn.BatchNorm(use_running_average=not train)(value)
        value = mish(value)
        value = value.reshape((batch_size, -1))
        value = nn.Dense(features=256)(value)
        value = mish(value)
        value = nn.Dense(features=1)(value)
        value = nn.tanh(value)
        value = value.squeeze(axis=1)

        return policy, value
    
import requests
from bs4 import BeautifulSoup


def get_session_key(username: str, password: str) -> str:
    """Log into chess.com and retrieve php session key

    Args:
        username (str): chess.com username
        password (str): chess.com password

    Returns:
        str: PHPSESSID
    """
    s = requests.Session()
    login_url = "https://www.chess.com/login_and_go?returnUrl=https://www.chess.com/"

    response = s.get(login_url, allow_redirects=True)
    soup = BeautifulSoup(response.content, "html.parser")
    token_input = soup.find("input", {"name": "_token"})
    token = token_input.get("value")

    login_data = {"_username": username,
                    "_password": password,
                    "login": "",
                    "_target_path": "https://www.chess.com/game/live",
                    "_token": token
                    }

    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3626.121 Safari/537.36",
        "Referer": login_url
    }
    r = s.post("https://www.chess.com/login_check", data=login_data, headers=headers, allow_redirects=False, verify=True)
    return r.headers["Set-Cookie"].split("=")[1].split(";")[0]

import math

import chess


def _move_dict_to_obj(move_dict):
    move_obj = chess.Move(
        from_square=(
            chess.parse_square(move_dict["to_square"])
            if move_dict["from_square"] is None
            else chess.parse_square(move_dict["from_square"])
        ),
        to_square=(
            None
            if move_dict["to_square"] is None
            else chess.parse_square(move_dict["to_square"])
        ),
        drop=(
            None
            if move_dict["drop"] is None
            else chess.Piece.from_symbol(move_dict["drop"]).piece_type
        ),
        promotion=(
            None
            if move_dict["promotion"] is None
            else chess.Piece.from_symbol(move_dict["promotion"]).piece_type
        ),
    )
    return move_obj


# tcn_decode and tcn_encode are 1:1 port of chess-tcn npm library that chess.com uses
def tcn_decode(n):
    tcn_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!?{~}(^)[_]@#$,./&-*++="
    piece_chars = "qnrbkp"
    w = len(n)
    c = []
    for i in range(0, w, 2):
        u = {
            "from_square": None,
            "to_square": None,
            "drop": None,
            "promotion": None,
        }
        o = tcn_chars.index(n[i])
        s = tcn_chars.index(n[i + 1])
        if s > 63:
            u["promotion"] = piece_chars[math.floor((s - 64) / 3)]
            s = o + (-8 if o < 16 else 8) + ((s - 1) % 3) - 1
        if o > 75:
            u["drop"] = piece_chars[o - 79]
        else:
            u["from_square"] = tcn_chars[o % 8] + str(math.floor(o / 8) + 1)
        u["to_square"] = tcn_chars[s % 8] + str(math.floor(s / 8) + 1)
        move = _move_dict_to_obj(u)
        c.append(move)
    return c


def tcn_encode(n):
    tcn_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!?{~}(^)[_]@#$,./&-*++="
    piece_chars = "qnrbkp"
    o = len(n)
    w = ""
    for i in range(o):
        if n[i][1] == "@":
            s = 79 + piece_chars.index(n[i][0].lower())
        else:
            s = tcn_chars.index(n[i][0]) + 8 * (
                int(n[i][1]) - 1
            )
        u = tcn_chars.index(n[i][2]) + 8 * (int(n[i][3]) - 1)
        if len(n[i]) > 4:
            add_u = 9 + u - s if u < s else u - s - 7
            u = 3 * piece_chars.index(n[i][4]) + 64 + add_u
        w += tcn_chars[s]
        w += tcn_chars[u]
    return w


def mirrorMoveUCI(uci_move):
    move = chess.Move.from_uci(uci_move)
    return mirrorMove(move).uci()


def mirrorMove(move):
    return chess.Move(
        chess.square_mirror(move.from_square),
        chess.square_mirror(move.to_square),
        move.promotion,
        move.drop,
    )

In [None]:
from functools import partial
from time import time

import jax
import jax.numpy as jnp
import mctx
import orbax.checkpoint as ocp
from pgx.bughouse import Action, Bughouse, State, _set_current_player, _time_advantage

ckpt_dir = "/kaggle/working"

options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
    ckpt_dir,
    options=options,
    item_handlers=ocp.PyTreeCheckpointHandler())
ckpt = mngr.restore(1)

variables = {"params": ckpt["params"], "batch_stats": ckpt["batch_stats"]}

seed = 42
key = jax.random.PRNGKey(seed)
keys = jax.random.split(key, 1)
env = Bughouse()
step_fn = jax.jit(jax.vmap(env.step))
init_fn = jax.jit(jax.vmap(env.init))

@jax.jit
def search(state, variables):
    
    model = AZResnet(
        AZResnetConfig(
            num_blocks=15,
            channels=256,
            policy_channels=4,
            value_channels=8,
            num_policy_labels=2*64*78+1,
        )
    )
    forward = jax.jit(partial(model.apply, train=False))

    def recurrent_fn(variables, rng_key: jnp.ndarray, action: jnp.ndarray, state):
        rng_keys = jax.random.split(rng_key, 1)
        current_player = state.current_player
        state = step_fn(state, action, rng_keys)

        logits, value = forward(variables, state.observation)
        logits = logits - jnp.max(logits, axis=-1, keepdims=True)
        logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min)

        reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player]
        value = jnp.where(state.terminated, 0.0, value)
        discount = -1.0 * jnp.ones_like(value)
        discount = jnp.where(state.terminated, 0.0, discount)

        recurrent_fn_output = mctx.RecurrentFnOutput(
            reward=reward,
            discount=discount,
            prior_logits=logits,
            value=value,
        )
        return recurrent_fn_output, state

    logits, value = forward(variables, state.observation)
    root = mctx.RootFnOutput(prior_logits=logits, value=value, embedding=state)
    
    policy_output = mctx.alphazero_policy(
        params=variables,
        rng_key=key,
        root=root,
        recurrent_fn=recurrent_fn,
        num_simulations=200,
        invalid_actions=~state.legal_action_mask,
        qtransform=partial(mctx.qtransform_by_min_max, min_value=-1, max_value=1),
        dirichlet_fraction=0.0,
        temperature=0.0
    )

    return policy_output

state = init_fn(keys)
step_fn(state, jnp.int32([0]), keys)
out = search(state, variables)

from pgx.experimental.bughouse import make_policy_labels
labels = make_policy_labels()

In [None]:
import asyncio
import json
import time

import random
import jax
import jax.numpy as jnp
import typer
import websockets
import yaml
from pgx.bughouse import (Action, Bughouse, _set_clock,
                          _set_current_player, _time_advantage, _is_promotion)


update_clock = jax.jit(jax.vmap(_set_clock))
update_player = jax.jit(jax.vmap(_set_current_player))
time_advantage = jax.jit(jax.vmap(_time_advantage))
is_promotion = jax.jit(jax.vmap(_is_promotion))

ping = random.randint(11, 69)

class Client:
    """
    Client class to play an account
    """
    def __init__(self, config) -> None:
        self.phpsessid = get_session_key(config["username"], config["password"])
        self.username = config["username"]
        self.partner = config["partner"]
        self.opponent = config["opponent"]
        self.board_num = config["board"]
        self.clientId = ""
        self.ply = 0
        self.gameId = -1
        self.side = -1
        self.id = 1
        self.ack = 1
        self.playing = False
        self.state = None
        self.lengths = [0, 0]
        self.times = [[1200, 1200], [1200, 1200]]
        self.turn = [0, 0]
        self.key = jax.random.PRNGKey(seed)
        self.thinking = False
        self.new_game()

    def new_game(self) -> None:
        self.key, subkey = jax.random.split(self.key)
        keys = jax.random.split(subkey, 1)
        self.state = init_fn(keys)
        self.lengths = [0, 0]
        self.times = [[1200, 1200], [1200, 1200]]
        self.turn = [0, 0]

    async def play_move(self, board_num: int, move: str, ws=None) -> None:
        move_uci = move
        if self.turn[board_num] == 1:
            move_uci = mirrorMoveUCI(move_uci)
        move_uci = str(board_num) + move_uci
        if move_uci.endswith("q"): # Treat queen promotion as default move
            move_uci = move_uci[:-1]
        action = labels.index(move_uci)

        self.state = update_player(self.state, jnp.int32([self.turn[board_num]]) if board_num == 0 else jnp.int32([1 - self.turn[board_num]]))
        if self.state.legal_action_mask[0][action]:
            if ws:
                if self.times[board_num][self.turn[board_num]] < 1750 and self.times[board_num][self.turn[board_num]] > 50 and self.times[1 - board_num][1 - self.turn[board_num]] > 50: 
                    time.sleep(random.random() * 2.0)
                else:
                    time.sleep(random.random() * 0.3)
                await self.send_move(ws, tcn_encode([move]))

            self.key, subkey = jax.random.split(self.key)
            keys = jax.random.split(subkey, 1)
            self.state = step_fn(self.state, jnp.int32([action]), keys)
            self.turn[board_num] = 1 - self.turn[board_num]
            print(action, "Move played:", move, "on board", board_num)

    async def send_partnership(self, ws) -> None:
        data = [
            {
                "channel": "/service/game",
                "data": {
                    "tid": "RequestBughousePair",
                    "to": self.partner,
                    "from": self.username,
                },
                "id": self.id,
                "clientId": self.clientId,
            },
        ]
        await ws.send(json.dumps(data))
        self.id += 1

    async def rematch(self, ws) -> None:
        data = [
            {
                "channel": "/service/game",
                "data": {
                    "tid": "Challenge",
                    "uuid": "",
                    "to": self.opponent,
                    "from": self.username,
                    "gametype": "bughouse",
                    "initpos": None,
                    "rated": False,
                    "minrating": None,
                    "maxrating": None,
                    "rematchgid": self.gameId,
                    "color": 2 if self.side == 0 else 1,
                    "basetime": 1200,
                    "timeinc": 0
                },
                "id": self.id,
                "clientId": self.clientId,
            },
        ]
        await ws.send(json.dumps(data))
        self.id += 1

    async def seek_game(self, ws) -> None:
        data = [
            {
                "channel": "/service/game",
                "data": {
                    "tid": "Challenge",
                    "uuid": "",
                    "to": None,
                    "from": self.username,
                    "gametype": "bughouse",
                    "initpos": None,
                    "rated": True,
                    "minrating": 1800,
                    "maxrating": None,
                    "basetime": 1800,
                    "timeinc": 0
                },
                "id": self.id,
                "clientId": self.clientId,
            },
        ]
        await ws.send(json.dumps(data))
        self.id += 1

    async def send_move(self, ws, move: str) -> None:
        data = [
            {
                "channel": "/service/game",
                "data": {
                    "move": {
                        "gid": self.gameId,
                        "move": move,
                        "seq": self.ply,
                        "uid": self.username,
                    },
                    "tid": "Move",
                },
                "id": self.id,
                "clientId": self.clientId,
            },
        ]
        await ws.send(json.dumps(data))
        self.id += 1

    def update_clock(self, board_num, times):
        delta = max(self.times[board_num][i] - times[i] for i in range(2))
        self.times[1 - board_num][self.turn[1 - board_num]] -= delta
        self.times[board_num] = times

    def update_clock_and_player(self):
        self.state = update_player(self.state, jnp.int32([self.turn[self.board_num]]) if self.board_num == 0 else jnp.int32([1 - self.turn[self.board_num]]))
        t = self.times.copy()
        if self.turn[0] != 0:
            t[0] = t[0][::-1]
        if self.turn[1] != 0:
            t[1] = t[1][::-1]
            
        self.state = update_clock(self.state, jnp.int32([t]))
        delta = time_advantage(self.state).item()
        if delta >= 20:
            delta -= 20
        elif delta <= -20:
            delta += 20 
        else:
            delta = 0
        t[self.board_num][0] -= delta
        t[self.board_num][1] += delta 
        self.state = update_clock(self.state, jnp.int32([t]))

    async def start(self) -> None:
        self.update_clock_and_player()
        time_advantage(self.state)
        update_player(self.state, jnp.int32([0]))
        is_promotion(self.state, jnp.int32([0]))

        #print("Started")
        async with websockets.connect("wss://live2.chess.com/cometd", extra_headers=[("Cookie", f"PHPSESSID={self.phpsessid}")]) as ws:
            data = [
                {
                    "version":"1.0",
                    "minimumVersion":"1.0",
                    "channel":"/meta/handshake",
                    "supportedConnectionTypes":["ssl-websocket"],
                    "advice":{"timeout":60000,"interval":0},
                    "clientFeatures":{
                        "protocolversion":"2.1",
                        "clientname":"LC6;chrome/121.0.6167/browser;Windows 10;jxk3sm4;78.0.2",
                        "skiphandshakeratings":True,
                        "adminservice":True,
                        "announceservice":True,
                        "arenas":True,
                        "chessgroups":True,
                        "clientstate":True,
                        "events":True,
                        "gameobserve":True,
                        "genericchatsupport":True,
                        "genericgamesupport":True,
                        "guessthemove":True,
                        "multiplegames":True,
                        "multiplegamesobserve":True,
                        "offlinechallenges":True,
                        "pingservice":True,
                        "playbughouse":True,
                        "playchess":True,
                        "playchess960":True,
                        "playcrazyhouse":True,
                        "playkingofthehill":True,
                        "playoddschess":True,
                        "playthreecheck":True,
                        "privatechats":True,
                        "stillthere":True,
                        "teammatches":True,
                        "tournaments":True,
                        "userservice":True},
                    "serviceChannels":["/service/user"],
                    "ext":{
                        "ack":True,
                        "timesync":{"tc":int(time.time()*1000),"l":ping,"o":0}
                    },
                    "id":self.id,
                    "clientId":None
                }
            ]
            await ws.send(json.dumps(data))
            self.id += 1

            async for message in ws:
                message = json.loads(message)[0]
                asyncio.create_task(self.handle_message(ws, message))

    async def handle_message(self, ws, message: str) -> None:
        #print(message)
        # Get Client ID
        if "clientId" in message:
            self.clientId = message["clientId"]
            await self.send_partnership(ws)
            await self.seek_game(ws)
            
        if "data" in message and "tid" in message["data"] and message["data"]["tid"] == "RequestBughousePair" and "from" in message["data"]:
            await self.send_partnership(ws)
            
        if "data" in message and "tid" in message["data"] and message["data"]["tid"] == "BughousePair":
            print(f"Partnered to {self.partner}")

        # Send heartbeat back to server
        if (message["channel"] == "/meta/connect" or message["channel"] == "/meta/handshake") and message["successful"]:
            if message["channel"] == "/meta/connect":
                self.ack = message["ext"]["ack"]
            data = [{"channel":"/meta/connect","connectionType":"ssl-websocket","ext":{"ack":self.ack,"timesync":{"tc":int(time.time()*1000),"l":ping,"o":0}},"id":self.id,"clientId":self.clientId}]
            await ws.send(json.dumps(data))
            self.id += 1

        # Handle game logic
        if "data" in message and "game" in message["data"] and "status" in message["data"]["game"]:
            if message["data"]["game"]["status"] == "finished":
                if self.playing:
                    print("Game ended")
                    self.playing = False
                    self.new_game()
                    #await self.rematch(ws)
                    time.sleep(2)
                    await self.seek_game(ws)
            else:
                if message["data"]["game"]["status"] == "starting":
                    self.playing = True

                players = message["data"]["game"]["players"]
                user_index = -1
                for i in range(len(players)):
                    if players[i]["uid"].lower() == self.username.lower():
                        user_index = i
                        break

                times = message["data"]["game"]["clocks"]
                tcn_moves = message["data"]["game"]["moves"]
                move = "" if not tcn_moves else tcn_decode(tcn_moves[-2:])[0].uci()
                if user_index != -1:
                    self.gameId = message["data"]["game"]["id"]
                    self.ply = message["data"]["game"]["seq"]
                    self.side = user_index

                    # Update clock times
                    self.update_clock(self.board_num, times)

                    # Update state with new move
                    if move and self.lengths[self.board_num] < len(tcn_moves) and self.turn[self.board_num] != self.side:
                        self.lengths[self.board_num] = len(tcn_moves)
                        await self.play_move(self.board_num, move)
                else:
                    self.update_clock(1 - self.board_num, times)

                    if move and self.lengths[1 - self.board_num] < len(tcn_moves):
                        self.lengths[1 - self.board_num] = len(tcn_moves)
                        await self.play_move(1 - self.board_num, move)

                # If our turn to move play engine move
                if not self.thinking and self.turn[self.board_num] == self.side and ~self.state.terminated.any():
                    self.thinking = True
                    self.update_clock_and_player()

                    out = search(self.state, 20)
                    action = out.action
                    
                    move_uci = labels[action[0]]
                    if len(move_uci) < 6 and is_promotion(self.state, action)[0]:
                        move_uci += "q"

                    print(move_uci)
                    if move_uci != "pass" and int(move_uci[0]) == self.board_num:
                        move_uci = move_uci[1:]
                        if self.turn[self.board_num] == 1:
                            move_uci = mirrorMoveUCI(move_uci)
                        await self.play_move(self.board_num, move_uci, ws)
                    self.thinking = False

config = {
    "username": "",
    "password": "",
    "board": 0,
    "opponent": "",
    "partner": "",
}
print("Logging into " + config["username"])
client = Client(config)
asyncio.run(client.start())
