### Load libraries

In [None]:
import dataclasses
import os
import pathlib
import random
import re
from collections import defaultdict
from typing import List, Dict, Union

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from statsmodels.stats.proportion import proportion_confint
from sgf_parser import game_info

plt.rcParams.update(
    {
        "pgf.texsystem": "pdflatex",
        "font.family": "serif",
        "font.serif": ["Times"],
        "text.usetex": True,
        "pgf.rcfonts": False,
        "font.size": 10,
        "font.weight": "medium",
    }
)
plt.style.use("tableau-colorblind10")

### Utilities

In [None]:
flatten_2d_list = lambda lists: sum(lists, [])


def parse_for_match(df: pd.DataFrame, victim_name_prefix="cp505-v"):
    adv_is_black = df.b_name.str.contains("adv")
    adv_is_white = df.w_name.str.contains("adv")
    victim_is_black = ~adv_is_black
    victim_is_white = ~adv_is_white
    assert (~adv_is_black == adv_is_white).all()

    df.loc[adv_is_black, "adv_name"] = df.b_name[adv_is_black]
    df.loc[adv_is_white, "adv_name"] = df.w_name[adv_is_white]
    df.loc[adv_is_black, "adv_color"] = "b"
    df.loc[adv_is_white, "adv_color"] = "w"

    df.loc[victim_is_black, "victim_name"] = df.b_name[victim_is_black]
    df.loc[victim_is_white, "victim_name"] = df.w_name[victim_is_white]
    df.loc[victim_is_black, "victim_color"] = "b"
    df.loc[victim_is_white, "victim_color"] = "w"

    df.adv_win = df.adv_color == df.win_color

    df.victim_visits = df.victim_name.str.slice(start=len(victim_name_prefix)).astype(
        int
    )


def parse_sgfs(paths):
    game_infos = flatten_2d_list(
        [
            game_info.read_and_parse_all_files(
                game_info.find_sgf_files(pathlib.Path(path)), fast_parse=True
            )
            for path in paths
        ]
    )
    df = pd.DataFrame(game_infos)
    return df

### cp505 vs cp127

In [None]:
match_path = pathlib.Path("/nas/ucb/tony/go-attack/matches/cp127-vs-cp505/summary.txt")
lines = match_path.read_text().splitlines()

elo_lines_start_idx = lines.index("Elos (+/- one approx standard error):")
elo_lines_end_idx = lines.index(
    "Pairwise approx % likelihood of superiority of row over column:"
)
elo_lines = lines[elo_lines_start_idx + 1 : elo_lines_end_idx]

bot_entries: List[Dict[str, Union[float, int, str]]] = []
for elo_line in elo_lines:
    name = elo_line.split(" ")[0]
    bot_entries.append(
        {
            "name": name.split("-v")[0],
            "visits": int(name.split("-v")[1]),
            "elo": float(elo_line.split(":")[1].split("+/-")[0]),
            "std": float(elo_line.split("+/-")[1]),
        }
    )

df = pd.DataFrame(bot_entries)
df.head()
df.elo -= df.elo.min()

In [None]:
print(df.elo.max())
fig, axs = plt.subplots(
    1,
    1,
    constrained_layout=True,
    figsize=(5.50107, 3),
    dpi=240,
)
sns.lineplot(data=df, x="visits", y="elo", hue="name")
plt.xscale("log")

### A-MCTS-R and A-MCTS-S++ vs. varying victim visits

In [None]:
def plot_victim_visit_sweep(df, victim_label, plot_name):
    """Plot performance of A-MCTS-{R,S++}-v200 vs. varying victim visits."""
    parse_for_match(df)
    fig, axs = plt.subplots(
        1,
        1,
        constrained_layout=True,
        figsize=(2.64049, 2),
        dpi=240,
    )

    (
        100
        * df.query("adv_name.str.contains('ov1') == False & victim_visits <= 128")
        .groupby("victim_visits")
        .mean(numeric_only=True)
        .adv_win
    ).plot(label="A-MCTS-R", zorder=10, linestyle="--")
    (
        100
        * df.query("adv_name.str.contains('ov1')")
        .groupby("victim_visits")
        .mean(numeric_only=True)
        .adv_win
    ).plot(label="A-MCTS-S++")

    plt.ylim(-5, 105)
    plt.xscale("log")
    plt.ylabel(f"Adv. win \\% vs. {victim_label}")
    # plt.title("Adversary against cp505")
    plt.xlabel("Victim visits")
    plt.legend()

    max_victim_visits = df.victim_visits.max()
    major_ticks = [
        10**i for i in range(0, int(np.floor(np.log10(max_victim_visits))) + 1)
    ]
    minor_ticks = flatten_2d_list(
        [list(range(i, min(10 * i, max_victim_visits + i), i)) for i in major_ticks]
    )
    plt.xticks(major_ticks)
    plt.xticks(minor_ticks, minor=True)

    plt.savefig(f"{plot_name}.pgf", backend="pgf")

In [None]:
# adversary s34mil from unhardened training run
df = parse_sgfs(
    [
        "/nas/ucb/tony/go-attack/matches/cp505-perfect-victim-modeling",
        "/nas/ucb/tony/go-attack/matches/cp505-ov1",
    ]
)
plot_victim_visit_sweep(df, "$\\texttt{Latest}$", "adv-vs-cp505-vary-visits")

In [None]:
# adversary 349mil from hardened training run
df = parse_sgfs(["/nas/ucb/k8/go-attack/match/ttseng-hard-vic-v-sweep-20221108-182437"])
plot_victim_visit_sweep(
    df, "$\\texttt{Latest}_\\texttt{def}$", "adv-vs-cp505h-vary-visits"
)

In [None]:
# adversary 497mil from hardened training run
df = parse_sgfs(
    [
        "/nas/ucb/k8/go-attack/match/ttseng-hard-vic-v-sweep-s497m-20221114-214156",
        "/nas/ucb/k8/go-attack/match/ttseng-hard-vic-v-sweep-s497m-ov1-20221114-213705",
    ]
)
plot_victim_visit_sweep(
    df, "$\\texttt{Latest}_\\texttt{def}$", "adv-497mil-vs-cp505h-vary-visits"
)

### Perfect victim modeling cp127 (old adversary)

In [None]:
sgf_paths = game_info.find_sgf_files(
    pathlib.Path("/nas/ucb/tony/go-attack/matches/cp127-perfect-victim-modeling/sgfs")
)

game_infos = list(
    game_info.read_and_parse_all_files(
        sgf_paths,
        fast_parse=True,
    )
)

df = pd.DataFrame(game_infos)
parse_for_match(df)

In [None]:
print(df.victim_name.unique())
print(df.victim_visits.unique())
print(df.adv_name.unique())

In [None]:
fig, axs = plt.subplots(
    1,
    1,
    constrained_layout=True,
    figsize=(2.64049, 2),
    dpi=240,
)

(
    100
    * df.query("adv_name == 'adv-s35783424-v200'")
    .groupby("victim_visits")
    .mean(numeric_only=True)
    .adv_win
).plot(label="A-MCTS-R")
(
    100
    * df.query("adv_name == 'adv-s35783424-v200-ov1'")
    .groupby("victim_visits")
    .mean(numeric_only=True)
    .adv_win
).plot(label="A-MCTS-S++")
(
    100
    * df.query("adv_name == 'adv-s35783424-v200-ov1-os1'")
    .groupby("victim_visits")
    .mean(numeric_only=True)
    .adv_win
).plot(label="A-MCTS-S")

plt.ylim(-5, 105)
plt.xscale("log")
plt.ylabel("Adv. win \\%")
# plt.title("Adversary against cp127")
plt.xlabel("cp127 visits")
plt.legend()

### Strongest cp127 adversary

In [None]:
sgf_paths = game_info.find_sgf_files(
    pathlib.Path("/nas/ucb/tony/go-attack/matches/cp127-cp505-vs-strong-adv/sgfs/")
)

game_infos = list(
    game_info.read_and_parse_all_files(
        sgf_paths,
        fast_parse=True,
    )
)

df = pd.DataFrame(game_infos)
parse_for_match(df)
df["victim_net"] = df.victim_name.str.slice(stop=len("cp505"))
len(df)

In [None]:
print(df.victim_name.unique())
print(df.victim_net.unique())
print(df.victim_visits.unique())
print(df.adv_name.unique())

In [None]:
plt.figure(figsize=(14, 4))

victim_net: str
adv_name: str
for i, victim_net in enumerate(df.victim_net.unique()):
    plt.subplot(1, 2, i + 1)
    for adv_name in df.adv_name.unique():
        if not adv_name.endswith("-os1"):
            continue
        df.query(f"adv_name == '{adv_name}' & victim_net == '{victim_net}'").groupby(
            "victim_visits"
        ).mean(numeric_only=True).adv_win.plot(label=adv_name)

    plt.xscale("log")
    plt.ylabel("adv. win rate")
    plt.title(f"Adversary against {victim_net}")
    plt.xlabel(f"{victim_net} visits")
    plt.legend()

### Strongest cp505 adversary

In [None]:
sgf_paths = game_info.find_sgf_files(
    pathlib.Path("/nas/ucb/tony/go-attack/matches/cp505-adv-emcts1.4/sgfs")
)

game_infos = list(
    game_info.read_and_parse_all_files(
        sgf_paths,
        fast_parse=True,
    )
)

df = pd.DataFrame(game_infos)

# Filter out games with two adversaries
adv_is_black = df.b_name.str.contains("adv")
adv_is_white = df.w_name.str.contains("adv")
df = df[~(adv_is_black & adv_is_white)].copy()
parse_for_match(df, victim_name_prefix="bot-cp127-v")
df["victim_net"] = df.victim_name.str.slice(start=len("bot-"), stop=len("bot-cp505"))
len(df)

In [None]:
print(df.victim_name.unique())
print(df.victim_net.unique())
print(df.victim_visits.unique())
print(df.adv_name.unique())

In [None]:
plt.figure(figsize=(14, 4))

victim_net: str
adv_name: str
for i, victim_net in enumerate(df.victim_net.unique()):
    plt.subplot(1, 2, i + 1)
    for adv_name in df.adv_name.unique():
        df.query(f"adv_name == '{adv_name}' & victim_net == '{victim_net}'").groupby(
            "victim_visits"
        ).mean(numeric_only=True).adv_win.plot(label=adv_name)

    plt.xscale("log")
    plt.ylabel("adv. win rate")
    plt.title(f"Adversary against {victim_net}")
    plt.xlabel(f"{victim_net} visits")
    plt.legend()

In [None]:
fig, axs = plt.subplots(
    1,
    1,
    constrained_layout=True,
    figsize=(2.64049, 2),
    dpi=240,
)

adv_name: str = "adv-cp505-v1-s34090496-v600"
victim_net: str
for victim_net in reversed(df.victim_net.unique()):
    (
        100
        * df.query(f"adv_name == '{adv_name}' & victim_net == '{victim_net}'")
        .groupby("victim_visits")
        .mean(numeric_only=True)
        .adv_win
    ).plot(
        label="$\\texttt{Latest}$" if victim_net == "cp505" else "$\\texttt{cp127}$",
        linestyle="--" if victim_net == "cp127" else None,
    )

plt.ylim(-5, 105)
plt.xscale("log")
plt.ylabel("$\\texttt{Latest}$-trained adv. win \\%")
# plt.title("cp505 adversary vs. victims")
plt.xlabel("Victim visits")
plt.legend()

plt.savefig("adv505-transfer.pgf", backend="pgf")

### b10 cp127 adversary

In [None]:
sgf_paths = game_info.find_sgf_files(
    pathlib.Path("/nas/ucb/ttseng/go_attack/match/b10-vs-v1to1024/sgfs")
)

game_infos = list(
    game_info.read_and_parse_all_files(
        sgf_paths,
        fast_parse=True,
    )
)

df = pd.DataFrame(game_infos)
parse_for_match(df, victim_name_prefix="bot-cp505-v")
df["victim_net"] = df.victim_name.str.slice(start=len("bot-"), stop=len("bot-cp505"))
len(df)

In [None]:
print(df.victim_name.unique())
print(df.victim_net.unique())
print(df.victim_visits.unique())
print(df.adv_name.unique())

In [None]:
fig, axs = plt.subplots(
    1,
    1,
    constrained_layout=True,
    figsize=(2.64049, 2),
    dpi=240,
)

victim_net: str
for victim_net in df.victim_net.unique():
    (
        100
        * df.query(f"victim_net == '{victim_net}'")
        .groupby("victim_visits")
        .mean(numeric_only=True)
        .adv_win
    ).plot(
        label="$\\texttt{Latest}$" if victim_net == "cp505" else "$\\texttt{cp127}$",
        linestyle="--" if victim_net == "cp127" else None,
    )

plt.ylim(-5, 105)
plt.xscale("log")
plt.ylabel("$\\texttt{cp127}$-trained adv. win \\%")
# plt.title("cp505 adversary vs. victims")
plt.xlabel("Victim visits")

handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles[::-1], labels[::-1])
# plt.legend();
plt.savefig("adv127-transfer.pgf", backend="pgf")

### A-MCTS-S, varying adv. visits

In [None]:
def plot_adv_visit_sweep(df, df2, victim_visits, victim_label, plot_name):
    """Plot performance of A-MCTS-S with varying adversary visits.

    df = data for adversary visit sweep experiment
    df2 = data for victim visit sweep experiment
    """
    parse_for_match(df)
    parse_for_match(df2)

    visit_query = f"victim_visits == {victim_visits}"
    print(
        "A-MCTS-S win rate:\n",
        df.query(visit_query).groupby("adv_visits").mean(numeric_only=True).adv_win,
    )

    df2_points = (
        df2.query(visit_query).groupby("adv_name").mean(numeric_only=True).adv_win
    )
    print("A-MCTS-{S++,R} win rate:\n", df2_points)
    assert df2_points.shape[0] <= 2
    splusplus_point = None
    perfect_modeling_point = None
    for name in df2_points.index:
        if "ov1" in name:
            splusplus_point = df2_points[name]
        else:
            perfect_modeling_point = df2_points[name]

    fig, ax = plt.subplots(
        1,
        1,
        constrained_layout=True,
        figsize=(2.64049, 2),
        dpi=240,
    )

    (
        100
        * df.query(visit_query).groupby("adv_visits").mean(numeric_only=True).adv_win
    ).plot(label="A-MCTS-S")

    if splusplus_point is not None:
        plt.plot(
            200,
            100 * splusplus_point,
            "+",
            label="A-MCTS-S++",
            markersize=10,
            color="tab:red",
        )

    if perfect_modeling_point is not None:
        plt.plot(
            200,
            100 * perfect_modeling_point,
            "o",
            label="A-MCTS-R",
            markersize=6,
            color="tab:green",
        )

    plt.ylim(-5, 105)
    plt.xscale("log")

    # https://stackoverflow.com/a/73094650/1337463
    from matplotlib import ticker as mticker

    ax.xaxis.set_major_locator(mticker.LogLocator(numticks=999))
    ax.xaxis.set_minor_locator(mticker.LogLocator(numticks=999, subs="auto"))

    plt.xlabel(f"Adversary visits")
    plt.ylabel(f"Adv. win \\% vs. {victim_label}")
    plt.legend()
    # smaller font size to avoid covering up points
    # plt.legend(prop={"size": 8})
    plt.savefig(f"{plot_name}.pgf", backend="pgf")

In [None]:
df = parse_sgfs(["/nas/ucb/tony/go-attack/matches/cp505-v64-vs-adv-1-to-8192/sgfs"])
df2 = parse_sgfs(
    [
        "/nas/ucb/tony/go-attack/matches/cp505-perfect-victim-modeling",
        "/nas/ucb/tony/go-attack/matches/cp505-ov1",
    ]
)
plot_adv_visit_sweep(df, df2, 64, "$\\texttt{Latest}$", "adv-vs-cp505-vary-visits2")

In [None]:
df = parse_sgfs(
    ["/nas/ucb/k8/go-attack/match/ttseng-hard-adv-v-sweep-v16-v32-20221109-102538"]
)
df2 = parse_sgfs(
    ["/nas/ucb/k8/go-attack/match/ttseng-hard-vic-v-sweep-20221108-182437"]
)
plot_adv_visit_sweep(
    df, df2, 32, "$\\texttt{Latest}_\\texttt{def}$", "adv-vs-cp505h-v32-vary-adv-visits"
)

In [None]:
df = parse_sgfs(
    ["/nas/ucb/k8/go-attack/match/ttseng-hard-adv-v-sweep-s497m-20221114-215552"]
)
df2 = parse_sgfs(
    [
        "/nas/ucb/k8/go-attack/match/ttseng-hard-vic-v-sweep-s497m-20221114-214156",
        "/nas/ucb/k8/go-attack/match/ttseng-hard-vic-v-sweep-s497m-ov1-20221114-213705",
    ]
)
print(df2.adv_name.unique())
plot_adv_visit_sweep(
    df,
    df2,
    2048,
    "$\\texttt{Latest}_\\texttt{def}$",
    "adv-497mil-vs-cp505h-v2048-vary-adv-visits",
)

### Elo combined

In [None]:
import pathlib
from typing import Dict, List, Union

fig, axs = plt.subplots(1, 1, constrained_layout=True, figsize=(5.50107, 3), dpi=240)
plt.subplot(1, 1, 1)

match_path = pathlib.Path("/nas/ucb/tony/go-attack/matches/visit-exp3/summary.txt")
lines = match_path.read_text().splitlines()
elo_lines_start_idx = lines.index("Elos (+/- one approx standard error):")
elo_lines_end_idx = lines.index(
    "Pairwise approx % likelihood of superiority of row over column:"
)
elo_lines = lines[elo_lines_start_idx + 1 : elo_lines_end_idx]
# elo_lines.append("cp103-v1600         :   128.86 +/- 13.25")

bot_entries: List[Dict[str, Union[float, int, str]]] = []
for elo_line in elo_lines:
    name = elo_line.split(" ")[0]
    entry = {
        "name": name,
        "rank": name[3] if "bot" in name else name.split("-")[0],
        "visits": int(name.split("-v")[1]),
        "elo": float(elo_line.split(":")[1].split("+/-")[0]) - 483.746,
        "std": float(elo_line.split("+/-")[1]) - 483.746,
    }
    if "cp127" not in entry["rank"]:
        bot_entries.append(entry)

match_path = pathlib.Path("/nas/ucb/tony/go-attack/matches/cp127-vs-cp505/summary.txt")
lines = match_path.read_text().splitlines()
elo_lines_start_idx = lines.index("Elos (+/- one approx standard error):")
elo_lines_end_idx = lines.index(
    "Pairwise approx % likelihood of superiority of row over column:"
)
elo_lines = lines[elo_lines_start_idx + 1 : elo_lines_end_idx]

for elo_line in elo_lines:
    name = elo_line.split(" ")[0]
    entry = {
        "name": name,
        "rank": name[3] if "bot" in name else name.split("-")[0],
        "visits": int(name.split("-v")[1]),
        "elo": float(elo_line.split(":")[1].split("+/-")[0]) + 591.59,
        "std": float(elo_line.split("+/-")[1]) + 591.59,
    }
    # if 'cp505' in entry['rank']:
    bot_entries.append(entry)

df = pd.DataFrame(bot_entries)
# with pd.option_context('display.max_rows', None, 'display.max_columns', None):
#     print('df', df)
for rank in sorted(df["rank"].unique()):
    df[df["rank"] == rank].groupby("visits").mean(numeric_only=True).elo.plot(
        label=rank
    )
plt.plot(1600, 800.31, "rx", label="cp103")
plt.ylabel("Elo")
plt.xlabel("Visits")
plt.xscale("log")
handles, labels = plt.gca().get_legend_handles_labels()
labels, handles = zip(
    *sorted(zip(labels, handles), key=lambda t: int(t[0].split("p")[1]))
)
handles = list(handles)[::-1]
labels = list(labels)[::-1]
for i in range(len(labels)):
    orig_label = labels[i]
    labels[i] = f"$\\texttt{{{orig_label}}}$"
    if orig_label == "cp505":
        labels[i] += " ($\\texttt{Latest}$)"
    elif orig_label == "cp103":
        labels[i] += " ($\\texttt{Original}$)"
plt.legend(
    handles,
    labels,
    title="Network",
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    fancybox=True,
)

plt.savefig("elo-by-visits.pgf", backend="pgf")

In [None]:
# Elo curve again, but with updated match configs that are closer to the original KataGo configs

fig, axs = plt.subplots(
    1,
    1,
    constrained_layout=True,
    figsize=(5.50107, 3),
    dpi=240,
)
plt.subplot(1, 1, 1)

match_path = pathlib.Path("/nas/ucb/ttseng/go_attack/match/elo-221115/summary.txt")
lines = match_path.read_text().splitlines()
elo_lines_start_idx = lines.index("Elos (+/- one approx standard error):")
elo_lines_end_idx = lines.index(
    "Pairwise approx % likelihood of superiority of row over column:"
)
elo_lines = lines[elo_lines_start_idx + 1 : elo_lines_end_idx]

bot_entries: List[Dict[str, Union[float, int, str]]] = []
for elo_line in elo_lines:
    name = elo_line.split(" ")[0]
    entry = {
        "name": name,
        "rank": name[3] if "bot" in name else name.split("-")[0],
        "visits": int(name.split("-v")[1]),
        "elo": float(elo_line.split(":")[1].split("+/-")[0]) + 562.778,
        "std": float(elo_line.split("+/-")[1]) + 562.778,
    }
    bot_entries.append(entry)

df = pd.DataFrame(bot_entries)
for rank in sorted(df["rank"].unique()):
    df[df["rank"] == rank].groupby("visits").mean(numeric_only=True).elo.plot(
        label=rank
    )
plt.plot(1600, 800.31, "rx", label="cp103")
plt.ylabel("Elo")
plt.xlabel("Visits")
plt.xscale("log")
handles, labels = plt.gca().get_legend_handles_labels()


def label_to_key(label):
    checkpoint_num = int(re.sub("[^0-9]", "", label))
    is_hardened = label[-1] == "h"
    return checkpoint_num - (0.5 if is_hardened else 0)


labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: label_to_key(t[0])))
handles = list(handles)[::-1]
labels = list(labels)[::-1]
for i in range(len(labels)):
    orig_label = labels[i]
    labels[i] = f"$\\texttt{{{orig_label}}}$"
    if orig_label == "cp505":
        labels[i] += " ($\\texttt{Latest}$)"
    elif orig_label == "cp505h":
        labels[i] = "$\\texttt{cp505}_\\texttt{def}$ ($\\texttt{Latest}_\\texttt{def}$)"
    elif orig_label == "cp103":
        labels[i] += " ($\\texttt{Original}$)"
plt.legend(
    handles,
    labels,
    title="Network",
    loc="upper left",
    bbox_to_anchor=(1, 1),
    fancybox=True,
)

plt.savefig("elo-by-visits-2.pgf", backend="pgf")

# TODO(tomtseng):
# * rerun this after /nas/ucb/ttseng/go_attack/match/elo-221115/ is actually done
# * shift everything to put cp505-v1 at 0
# * remove cp103 red x assuming it matches up (after shifting)

### Win rate across training

In [None]:
def get_victim_active_ranges(df):
    """Get victims' active ranges during training."""
    # Filter to only normal games
    df = df[(df.gtype == "normal")]
    df["victim_name_v2"] = (
        df.victim_name.str.strip("kata1-").str.strip(".bin.gz").str.strip(".txt.gz")
        + "-v"
        + df.victim_visits.astype("str")
    )
    df19 = df[df.board_size == 19]

    min_dict = (
        df19[["victim_name_v2", "adv_steps"]].groupby("victim_name_v2").min().adv_steps
    )
    max_dict = (
        df19[["victim_name_v2", "adv_steps"]].groupby("victim_name_v2").max().adv_steps
    )

    victim_ranges: Dict[str, Tuple[int, int]] = {}
    for v in df19.victim_name_v2.unique():
        start = min_dict[v]
        end = max_dict[v]
        victim_ranges[v] = (start, end)
    victim_ranges = dict(sorted(victim_ranges.items(), key=lambda x: x[1][1]))
    return victim_ranges


def get_victim_change_steps(df):
    """Get steps at which victim changes during training."""
    return [r[0] for r in get_victim_active_ranges(df).values()]


def plot_training(
    df,
    victim_name_to_plot_label,
    highlighted_point_step,
    victim_change_steps,
    plot_name,
    ignored_adversaries=[],
    legend_ncol=None,
):
    """Plot win rate throughout training.

    Params:
        df: Data
        victim_name_to_plot_label: Key = victims to plot, value = label on plot legend
        highlighted_point_step: Point (specified by adversary training steps) to mark
          with a special marker on the plot
        victim_changes_steps: Training steps at which the victim changed
        plot_name: Filename of plot
        ignored_adversaries: Names of adversaries to ignore due to bad data
        legend_ncol: Override the number of legend_columns
    """
    df19 = df.loc[df.board_size == 19]
    for adv in ignored_adversaries:
        df19 = df19.loc[df19.adv_name != adv]

    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    fig, axs = plt.subplots(
        1, 1, constrained_layout=True, figsize=(5.50107, 3), dpi=240
    )

    ALPHA = 0.05
    df19["adv_win_perc"] = df.adv_win * 100
    for i, (victim_name, victim_label) in enumerate(victim_name_to_plot_label.items()):
        victim_df = df19[df19.victim_name == victim_name]
        ax = (
            victim_df.groupby("adv_steps")
            .mean(True)
            .adv_win_perc.plot(label=victim_label)
        )
        df_confint = victim_df.groupby("adv_steps").mean(True)
        conf_df = df_confint.apply(
            lambda x: proportion_confint(
                len(victim_df[(victim_df.adv_steps == x.name) & (victim_df.adv_win)]),
                len(victim_df[victim_df.adv_steps == int(x.name)]),
                alpha=ALPHA,
                # Generate Clopper-Pearson confidence intervals:
                # https://www.statsmodels.org/dev/generated/statsmodels.stats.proportion.proportion_confint.html
                method="beta",
            ),
            axis=1,
            result_type="expand",
        )
        plt.fill_between(
            df_confint.index, conf_df[0] * 100, conf_df[1] * 100, alpha=0.3
        )
        plt.plot(
            highlighted_point_step,
            victim_df.groupby("adv_steps")
            .mean(True)
            .adv_win_perc.loc[highlighted_point_step]
            .item(),
            "D",
            color=colors[i],
        )

    max_step = df19.adv_steps.max()
    for xc in victim_change_steps:
        if xc > 0 and xc < max_step:
            plt.axvline(x=xc, ls=":", lw=1, color=colors[3])

    if legend_ncol is None:
        legend_ncol = len(victim_name_to_plot_label)
    plt.ylabel(r"Adversary win rate \%")
    plt.xlabel("Adversary training steps")
    plt.margins(x=0)
    plt.legend(
        loc="lower center", bbox_to_anchor=(0.5, 1.0), ncols=legend_ncol, fancybox=True
    )
    _, x_max = plt.xlim()
    plt.xlim(right=max(x_max, 1.02 * highlighted_point_step))

    plt.savefig(f"{plot_name}.pgf", backend="pgf")

In [None]:
# training_df = parse_sgfs(["/nas/ucb/tony/go-attack/training/emcts1-curr/cp127-to-505-v1/selfplay"])
# victim_change_steps = get_victim_change_steps(training_df)
# print("Changes:", victim_change_steps)
victim_change_steps = [12560640, 16541696, 22263040, 25102336]

df = parse_sgfs(["/nas/ucb/ttseng/go_attack/match/adv-checkpoints"])
# There is a typo in the bot names in this experiment.
# bot-cp505-v2 is actually cp505 with one visit.
df.loc[df.victim_name == "bot-cp505-v2", "victim_name"] = "bot-cp505-v1"

print("All victims:", df.victim_name.unique())
victim_name_to_plot_label = {
    "bot-cp127-v1": r"\texttt{cp127}",
    "bot-cp505-v1": r"\texttt{Latest}",
}
plot_training(
    df,
    victim_name_to_plot_label,
    34090496,
    victim_change_steps,
    "adv-training",
    ignored_adversaries=["adv-s1122816-d254183-v600"],
)

In [None]:
# training_df = parse_sgfs(["/nas/ucb/k8/go-attack/victimplay/ttseng-avoid-pass-alive-coldstart-39-20221025-175949/selfplay"])
# victim_change_steps = get_victim_change_steps(training_df)
# print("Changes:", victim_change_steps)
victim_change_steps = [
    0,
    30365184,
    34502144,
    38566400,
    87698176,
    164908288,
    194422784,
    198203136,
    205049600,
    212038912,
    218883584,
    227013120,
    227013120,
    230932992,
    417232384,
    466991360,
    486887168,
    491808000,
    495654912,
    499575296,
    503639552,
    509699584,
    516545024,
    522673664,
]

df = parse_sgfs(
    ["/nas/ucb/k8/go-attack/match/ttseng-hard-adv-checkpoint-sweep-20221108"]
)
print("All victims:", df.victim_name.unique())
victim_name_to_plot_label = {
    "cp39h-v1": r"$\texttt{cp39}_\texttt{def}$",
    "cp127h-v1": r"$\texttt{cp127}_\texttt{def}$",
    "cp505h-v1": r"$\texttt{Latest}_\texttt{def}$",
}
plot_training(
    df,
    victim_name_to_plot_label,
    349284096,
    victim_change_steps,
    "adv-training-hardened",
)

In [None]:
victim_change_steps = [
    0,
    30365184,
    34502144,
    38566400,
    87698176,
    164908288,
    194422784,
    198203136,
    205049600,
    212038912,
    218883584,
    227013120,
    227013120,
    230932992,
    417232384,
    466991360,
    486887168,
    491808000,
    495654912,
    499575296,
    503639552,
    509699584,
    516545024,
    522673664,
]
df = parse_sgfs(
    ["/nas/ucb/k8/go-attack/match/ttseng-hard-adv-checkpoint-sweep-497mil-221115"]
)
print("All victims:", df.victim_name.unique())
victim_name_to_plot_label = {
    "cp39h-v1": r"$\texttt{cp39}_\texttt{def}$",
    "cp127h-v1": r"$\texttt{cp127}_\texttt{def}$",
    "cp505h-v1": r"$\texttt{Latest}_\texttt{def}$",
    "cp505h-v2048": r"$\texttt{Latest}_\texttt{def}$ (2048 visits)",
}
plot_training(
    df,
    victim_name_to_plot_label,
    497721856,
    victim_change_steps,
    "adv-497mil-training-hardened",
    legend_ncol=2,
)

### Baseline attacks

In [None]:
edge_x, edge_y, edge_cis, edge_winrate, edge_winrate_cis = (
    np.array([2, 4, 8, 16, 32]),
    np.array([148.75, 125.57, 270.91, 311.96, 311.07]),
    np.array([32.64096667, 31.42023977, 26.51559246, 14.94472483, 16.49799086]),
    np.array([0.51, 0.57, 0.17, 0.01, 0.03]),
    np.array([0.0979804, 0.09703485, 0.07362403, 0.01950175, 0.0334351]),
)

In [None]:
mirror_x, mirror_y, mirror_cis, mirror_winrate, mirror_winrate_cis = (
    np.array([2, 4, 8, 16, 32]),
    np.array([207.19, 239.88709677, 256.18041237, 276.08646617, 253.62]),
    np.array([25.79337205, 16.40883013, 21.08132841, 12.34433871, 22.5459692]),
    np.array([0.09, 0.05069124, 0.04123711, 0.03007519, 0.04]),
    np.array([0.05609163, 0.02918747, 0.03957035, 0.02052525, 0.038408]),
)

In [None]:
spiral_x, spiral_y, spiral_cis, spiral_winrate, spiral_winrate_cis = (
    np.array([2, 4, 8, 16, 32]),
    np.array([304.79, 291.67, 323.21, 323.89, 323.67]),
    np.array([17.63109707, 22.33512761, 11.17424501, 10.9122512, 9.67025367]),
    np.array([0.06, 0.11, 0.0, 0.0, 0.0]),
    np.array([0.04654742, 0.06132639, 0.0, 0.0, 0.0]),
)

In [None]:
fig, ax = plt.subplots(1, 1, constrained_layout=True, figsize=(5.50107, 3), dpi=240)
ax.set_xlabel("Visit count")
ax.set_xscale("log", base=2)
ax.set_ylabel("Average KataGo win margin")
ax.set_xticks(ticks=edge_x, labels=edge_x)
ax.plot(edge_x, edge_y, label="Edge attack")
ax.plot(mirror_x, mirror_y, label="Mirror attack")
ax.plot(spiral_x, spiral_y, label="Spiral attack")
ax.fill_between(edge_x, edge_y - edge_cis, edge_y + edge_cis, alpha=0.2)
ax.fill_between(mirror_x, mirror_y - mirror_cis, mirror_y + mirror_cis, alpha=0.2)
ax.fill_between(spiral_x, spiral_y - spiral_cis, spiral_y + spiral_cis, alpha=0.2)
ax.legend(loc="lower right")
fig.savefig("baseline-attack-win-margins.pgf", backend="pgf")

In [None]:
fig, ax = plt.subplots(1, 1, constrained_layout=True, figsize=(5.50107, 3), dpi=240)
ax.set_xlabel("Visit count")
ax.set_xscale("log", base=2)
ax.set_ylabel("Adversary win \%")

ax.set_xticks(ticks=edge_x, labels=edge_x)
yticks = np.arange(0.0, 0.8, 0.1)
ax.set_yticks(ticks=yticks, labels=[f"{100 * y:.0f}" for y in yticks])
ax.plot(edge_x, edge_winrate, label="Edge attack")
ax.plot(mirror_x, mirror_winrate, label="Mirror attack")
ax.plot(spiral_x, spiral_winrate, label="Spiral attack")
ax.fill_between(
    edge_x, edge_winrate - edge_winrate_cis, edge_winrate + edge_winrate_cis, alpha=0.2
)
ax.fill_between(
    mirror_x,
    mirror_winrate - mirror_winrate_cis,
    mirror_winrate + mirror_winrate_cis,
    alpha=0.2,
)
ax.fill_between(
    spiral_x,
    spiral_winrate - spiral_winrate_cis,
    spiral_winrate + spiral_winrate_cis,
    alpha=0.2,
)
ax.legend(loc="upper right")
fig.savefig("baseline-attack-winrates.pgf", backend="pgf")

### Figure 5a & 5b

In [None]:
from pathlib import Path
from sgfmill.sgf import Sgf_game


root = Path("/nas/ucb/norabelrose/latest@1600-dragonslayer-fixed/cleaned")
games = [
    Sgf_game.from_string(p.open().read())
    for p in root.rglob('*.sgf')
]

In [None]:
import numpy as np
import re


def get_stats(games: list, time_to_end: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Return (`adv_win_rates`, `victim_win_rates`, `gt_adv_wins`) as a tuple.
    
    `time_to_end` is a boolean that determines whether to align games s.t. they all end at the same time.
    """
    game_lengths = np.array([len(g.get_main_sequence()) - 1 for g in games])
    num_moves = min(game_lengths) if time_to_end else max(game_lengths)

    adv_win_rates = np.full((len(games), num_moves), np.nan)
    victim_win_rates = np.full((len(games), num_moves), np.nan)
    gt_adv_wins = np.full(len(games), np.nan)

    # Regex that matches exactly 4 decimal numbers separated by spaces
    probs_regex = re.compile(r'^(\d\.\d+) (\d\.\d+) (\d\.\d+) (-?\d+\.\d+)')
    for i, game in enumerate(games):
        root = game.get_root()
        if not root.has_property("RE"):
            continue

        black, white = root.get('PB').strip('"'), root.get('PW').strip('"')
        assert sorted([black, white]) == ['adversary', 'victim']

        # By default we assume adversary is playing white; flip probs if needed
        adv_color = "b" if black == 'adversary' else "w"
        sequence = game.get_main_sequence()
        sequence.pop(0)  # Remove root node
        if time_to_end:
            sequence = sequence[-num_moves:]

        winner, _ = root.get("RE").split("+")
        adv_won = winner.lower() == adv_color
        gt_adv_wins[i] = adv_won

        # The first node is the root, which has no move, so skip it
        for j, node in enumerate(sequence):
            comment = node.get('C').strip()
            assert comment

            maybe_match = probs_regex.match(comment)
            assert maybe_match

            # These are all from white's perspective
            win, loss, draw, score = map(float, maybe_match.groups())
            node.set('C', f'{win:.2f} {loss:.2f} {draw:.2f} {score:.2f}')
            if adv_color == "b":
                win, loss = loss, win
                score = -score

            cur_color, _ = node.get_move()
            if cur_color == adv_color:
                # Winrate probabilities in KataGo are logged up to the 2nd decimal place,
                # so we clip them to [0.01, 0.99] to avoid numerical issues in the BCE
                # calculation
                adv_win_rates[i, j] = np.clip(win, 0.01, 0.99)
            else:
                victim_win_rates[i, j] = np.clip(win, 0.01, 0.99)
    
    return adv_win_rates, victim_win_rates, gt_adv_wins


#### Set Matplotlib style params

In [None]:
import matplotlib.pyplot as plt

plt.rcParams.update(
    {
        "figure.subplot.left": 0.2,
        "figure.subplot.right": 0.95,
        "figure.subplot.bottom": 0.2,
        "figure.subplot.top": 0.85,
        "font.family": "serif",
        "font.serif": ["Times"],
        "font.size": 10,
        "font.weight": "medium",
        "pgf.rcfonts": False,
        "pgf.texsystem": "pdflatex",
        "text.usetex": True,
    }
)
plt.style.use("tableau-colorblind10")

In [None]:
def plot(x, label, color):
    mean_victim_loss = np.nanmean(x, axis=0)
    victim_loss_std = np.nanstd(x, axis=0)

    kernel = np.ones(10) / 10
    smoothed_victim_bce_loss = np.convolve(mean_victim_loss, kernel, mode='valid')
    smoothed_victim_bce_loss_std = np.convolve(victim_loss_std, kernel, mode='valid')

    plt.plot(smoothed_victim_bce_loss, label=label, c=color)
    plt.xlabel('Moves until game end')
    plt.ylabel('BCE loss (nats)')

    plt.fill_between(
        np.arange(len(smoothed_victim_bce_loss)),
        np.clip(smoothed_victim_bce_loss - smoothed_victim_bce_loss_std, a_min=0, a_max=None),
        np.clip(smoothed_victim_bce_loss + smoothed_victim_bce_loss_std, a_min=0, a_max=None),
        alpha=0.2,
        color=color,
    )
    ticks = np.arange(0, len(smoothed_victim_bce_loss), 40)
    plt.xticks(ticks, ticks[::-1])
    print(ticks)

fig = plt.figure(figsize=(2.6448 + 0.15, 2.6448), dpi=150)

adv_win_rates, victim_win_rates, gt_adv_wins = get_stats(games, time_to_end=True)
raw_victim_loss = -(gt_adv_wins[:, None] * np.log(victim_win_rates) + (1 - gt_adv_wins[:, None]) * np.log(1 - victim_win_rates))
plot(raw_victim_loss[gt_adv_wins.astype(bool)], label="Adv. win", color=(0.5529411764705883, 0.8980392156862745, 0.6313725490196078))
plot(raw_victim_loss[~gt_adv_wins.astype(bool)], "Victim win", color=(0.8156862745098039, 0.7333333333333333, 1.0))
plt.legend(loc="lower center", bbox_to_anchor=(0.44, 1.0), ncol=2, fancybox=True)

plt.savefig('fig-5a.pgf', backend='pgf')

In [None]:
plt.rcParams.update(
    {
        "figure.subplot.left": 0.2,
        "figure.subplot.right": 0.95,
        "figure.subplot.bottom": 0.2,
        "figure.subplot.top": 0.85,
    }
)

adv_win_rates, victim_win_rates, gt_adv_wins = get_stats(games)
adv_trunc = adv_win_rates[0, :220:2]
victim_trunc = victim_win_rates[0, 1:220:2]

plt.figure(figsize=(2.6448 + 0.15, 2.6448), dpi=150)
plt.plot(adv_trunc, label="Adv.")
plt.plot(victim_trunc, label="Victim")
plt.xticks(
    np.arange(80, 110, 5),
    np.arange(80, 110, 5) * 2
)
plt.xlim(80, 110)

plt.legend(loc="lower center", bbox_to_anchor=(0.5, 1.0), ncol=2, fancybox=True)
plt.xlabel("Move number")
plt.ylabel("Probability of adversary win")

plt.savefig('fig-5b.pgf', backend='pgf')