Skip to content

Commit

Permalink
Big cleaning (#16)
Browse files Browse the repository at this point in the history
* big refacto

* preprocess module

* preprocess game to boards

* play & refacto config

* semantic release

* removed sr

* self play

* logging

* policy sampler

* Update tests/conftest.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update tests/conftest.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update tests/conftest.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* batched play

* lint

* next version

* next version

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
Xmaster6y and sourcery-ai[bot] committed May 2, 2024
1 parent d6c5011 commit f2fe879
Show file tree
Hide file tree
Showing 79 changed files with 871 additions and 1,511 deletions.
28 changes: 6 additions & 22 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
repos:
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
args: ["--config", "pyproject.toml"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand All @@ -19,20 +14,9 @@ repos:
rev: 1.7.0
hooks:
- id: poetry-check
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies: ['types-requests', 'types-toml']
exclude: scripts
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
args: ['--ignore=E203,W503', '--per-file-ignores=__init__.py:F401']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--settings-path", "pyproject.toml"]
name: isort (python)
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 Yoann Poupart
Copyright (c) 2024 Yoann Poupart

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
23 changes: 0 additions & 23 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# CI
.PHONY: checks
checks:
poetry run pre-commit run --all-files
Expand All @@ -19,28 +18,6 @@ tests:
docs:
cd docs && poetry run make html

# API
.PHONY: demo
demo:
poetry run python -m demo.main

# Docker
.PHONY: docker-build
docker-build:
docker compose -f docker/docker-compose.yml build

.PHONY: docker-start
docker-start:
docker compose -f docker/docker-compose.yml up

.PHONY: docker-start-bg
docker-start-bg:
docker compose -f docker/docker-compose.yml up -d --build

.PHONY: docker-stop
docker-stop:
docker compose -f docker/docker-compose.yml down

.PHONY: docker-tty
docker-tty:
docker compose -f docker/docker-compose.yml exec fastapi bash
5 changes: 0 additions & 5 deletions apptainer/.gitignore

This file was deleted.

19 changes: 0 additions & 19 deletions apptainer/base.def

This file was deleted.

29 changes: 0 additions & 29 deletions apptainer/make-datasets.sh

This file was deleted.

9 changes: 0 additions & 9 deletions apptainer/script.def

This file was deleted.

24 changes: 6 additions & 18 deletions demo/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def make_plot(
state_boards,
state_cache,
):

if state_cache == []:
gr.Warning("No cache available.")
return None, None, None
Expand All @@ -107,8 +106,7 @@ def make_plot(
num_attention_layers = len(state_cache[state_board_index])
if attention_layer > num_attention_layers:
gr.Warning(
f"Attention layer {attention_layer} does not exist, "
f"using layer {num_attention_layers} instead."
f"Attention layer {attention_layer} does not exist, " f"using layer {num_attention_layers} instead."
)
attention_layer = num_attention_layers

Expand All @@ -120,8 +118,7 @@ def make_plot(
return None, None, None
if attention_head > attention_tensor.shape[1]:
gr.Warning(
f"Attention head {attention_head} does not exist, "
f"using head {attention_tensor.shape[1]+1} instead."
f"Attention head {attention_head} does not exist, " f"using head {attention_tensor.shape[1]+1} instead."
)
attention_head = attention_tensor.shape[1]
try:
Expand All @@ -136,9 +133,7 @@ def make_plot(
heatmap = attention_tensor[0, attention_head - 1, square_index]
if board.turn == chess.BLACK:
heatmap = heatmap.view(8, 8).flip(0).view(64)
svg_board, fig = visualisation.render_heatmap(
board, heatmap, square=square
)
svg_board, fig = visualisation.render_heatmap(board, heatmap, square=square)
with open(f"{constants.FIGURE_DIRECTORY}/attention.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/attention.svg", board.fen(), fig
Expand Down Expand Up @@ -206,9 +201,7 @@ def next_board(
)
with gr.Column(scale=1):
with gr.Row():
model_name = gr.Textbox(
label="Selected model", lines=1, interactive=False, scale=7
)
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)

model_df.select(
on_select_model_df,
Expand All @@ -228,10 +221,7 @@ def next_board(
label="Action sequence",
lines=1,
max_lines=1,
value=(
"e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 "
"d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"
),
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
)
compute_cache_button = gr.Button("Compute cache")

Expand Down Expand Up @@ -295,9 +285,7 @@ def next_board(
inputs=base_inputs,
outputs=outputs + [state_board_index],
)
next_board_button.click(
next_board, inputs=base_inputs, outputs=outputs + [state_board_index]
)
next_board_button.click(next_board, inputs=base_inputs, outputs=outputs + [state_board_index])

attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs)
attention_head.change(make_plot, inputs=base_inputs, outputs=outputs)
Expand Down
41 changes: 12 additions & 29 deletions demo/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from lczero.backends import Backend, GameState, Weights

from demo import constants, utils, visualisation
from lczerolens import move_utils
from lczerolens.utils import lczero as lczero_utils
from lczerolens import move_encodings
from lczerolens.model import lczero as lczero_utils
from lczerolens.xai import PolicyLens


Expand Down Expand Up @@ -74,9 +74,7 @@ def make_policy_plot(
only_legal=only_legal,
illegal_value=0,
)
pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(
policy, int(aggregate_topk)
)
pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(policy, int(aggregate_topk))

if view == "from":
if board.turn == chess.WHITE:
Expand All @@ -90,9 +88,7 @@ def make_policy_plot(
heatmap = dropoff_agg.view(8, 8).flip(0).view(64)
us_them = (board.turn, not board.turn)
if only_legal:
legal_moves = [
move_utils.encode_move(move, us_them) for move in board.legal_moves
]
legal_moves = [move_encodings.encode_move(move, us_them) for move in board.legal_moves]
filtered_policy = torch.zeros(1858)
filtered_policy[legal_moves] = policy[legal_moves]
if (filtered_policy < 0).any():
Expand All @@ -102,11 +98,9 @@ def make_policy_plot(
topk_moves = torch.topk(policy, render_bestk)
arrows = []
for move_index in topk_moves.indices:
move = move_utils.decode_move(move_index, us_them)
move = move_encodings.decode_move(move_index, us_them)
arrows.append((move.from_square, move.to_square))
svg_board, fig = visualisation.render_heatmap(
board, heatmap, arrows=arrows
)
svg_board, fig = visualisation.render_heatmap(board, heatmap, arrows=arrows)
with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f:
f.write(svg_board)
raw_policy, _ = lczero_utils.prediction_from_backend(
Expand All @@ -118,7 +112,7 @@ def make_policy_plot(
)
fig_dist = visualisation.render_policy_distribution(
raw_policy,
[move_utils.encode_move(move, us_them) for move in board.legal_moves],
[move_encodings.encode_move(move, us_them) for move in board.legal_moves],
)
return (
f"{constants.FIGURE_DIRECTORY}/policy.svg",
Expand All @@ -140,9 +134,7 @@ def make_policy_plot(
)
with gr.Column(scale=1):
with gr.Row():
model_name = gr.Textbox(
label="Selected model", lines=1, interactive=False, scale=7
)
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)

model_df.select(
on_select_model_df,
Expand All @@ -161,10 +153,7 @@ def make_policy_plot(
label="Action sequence",
lines=1,
max_lines=1,
value=(
"e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 "
"d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"
),
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
)
with gr.Group():
with gr.Row():
Expand Down Expand Up @@ -194,15 +183,11 @@ def make_policy_plot(
value=5,
scale=3,
)
only_legal = gr.Checkbox(
label="Only legal", value=True, scale=1
)
only_legal = gr.Checkbox(label="Only legal", value=True, scale=1)

policy_button = gr.Button("Plot policy")
colorbar = gr.Plot(label="Colorbar")
game_info = gr.Textbox(
label="Game info", lines=1, max_lines=1, value=""
)
game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
with gr.Column():
image = gr.Image(label="Board")
density_plot = gr.Plot(label="Density")
Expand All @@ -219,6 +204,4 @@ def make_policy_plot(
only_legal,
]
policy_outputs = [image, colorbar, game_info, density_plot]
policy_button.click(
make_policy_plot, inputs=policy_inputs, outputs=policy_outputs
)
policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
6 changes: 2 additions & 4 deletions demo/convert_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import gradio as gr

from demo import constants, utils
from lczerolens.utils import lczero as lczero_utils
from lczerolens.model import lczero as lczero_utils


def list_models():
Expand Down Expand Up @@ -147,9 +147,7 @@ def get_model_path(
)
with gr.Column(scale=1):
with gr.Row():
model_name = gr.Textbox(
label="Selected model", lines=1, interactive=False, scale=7
)
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
conversion_status = gr.Textbox(
label="Conversion status",
lines=1,
Expand Down
Loading

0 comments on commit f2fe879

Please sign in to comment.