Counts the total number of visits used in a training run. We're using this as a proxy for the number of inferences, which is a proxy for the computational cost of a training run.

In [1]:
import collections
import re
import sys
from pathlib import Path
from typing import Counter, Dict, Iterable, Optional, Tuple

import sgfmill.sgf
from tqdm.notebook import tqdm

sys.setrecursionlimit(4000)

In [2]:
def name_to_architecture(name: str) -> str:
    """Guesses what the network architecture is from the player name."""
    if name.startswith("t0-s"):
        # Model architecture is probably b6c96 but it depends on the training run,
        # so we'll just return "t0" and let the caller deal with it.
        return "t0"
    if name == "random":
        return "random"
    if "cp505" in name:
        return "b40c256"

    match = re.search("-(b[0-9]+c[0-9]+[^-]*)-", name)
    if match is not None:
        architecture = match.group(1)
        if architecture.endswith("x2"):
            architecture = architecture[:-2]
        return architecture
    raise ValueError("Can't determine architecture:", name)


def get_player_algos(game: sgfmill.sgf.Sgf_game) -> Tuple[Optional[str], Optional[str]]:
    """Returns the search algorithms the players are using."""
    algo_regex = re.compile("algo=([^,]+),")
    root = game.get_root()
    result = []
    for color in ["B", "W"]:
        player_info = root.get(f"{color}R")
        match = algo_regex.search(player_info)
        result.append(None if match is None else match.group(1))
    return tuple(result)


def count_node_visits(node: sgfmill.sgf.Tree_node) -> Optional[int]:
    """(helper) Returns visits of a single SGF node."""
    if not node.has_property("C"):
        return None
    visits_match = re.search(" v=(\d+)", node.get("C"))
    if visits_match is None:
        return None
    return int(visits_match.group(1))


def count_descendants_visits(
    node: sgfmill.sgf.Tree_node, num_nodes_to_ignore: int
) -> Tuple[int, int]:
    """(helper) Returns (black_visits, white_visits) of an SGF node and its descendants."""
    b_visits = 0
    w_visits = 0
    for child in node:
        child_b_visits, child_w_visits = count_descendants_visits(
            child, num_nodes_to_ignore - 1
        )
        b_visits += child_b_visits
        w_visits += child_w_visits

    player_color, _ = node.get_move()
    if player_color is not None:
        node_visits = count_node_visits(node)
        assert (node_visits is None) == (num_nodes_to_ignore > 0)
        if node_visits is not None:
            if player_color == "b":
                b_visits += node_visits
            else:
                w_visits += node_visits
    return b_visits, w_visits


def count_sgf_visits(game: sgfmill.sgf.Sgf_game) -> Tuple[int, int]:
    """Returns an estimate of (black_visits, white_visits) in a game.

    There are some visits that I'm (tomtseng) not sure we can count solely from looking at SGFs:
    - If config param sidePositionProb > 0, then each side position costs a whole search, plus it
      with 25% probability takes one more evaluation and generates a new side position. (Play::runGame())
    - Forking takes a few extra evaluations to select the forking move in Play::maybeForkGame().
    - After a game finishes, a few visits are used to estimate the lead by PlayUtils::computeLead().
    - Probably other stuff I'm missing.
    We could estimate these visits by extracting config parameters from victimplay/selfplay logs,
    but we'll just ignore these visits under the assumption that their contribution to the total
    number of visits is small.
    """
    game_root = game.get_root()

    assert game_root.has_property("C")
    game_info = game_root.get("C")
    match = re.search("startTurnIdx=(\d+),.*,gtype=([a-z]+)", game_info)
    assert match is not None
    start_turn_index = int(match.group(1))
    game_type = match.group(2)

    if game_type in ["cleanuptraining", "sgfpos", "hintpos", "hintfork", "other"]:
        # I (tomtseng) haven't looked into whether we need to count visits
        # in these games differently.
        raise ValueError("Counting visits for game type not implemented:", game_type)
    if game_type not in ["normal", "fork", "handicap", "asym"]:
        raise ValueError("Unknown game type:", game_type)

    b_visits = 0
    w_visits = 0

    # The obvious way to count handicap stones is to look at the HA[] property
    # with game.get_handicap(). However, KataGo populates the property incorrectly
    # if white spends its first moves passing (which can happen early in training
    # when the models are weak), counting the first stones played by black as handicap
    # stones. Instead, we look at the AB[] property.
    has_handicap_stones = game_root.has_property("AB")
    if has_handicap_stones:
        # Expect the game type to be "handicap" or overwritten by another
        # game type with higher priority.
        assert game_type != "normal"

        if game_type != "fork":
            num_handicap_stones = len(game_root.get("AB"))
            # In handicap games, each handicap stone is placed using an evaluation
            # of the black player's policy. Forked games don't count since the
            # evaluation was performed in the parent game.
            b_visits += num_handicap_stones

    if game_type != "fork":
        # In forked games, the start_turn_index represents forked moves, whereas
        # in other games, it represents moves sampled from the policy net
        # (config parameter `initGamesWithPolicy`).
        b_visits += (start_turn_index + 1) // 2
        w_visits += start_turn_index // 2

    root_b_visits, root_w_visits = count_descendants_visits(
        node=game_root,
        # Ignore the root node as well since it doesn't mark a move.
        num_nodes_to_ignore=start_turn_index + 1,
    )
    return b_visits + root_b_visits, w_visits + root_w_visits


def get_step_count(model_name: str):
    """Parses the step count from a model's name."""
    if model_name == "random":
        return 0
    step_match = re.search("-s([0-9]+)-", model_name)
    if step_match is None:
        raise ValueError("Unable to determine step count:", model_name)
    return int(step_match.group(1))


def count_training_run_visits(
    train_dir: Path,
    max_step_count: Optional[int] = None,
) -> Counter:
    """Counts the number of visits per architecture in a training run.

    Args
    ----
    max_step_count: If set, then games are only examined for trained models with
        strictly fewer than this many steps.
    """
    visits_counter = collections.Counter()
    for directory in tqdm(list(train_dir.glob("selfplay/*/sgfs"))):
        if max_step_count is not None:
            model_name = directory.parent.name
            if get_step_count(model_name) >= max_step_count:
                continue

        for file in tqdm(list(directory.glob("*.sgfs")), leave=False):
            with open(file) as f:
                for sgf in tqdm(list(f), leave=False):
                    game = sgfmill.sgf.Sgf_game.from_string(sgf)
                    b_name = game.get_player_name("b")
                    w_name = game.get_player_name("w")
                    b_architecture = name_to_architecture(b_name)
                    w_architecture = name_to_architecture(w_name)
                    b_algo, w_algo = get_player_algos(game)

                    try:
                        b_visits, w_visits = count_sgf_visits(game)
                    except (ValueError, KeyError):
                        print("Failed to count visits for SGF:", sgf)
                        raise

                    if b_algo == "MCTS":
                        visits_counter[b_architecture] += b_visits
                    elif b_algo == "AMCTS" or b_algo == "EMCTS1":
                        visits_counter[b_architecture] += b_visits / 2
                        visits_counter[w_architecture] += b_visits / 2
                    else:
                        raise ValueError("Handling algo not implemented:", b_algo)
                    if w_algo == "MCTS":
                        visits_counter[w_architecture] += w_visits
                    elif w_algo == "AMCTS" or w_algo == "EMCTS1":
                        visits_counter[b_architecture] += w_visits / 2
                        visits_counter[w_architecture] += w_visits / 2
                    else:
                        raise ValueError("Handling algo not implemented:", w_algo)
    return visits_counter


# Sanity-checking on some SGFs.
assert (1802, 16) == count_sgf_visits(
    sgfmill.sgf.Sgf_game.from_string(
        "(;FF[4]GM[1]SZ[15]PB[t0-s307440128-d77282325]PW[victim-kata1-b40c256-s11840935168-d2898845681.bin.gz]BR[algo=EMCTS1,v=600,rsym=4,opp_v=1,opp_rsym=1]WR[algo=MCTS,v=8,rsym=4]HA[0]DT[20221107-022651~20221107-050948]KM[6.5]RU[koPOSITIONALscoreAREAtaxNONEsui1]RE[B+50.5]C[startTurnIdx=0,initTurnNum=0,gameHash=17784BDC77ADF20CBC03129EE9FA405F,gtype=normal];B[mh]C[0.40 0.60 0.00 -37.3 v=600 weight=0.76];W[lm]C[0.54 0.46 0.00 1.0 v=8 weight=0.91];B[md]C[0.40 0.60 0.00 -34.3 v=601 weight=0.72];W[dd]C[0.56 0.44 0.00 1.2 v=8 weight=0.88];B[le]C[0.39 0.61 0.00 -33.3 v=601 weight=0.66])"
    )
)
assert (117, 111) == count_sgf_visits(
    sgfmill.sgf.Sgf_game.from_string(
        "(;FF[4]GM[1]SZ[19]PB[t0-s9053440-d2582634]PW[t0-s9053440-d2582634]BR[algo=MCTS,v=600,rsym=4]WR[algo=MCTS,v=600,rsym=4]HA[6]DT[20231031-195311~20231031-205212]KM[1.5]RU[koPOSITIONALscoreAREAtaxNONEsui1]RE[B+45.5]AB[je][bf][ch][do][ss]C[startTurnIdx=23,initTurnNum=0,gameHash=DD69CBEACFCEC0502E08745793D39CD1,gtype=handicap,newNeuralNetTurn115=t0-s10069760-d2819464];B[eb];W[kj];B[mr];W[lh];B[dp];W[bl];B[bm];W[kh];B[df];W[jp];B[pn];W[jh];B[hh];W[fd];B[dn];W[fh];B[hb];W[rj];B[bn];W[fp];B[ok];W[rr];B[rf];W[hl]C[0.50 0.50 0.00 -1.9 v=100 weight=0.07];B[ll]C[0.50 0.50 0.00 -1.7 v=100 weight=0.01])"
    )
)
assert (600, 100) == count_sgf_visits(
    sgfmill.sgf.Sgf_game.from_string(
        "(;FF[4]GM[1]SZ[13]PB[t0-s10069760-d2819464]PW[t0-s10069760-d2819464]BR[algo=MCTS,v=600,rsym=4]WR[algo=MCTS,v=600,rsym=4]HA[0]DT[20231031-205103~20231031-211333]KM[1]RU[koPOSITIONALscoreAREAtaxNONEsui1]RE[W+29]C[startTurnIdx=2,initTurnNum=0,gameHash=7369701C5306EB15A1820F217C86A639,gtype=fork,usedInitialPosition=1];B[kj];W[fd];B[ih]C[0.50 0.50 0.00 -0.2 v=600 weight=0.56];W[jk]C[0.50 0.50 0.00 -0.3 v=100 weight=0.00])"
    )
)

In [None]:
# This function can take a long time to run, so I'll paste the results below.
count_training_run_visits(
    train_dir=Path(
        "/nas/ucb/k8/go-attack/victimplay/ttseng-avoid-pass-alive-coldstart-39-20221025-175949/"
    ),
    max_step_count=545065216,
)

# cyclic:
# - b40c256: 179,348,619,656
# - t0: 91,066,142,195,
# - random: 310,825,776,
# - b15c192: 16,957,564,
# - b10c128: 9,985,812,
# - b6c96: 8,119,280,
# - b20c256: 6,129,528
#
# cyclic-r1
# - b40c256: 18,669,701,896
# - t0: 12,854,060,712
#
# cyclic-r2:
# - b40c256: 41,985,434,602
# - t0: 19,107,129,898
#
# cyclic-r3:
# - b40c256: 22,138,221,404.5
# - t0: 15,685,005,020.5