# Evidence of Learned Look-Ahead

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/evidence-of-learned-look-ahead.ipynb)

# Setup

In [None]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"

In [None]:
if MODE == "colab":
    %pip install -q lczerolens
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    %pip install -q ./lczerolens

In [None]:
!wget https://figshare.com/ndownloader/files/46473526?private_link=adc80845c00b67c8fce5 -O interesting_puzzles.pkl
!wget https://figshare.com/ndownloader/files/46473529?private_link=adc80845c00b67c8fce5 -O lc0.onnx

# When `wget` fail, e.g., "403 Forbidden"
# %pip install gdown
# !gdown https://drive.google.com/uc?id=1GT6I7FAgxWIxA-tzsifBQx0MkKZcR_qz -O interesting_puzzles.pkl
# !gdown https://drive.google.com/uc?id=1PB097ZKd_zTaPHxLK29WKUWmv6KcZ15T -O lc0.onnx

# Checking Assets

In [None]:
import pickle
import chess

In [None]:
with open("interesting_puzzles.pkl", "rb") as f:
    puzzles = pickle.load(f)
puzzles.head()

In [None]:
from lczerolens import LczeroModel

model = LczeroModel.from_path("lc0.onnx")
model

In [None]:
import IPython

from lczerolens.board import LczeroBoard

puzzle = puzzles.loc[19612]
board = LczeroBoard(puzzle.FEN)
moves = puzzle.Moves.split()
board.push_uci(moves[0])
corrupted_board = LczeroBoard(puzzle.corrupted_fen)
display(board)
display(corrupted_board)

In [None]:
out = model(*[board, corrupted_board])
out["wdl"]

## Visualising Attention

In [None]:
layer = 9
head = 5

with model.trace(board):
    attention = getattr(model, f"encoder{layer}/mha/QK/softmax").output[0, head].save()
attention.shape

In [None]:
square = chess.F4

boardsvg, _ = board.render_heatmap(attention[square].detach())
display(IPython.display.HTML((boardsvg)))

## Probing Analysis

In [None]:
# TODO: complete this analysis

## Activation Patching

In [None]:
from lczerolens.lenses import ActivationLens  # TODO: replace with tdhook utils

MODULE = "encoder13/ln2"
act_lens = ActivationLens(MODULE)

clean_acts = act_lens.analyse(model, board)
corrupted_acts = act_lens.analyse(model, corrupted_board)

In [None]:
corrupted_acts[f"{MODULE}_output"].shape

In [None]:
with model.trace(board):
    out = model.output.save()