This notebook allows one to estimate the compute involved in training a model.

### Load libraries

In [1]:
import pathlib

from sgf_parser import game_info

### Load training run data

In [2]:
def get_game_infos(data_dir: str):
    sgf_paths = game_info.find_sgf_files(root=pathlib.Path(f"{data_dir}/selfplay"))
    return game_info.read_and_parse_all_files(
        sgf_paths,
        fast_parse=True,
    )

In [3]:
# Training run for passing adversary
GAME_INFOS_PASS = get_game_infos(
    "/nas/ucb/tony/go-attack/training/emcts1-curr/cp127-to-505-v1"
)
len(GAME_INFOS_PASS)

201896

In [4]:
# Training run for cyclic adversary
# https://www.notion.so/chaiberkeley/adv-b6-600-vs-avoid-pass-alive-v1-curriculum-start-at-cp39-again-ba7cb9bd348d409db3e6fb9c89a56227
GAME_INFOS_DEF = get_game_infos(
    "/nas/ucb/k8/go-attack/victimplay/ttseng-avoid-pass-alive-coldstart-39-20221025-175949"
)
len(GAME_INFOS_DEF)

1015622

### Estimate compute

In [5]:
# Computed in the estimate-flops-katago.ipynb notebook
# (output of Cell 8, also stored in a global variable called MACS_DICT)
# The (random: 0 is added by us here).
MACS_DICT = {
    "b10c128": 1052902056.7067871,
    "b15c192": 3535965536.369873,
    "b20c256": 8392483407.783936,
    "b40c256": 16703260895.197998,
    "b60c320": 38884293414.61206,
    "b6c96": 350317665.0437012,
} | {"random": 0}


def compute_flops(
    info: dict[str, int | str],
    adv_net_size: str,
    verbose: bool = False,
) -> float:
    """
    Given an game info dict, returns the number of flops used in that game.

    Assumes a constant adv_net_size.
    """

    if info["adv_name"] == "random":
        adv_net_size = "random"
    assert adv_net_size in MACS_DICT

    adv_visits: int = int(info["adv_visits"]) if info["adv_visits"] is not None else 600

    victim_net_size: str = info["victim_name"].split("-")[2]  # type: ignore
    if victim_net_size.endswith("x2"):
        victim_net_size = victim_net_size[:-2]
    victim_visits: int = int(info["victim_visits"])

    num_moves: int = int(info["num_moves"])
    adv_moves = num_moves / 2
    victim_moves = num_moves / 2

    # 2 flops per MAC
    adv_flops = (
        2
        * adv_moves
        * adv_visits
        * ((MACS_DICT[adv_net_size] + MACS_DICT[victim_net_size]) / 2)
    )

    victim_flops = 2 * victim_moves * victim_visits * MACS_DICT[victim_net_size]

    if verbose:
        print("adv:", adv_net_size, adv_visits)
        print("victim:", victim_net_size, victim_visits)
        print("# moves:", num_moves)
        print("adv flops:", adv_flops)
        print("victim flops:", victim_flops)

    return adv_flops + victim_flops


print(compute_flops(GAME_INFOS_DEF[135010], adv_net_size="b6c96", verbose=True))
print()
print(compute_flops(GAME_INFOS_PASS[50], adv_net_size="b6c96", verbose=True))

adv: b6c96 600
victim: b40c256 4096
# moves: 401
adv flops: 2051545500797076.5
victim flops: 2.7435039207319132e+16
2.948658470811621e+16

adv: b6c96 600
victim: b40c256 1
# moves: 28
adv flops: 143250059906030.28
victim flops: 467691305065.54395
143717751211095.8


In [6]:
def count_rows_and_moves(game_infos):
    tot_moves = sum(info["num_moves"] for info in game_infos)
    tot_rows = max(
        int(info["adv_name"].split("-")[2].lstrip("d"))
        for info in game_infos
        if info["adv_name"] != "random"
    )

    # There are roughly 2x as many moves as rows,
    # since we only train on moves made by the adversary.
    print("Tot moves:", tot_moves)
    print("Tot rows:", tot_rows)
    print("Moves / row:", tot_moves / tot_rows)


print("Pass run stats")
count_rows_and_moves(GAME_INFOS_PASS)

print()
print("Def run stats")
count_rows_and_moves(GAME_INFOS_DEF)

Pass run stats
Tot moves: 22254080
Tot rows: 15021246
Moves / row: 1.481506926922041

Def run stats
Tot moves: 305376648
Tot rows: 137006591
Moves / row: 2.2289193955639695


In [7]:
print("Pass run: Total FLOPs")
print(sum(compute_flops(info, adv_net_size="b6c96") for info in GAME_INFOS_PASS))

Pass run: Total FLOPs
1.125046864835051e+20


In [8]:
print("Def run: FLOPs up to playing against victim with 256 moves")
print(
    sum(
        compute_flops(info, adv_net_size="b6c96")
        for info in GAME_INFOS_DEF
        if int(info["victim_visits"]) <= 256
    )
)

Def run: FLOPs up to playing against victim with 256 moves
1.132901791968054e+21


In [9]:
# Compute up to t0-s545065216-d136760487
# which was used in the 1mil/10mil experiments.
# https://www.notion.so/chaiberkeley/match-adv-s545m-v600-vs-cp505-v1mil-10search_threads-50ccb5697f404c559a1763b7cfa759ef
print("Def run: FLOPs up to 1mil visit experiment adversary")
print(
    sum(
        compute_flops(info, adv_net_size="b6c96")
        for info in GAME_INFOS_DEF
        if int(info["adv_steps"]) <= 545065216
    )
)

Def run: FLOPs up to 1mil visit experiment adversary
9.579539505273113e+21
