From 9c25848fb5cb18531c639a462d0a8314bd17e20f Mon Sep 17 00:00:00 2001 From: Xmaster6y <66315201+Xmaster6y@users.noreply.github.com> Date: Wed, 15 May 2024 10:36:26 +0200 Subject: [PATCH] new script --- .vscode/launch.json | 47 +------- scripts/lrp/__init__.py | 0 scripts/lrp/plane_analysis.py | 217 ++++++++++++++++++++++++++++++++++ scripts/results/.gitignore | 2 + scripts/visualisation.py | 200 +++++++++++++++++++++++++++++++ 5 files changed, 421 insertions(+), 45 deletions(-) create mode 100644 scripts/lrp/__init__.py create mode 100644 scripts/lrp/plane_analysis.py create mode 100644 scripts/results/.gitignore create mode 100644 scripts/visualisation.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 22e8014..f0eaf15 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -19,53 +19,10 @@ "justMyCode": false }, { - "name": "Script CRP concepts", + "name": "Script Plane Analysis", "type": "debugpy", "request": "launch", - "module": "scripts.find_concepts", - "console": "integratedTerminal", - "justMyCode": false - }, - { - "name": "Script CRP clusters", - "type": "debugpy", - "request": "launch", - "module": "scripts.cluster_latent_relevances", - "console": "integratedTerminal", - "justMyCode": false - }, - { - "name": "Script make datasets", - "type": "debugpy", - "request": "launch", - "module": "scripts.make_datasets", - "console": "integratedTerminal", - "justMyCode": false - } - , - { - "name": "Script sample exploration", - "type": "debugpy", - "request": "launch", - "module": "scripts.sample_exploration", - "console": "integratedTerminal", - "justMyCode": false - } - , - { - "name": "Script simple sae", - "type": "debugpy", - "request": "launch", - "module": "scripts.simple_sae", - "console": "integratedTerminal", - "justMyCode": false - } - , - { - "name": "Debug", - "type": "debugpy", - "request": "launch", - "module": "ignored.debug", + "module": "scripts.lrp.plane_analysis", "console": "integratedTerminal", "justMyCode": false } diff --git a/scripts/lrp/__init__.py b/scripts/lrp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/lrp/plane_analysis.py b/scripts/lrp/plane_analysis.py new file mode 100644 index 0000000..5479a59 --- /dev/null +++ b/scripts/lrp/plane_analysis.py @@ -0,0 +1,217 @@ +"""Script to compute the importance of each plane for the model. + +Run with: +``` +poetry run python -m scripts.lrp.plane_analysis +``` +""" + +import argparse +from loguru import logger + +from datasets import Dataset +from torch.utils.data import DataLoader +import torch + +from lczerolens import Lens +from lczerolens.encodings import move as move_encoding +from lczerolens.xai import MulticlassConcept +from lczerolens.model import ForceValueFlow, PolicyFlow +from lczerolens.xai import concept +from scripts import visualisation + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def main(args): + dataset = Dataset.from_json( + "./assets/TCEC_game_collection_random_boards_bestlegal.jsonl", features=MulticlassConcept.features + ) + logger.info(f"Loaded dataset with {len(dataset)} boards.") + if args.target == "policy": + wrapper = PolicyFlow.from_path(f"./assets/{args.model_name}").to(DEVICE) + init_rel_fn = concept.concept_init_rel + + elif args.target == "value": + wrapper = ForceValueFlow.from_path(f"./assets/{args.model_name}").to(DEVICE) + init_rel_fn = None + else: + raise ValueError(f"Target '{args.target}' not supported.") + lens = Lens.from_name("lrp") + if not lens.is_compatible(wrapper): + raise ValueError(f"Lens of type 'lrp' not compatible with model '{args.model_name}'.") + + dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=concept.concept_collate_fn) + + iter_analyse = lens.analyse_batched_boards( + dataloader, + wrapper, + target=None, + return_output=True, + init_rel_fn=init_rel_fn, + ) + all_stats = { + "relative_piece_relevance": [], + "absolute_piece_relevance": [], + "plane_relevance_proportion": [], + "relative_piece_relevance_proportion": [], + "absolute_piece_relevance_proportion": [], + } + n_plotted = 0 + for batch in iter_analyse: + batched_relevances, boards, *infos = batch + relevances, outputs = batched_relevances + labels = infos[0] + for rel, out, board, label in zip(relevances, outputs, boards, labels): + max_config_rel = rel[:12].abs().max().item() + if max_config_rel == 0: + continue + if n_plotted < args.plot_first_n: + if board.turn: + heatmap = rel.sum(dim=0).view(64) + else: + heatmap = rel.sum(dim=0).flip(0).view(64) + if args.target == "policy": + move = move_encoding.decode_move(label, (board.turn, not board.turn), board) + else: + move = None + visualisation.render_heatmap( + board, + heatmap, + arrows=[(move.from_square, move.to_square)] if move is not None else None, + normalise="abs", + save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png", + ) + n_plotted += 1 + + plane_order = "PNBRQKpnbrqk" + piece_relevance = {} + for i, letter in enumerate(plane_order): + num = (rel[i] != 0).sum().item() + if num == 0: + piece_relevance[letter] = 0 + else: + piece_relevance[letter] = rel[i].sum().item() / num + + if piece_relevance["q"] / max_config_rel > 0.9 and args.target == "value": + if board.turn: + heatmap = rel.sum(dim=0).view(64) + else: + heatmap = rel.sum(dim=0).flip(0).view(64) + if args.target == "policy": + move = move_encoding.decode_move(label, (board.turn, not board.turn), board) + else: + move = None + visualisation.render_heatmap( + board, + heatmap, + arrows=[(move.from_square, move.to_square)] if move is not None else None, + normalise="abs", + save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png", + ) + raise SystemExit + + if any([piece_relevance[k] / max_config_rel > 0.9 for k in "pnbrqk"]) and args.target == "policy": + if board.turn: + heatmap = rel.sum(dim=0).view(64) + else: + heatmap = rel.sum(dim=0).flip(0).view(64) + if args.target == "policy": + move = move_encoding.decode_move(label, (board.turn, not board.turn), board) + else: + move = None + visualisation.render_heatmap( + board, + heatmap, + arrows=[(move.from_square, move.to_square)] if move is not None else None, + normalise="abs", + save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png", + ) + raise SystemExit + + all_stats["absolute_piece_relevance"].append(piece_relevance) + all_stats["relative_piece_relevance"].append({k: v / max_config_rel for k, v in piece_relevance.items()}) + + total_relevance = rel.abs().sum().item() + clock = board.fullmove_number * 2 - (not board.turn) + proportion = rel.abs().sum(dim=(1, 2)).div(total_relevance).tolist() + all_stats["plane_relevance_proportion"].append({clock: proportion}) + all_stats["relative_piece_relevance_proportion"].append( + {clock: [v / max_config_rel for v in piece_relevance.values()]} + ) + all_stats["absolute_piece_relevance_proportion"].append({clock: proportion[:12]}) + + logger.info(f"Processed {len(all_stats['relative_piece_relevance'])} boards.") + + visualisation.render_boxplot( + all_stats["relative_piece_relevance"], + y_label="Relevance", + title="Relative Relevance", + save_to=f"./scripts/results/{args.target}_piece_relative_relevance.png", + ) + visualisation.render_boxplot( + all_stats["absolute_piece_relevance"], + y_label="Relevance", + title="Absolute Relevance", + save_to=f"./scripts/results/{args.target}_piece_absolute_relevance.png", + ) + visualisation.render_proportion_through_index( + all_stats["plane_relevance_proportion"], + plane_type="Pieces", + y_label="Proportion of relevance", + y_log=True, + max_index=200, + title="Proportion of relevance per piece", + save_to=f"./scripts/results/{args.target}_plane_config_relevance.png", + ) + visualisation.render_proportion_through_index( + all_stats["plane_relevance_proportion"], + plane_type="H0", + y_label="Proportion of relevance", + y_log=True, + max_index=200, + title="Proportion of relevance per plane", + save_to=f"./scripts/results/{args.target}_plane_H0_relevance.png", + ) + visualisation.render_proportion_through_index( + all_stats["plane_relevance_proportion"], + plane_type="Hist", + y_label="Proportion of relevance", + y_log=True, + max_index=200, + title="Proportion of relevance per plane", + save_to=f"./scripts/results/{args.target}_plane_hist_relevance.png", + ) + visualisation.render_proportion_through_index( + all_stats["relative_piece_relevance_proportion"], + plane_type="Pieces", + y_label="Proportion of relevance", + y_log=False, + max_index=200, + title="Proportion of relevance per piece", + save_to=f"./scripts/results/{args.target}_piece_plane_relative_relevance.png", + ) + visualisation.render_proportion_through_index( + all_stats["absolute_piece_relevance_proportion"], + plane_type="Pieces", + y_label="Proportion of relevance", + y_log=False, + max_index=200, + title="Proportion of relevance per piece", + save_to=f"./scripts/results/{args.target}_piece_plane_absolute_relevance.png", + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser("plane-importance") + parser.add_argument("--model_name", type=str, default="64x6-2018_0627_1913_08_161.onnx") + parser.add_argument("--target", type=str, default="value") + parser.add_argument("--batch_size", type=int, default=100) + parser.add_argument("--plot_first_n", type=int, default=5) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/results/.gitignore b/scripts/results/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/scripts/results/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/scripts/visualisation.py b/scripts/visualisation.py new file mode 100644 index 0000000..27ba14b --- /dev/null +++ b/scripts/visualisation.py @@ -0,0 +1,200 @@ +""" +Visualisation utils. +""" + +import chess +import chess.svg +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + +from lczerolens.encodings import board as board_encoding + +COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000) +ALPHA = 1.0 +NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False) + + +def render_heatmap( + board, + heatmap, + square=None, + vmin=None, + vmax=None, + arrows=None, + normalise="none", + save_to=None, +): + """ + Render a heatmap on the board. + """ + if normalise == "abs": + a_max = heatmap.abs().max() + if a_max != 0: + heatmap = heatmap / a_max + vmin = -1 + vmax = 1 + if vmin is None: + vmin = heatmap.min() + if vmax is None: + vmax = heatmap.max() + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) + + color_dict = {} + for square_index in range(64): + color = COLOR_MAP(norm(heatmap[square_index])) + color = (*color[:3], ALPHA) + color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) + fig = plt.figure(figsize=(1, 6)) + ax = plt.gca() + ax.axis("off") + fig.colorbar( + matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP), + ax=ax, + orientation="vertical", + fraction=1.0, + ) + if square is not None: + try: + check = chess.parse_square(square) + except ValueError: + check = None + else: + check = None + if arrows is None: + arrows = [] + + svg_board = chess.svg.board( + board, + check=check, + fill=color_dict, + size=350, + arrows=arrows, + ) + if save_to is not None: + plt.savefig(save_to) + with open(save_to.replace(".png", ".svg"), "w") as f: + f.write(svg_board) + plt.close() + else: + plt.close() + return svg_board, fig + + +def render_boxplot( + data, + filter_null=True, + y_label=None, + title=None, + save_to=None, +): + labels = data[0].keys() + boxed_data = {label: [] for label in labels} + for d in data: + for label in labels: + v = d.get(label) + if v == 0.0 and filter_null: + continue + boxed_data[label].append(v) + plt.boxplot(boxed_data.values(), notch=True, vert=True, patch_artist=True, labels=labels) + plt.ylabel(y_label) + plt.title(title) + if save_to is not None: + plt.savefig(save_to) + plt.close() + else: + plt.show() + + +def render_proportion_through_index( + data, + plane_type="H0", + max_index=None, + y_log=False, + y_label=None, + title=None, + save_to=None, +): + if plane_type == "H0": + indexed_data = { + "H0": {}, + "Hist": {}, + "Meta": {}, + } + for d in data: + index, proportion = next(iter(d.items())) + if max_index is not None and index > max_index: + continue + if index not in indexed_data["H0"]: + indexed_data["H0"][index] = [sum(proportion[:13])] + indexed_data["Hist"][index] = [sum(proportion[13:104])] + indexed_data["Meta"][index] = [sum(proportion[104:])] + else: + indexed_data["H0"][index].append(sum(proportion[:13])) + indexed_data["Hist"][index].append(sum(proportion[13:104])) + indexed_data["Meta"][index].append(sum(proportion[104:])) + + elif plane_type == "Hist": + indexed_data = { + "H0": {}, + "H1": {}, + "H2": {}, + "H3": {}, + "H4": {}, + "H5": {}, + "H6": {}, + "H7": {}, + "Castling": {}, + "Remaining": {}, + } + for d in data: + index, proportion = next(iter(d.items())) + if max_index is not None and index > max_index: + continue + if index not in indexed_data["H0"]: + for i in range(8): + indexed_data[f"H{i}"][index] = [sum(proportion[13 * i : 13 * (i + 1)])] + indexed_data["Castling"][index] = [sum(proportion[104:108])] + indexed_data["Remaining"][index] = [sum(proportion[108:])] + else: + for i in range(8): + indexed_data[f"H{i}"][index].append(sum(proportion[13 * i : 13 * (i + 1)])) + indexed_data["Castling"][index].append(sum(proportion[104:108])) + indexed_data["Remaining"][index].append(sum(proportion[108:])) + + elif plane_type == "Pieces": + relative_plane_order = board_encoding.get_plane_order((chess.WHITE, chess.BLACK)) + indexed_data = {letter: {} for letter in relative_plane_order} + for d in data: + index, proportion = next(iter(d.items())) + if max_index is not None and index > max_index: + continue + if index not in indexed_data[relative_plane_order[0]]: + for i, letter in enumerate(relative_plane_order): + indexed_data[letter][index] = [proportion[i]] + else: + for i, letter in enumerate(relative_plane_order): + indexed_data[letter][index].append(proportion[i]) + else: + raise ValueError(f"Invalid plane type: {plane_type}") + + n_curves = len(indexed_data) + for i, (label, curve_data) in enumerate(indexed_data.items()): + indices = sorted(list(curve_data.keys())) + mean_curve = [np.mean(curve_data[idx]) for idx in indices] + std_curve = [np.std(curve_data[idx]) for idx in indices] + c = COLOR_MAP(i / (n_curves - 1)) + plt.plot(indices, mean_curve, label=label, c=c) + lower_bound = np.array(mean_curve) - np.array(std_curve) + upper_bound = np.array(mean_curve) + np.array(std_curve) + plt.fill_between(indices, lower_bound, upper_bound, alpha=0.2, color=c) + if y_log: + plt.yscale("log") + plt.legend() + plt.ylabel(y_label) + plt.title(title) + if save_to is not None: + plt.savefig(save_to) + plt.close() + else: + plt.show()