In this notebook,
we estimate the FLOPs needed to compute the forward pass
of various KataGo models.

This notebook works in the provided Docker container,
though has only been tested when a GPU is available.
It will probably be very slow without a GPU.

### Load libraries

In [1]:
import contextlib
import os
import pathlib
import sys
import warnings

import git.repo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ptflops
import requests
import thop
import torch
from bs4 import BeautifulSoup
from torch import nn
from tqdm.auto import tqdm

# We need to import some libraries from upstream KataGo.
# KataGo doesn't have a setup.py, so we modify sys.path instead (kinda hacky).
GIT_ROOT = pathlib.Path(
    str(git.repo.Repo(".", search_parent_directories=True).working_tree_dir)
)
sys.path.append(str(GIT_ROOT / "submodules/KataGo-pytorch-rewrite/python"))

import modelconfigs
from model_pytorch import Model

### Load KataGo training history

In [2]:
# Pull data from https://katagotraining.org/networks/
r = requests.get("https://katagotraining.org/networks/")
soup = BeautifulSoup(r.content, "html.parser")

# Get the name of all the networks
network_names = [
    x.td.text.strip()
    for x in soup.find_all(
        "tr",
        {
            "class": lambda x: x
            in (
                "normalNetworkStyle",
                "strongestNetworkStyle",
            )
        },
    )
]
assert network_names[-1] == "kata1-random"
network_names = network_names[:-1]
len(network_names)

540

In [3]:
def parse_network_name(name: str) -> dict[str, int | str]:
    net_size = name.split("-")[1]
    if net_size.endswith("x2"):
        net_size = net_size.rstrip("x2")

    return dict(
        name=name,
        net_size=net_size,
        steps=int(name.split("-")[2].lstrip("s")),
        rows=int(name.split("-")[3].lstrip("d")),
    )

# Create dataframe of all network checkpoints
df_nets = pd.DataFrame([parse_network_name(x) for x in network_names])
df_nets.head()

Unnamed: 0,name,net_size,steps,rows
0,kata1-b60c320-s6782286336-d3070935549,b60c320,6782286336,3070935549
1,kata1-b60c320-s6769829376-d3067673297,b60c320,6769829376,3067673297
2,kata1-b60c320-s6757237760-d3064295323,b60c320,6757237760,3064295323
3,kata1-b60c320-s6744642560-d3061231329,b60c320,6744642560,3061231329
4,kata1-b60c320-s6729327872-d3057177418,b60c320,6729327872,3057177418


In [4]:
UNIQUE_NET_SIZES = df_nets.net_size.unique()
UNIQUE_NET_SIZES

array(['b60c320', 'b40c256', 'b20c256', 'b15c192', 'b10c128', 'b6c96'],
      dtype=object)

### Collect FLOP data

In [5]:
class FilteredWriter:
    """Filters out an annoying message from ptflops."""

    def write(self, message):
        if message not in (
            "Warning! No positional inputs found for a module, assuming batch size is 1.",
            "\n",
        ):
            warnings.warn(message)


def measure_macs(
    model: nn.Module,
    batch_size: int = 256,
) -> dict[str, float | str]:
    """Returns macs / batch element."""
    # Get inputs to model.
    # Input shapes obtained via debugging the following command:
    #   python submodules/KataGo-pytorch-rewrite/python/test.py \
    #     -npzdir /nas/ucb/k8/go-attack/victimplay/ttseng-avoid-pass-alive-coldstart-39-20221025-175949/selfplay/t0-s532017152-d133516007/tdata \
    #     -model-kind b6c96 \
    #     -pos-len 19 \
    #     -batch-size 256
    binaryInputNCHW = torch.randn(batch_size, 22, 19, 19, device="cuda")
    globalInputNC = torch.randn(batch_size, 19, device="cuda")

    # Measure via ptflops
    ptflops_macs: float
    ptflops_params: float
    with contextlib.redirect_stdout(FilteredWriter()):  # type: ignore
        ptflops_macs, ptflops_params = ptflops.get_model_complexity_info(
            model,
            (1,),
            input_constructor=lambda _: dict(
                input_spatial=binaryInputNCHW,
                input_global=globalInputNC,
            ),
            as_strings=False,
            print_per_layer_stat=False,
        )  # type: ignore

    # Measure via thop
    thop_macs: float
    thop_params: float
    thop_macs, thop_params = thop.profile(  # type: ignore
        model,
        inputs=(binaryInputNCHW, globalInputNC),
        verbose=False,
    )

    return dict(
        ptflops_macs=ptflops_macs / batch_size,
        ptflops_params=ptflops_params,
        thop_macs=thop_macs / batch_size,
        thop_params=thop_params,
        batch_size=batch_size,
    )

In [6]:
results = []
for net_size in tqdm(UNIQUE_NET_SIZES):
    model_config = modelconfigs.config_of_name[net_size]
    model = Model(model_config, 19)
    model.initialize()
    model.cuda()
    model.eval()

    for batch_size in [2, 4, 8, 16, 32, 64, 128, 256]:
        res = measure_macs(model, batch_size=batch_size)
        res["net_size"] = net_size
        results.append(res)

df = pd.DataFrame(results)
df.head()

  0%|          | 0/6 [00:00<?, ?it/s]

Unnamed: 0,ptflops_macs,ptflops_params,thop_macs,thop_params,batch_size,net_size
0,38891420000.0,108559237,38877250000.0,108502149.0,2,b60c320
1,38891390000.0,108559237,38877220000.0,108502149.0,4,b60c320
2,38891380000.0,108559237,38877210000.0,108502149.0,8,b60c320
3,38891370000.0,108559237,38877200000.0,108502149.0,16,b60c320
4,38891370000.0,108559237,38877200000.0,108502149.0,32,b60c320


In [7]:
# Check we get low std in measurements
gb = df.groupby("net_size")
df_mean = pd.DataFrame(
    dict(
        ptflops_macs_mean=gb.ptflops_macs.mean(),
        ptflops_macs_std=gb.ptflops_macs.std(),
        thop_macs_mean=gb.thop_macs.mean(),
        thop_macs_std=gb.thop_macs.std(),
    )
)
df_mean

Unnamed: 0_level_0,ptflops_macs_mean,ptflops_macs_std,thop_macs_mean,thop_macs_std
net_size,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
b10c128,1053432000.0,9348.777258,1052373000.0,9318.170885
b15c192,3537091000.0,11684.545004,3534840000.0,11647.713606
b20c256,8394444000.0,14019.621081,8390522000.0,13977.256327
b40c256,16707090000.0,16354.697158,16699440000.0,16306.799048
b60c320,38891380000.0,18689.773234,38877210000.0,18636.341769
b6c96,350580500.0,7013.009512,350054800.0,6988.628164


In [8]:
macs = (df_mean.ptflops_macs_mean + df_mean.thop_macs_mean) / 2
MACS_DICT = dict(zip(df_mean.index, macs))
MACS_DICT

{'b10c128': 1052902056.7067871,
 'b15c192': 3535965536.369873,
 'b20c256': 8392483407.783936,
 'b40c256': 16703260895.197998,
 'b60c320': 38884293414.61206,
 'b6c96': 350317665.0437012}

### Compute how many FLOPs KataGo took to train

In [9]:
STRONGEST_CONFIDENTLY_RATED_NET = "kata1-b40c256-s11840935168-d2898845681"
STRONGEST_CONFIDENTLY_RATED_ROWS = parse_network_name(STRONGEST_CONFIDENTLY_RATED_NET)[
    "rows"
]

# https://github.com/lightvector/KataGo/blob/12b8dd4ce74367b4efa4678c9fe11597f55929f5/cpp/configs/training/selfplay8b20.cfg#L96
VISITS_PER_ROW = 1000

tot_flops = 0
prv_rows = 0
for x in (
    df_nets.query(f"rows <= {STRONGEST_CONFIDENTLY_RATED_ROWS} & net_size != 'b60c320'")
    .sort_values("rows")
    .reset_index(drop=True)
    .itertuples()
):
    cur_rows = x.rows

    # 2 FLOPS per MAC
    cur_flops = 2 * MACS_DICT[x.net_size] * VISITS_PER_ROW * (cur_rows - prv_rows)
    tot_flops += cur_flops

    prv_rows = cur_rows

tot_flops

8.272999624612053e+22