# Run Models on GPU

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/train_saes.ipynb)

## Setup

In [1]:
MODE = "local" # "colab" | "colab-dev" | "local"

In [2]:
if MODE == "colab":
    !pip install -q lczerolens
elif MODE == "colab-dev":
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b main
    !pip install -q ./lczerolens

In [3]:
!gdown 1cxC8_8vw7akfPyc9cZxwaAbLG2Zl4XiT -O lc0-10-4238.onnx

In [4]:
import torch

if not torch.cuda.is_available():
    raise RuntimeError("This notebook requires a GPU")

## Load the Model and the Dataset

In [5]:
from datasets import load_dataset
from lczerolens import LczeroModel

model = LczeroModel.from_path(
    "lc0-10-4238.onnx"
).to("cuda")

dataset = load_dataset("lczero-planning/boards")
dataset

Downloading readme:   0%|          | 0.00/475 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/167M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/41.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2231423 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/557856 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['gameid', 'moves', 'fen'],
        num_rows: 2231423
    })
    test: Dataset({
        features: ['gameid', 'moves', 'fen'],
        num_rows: 557856
    })
})

In [6]:
model

GraphModule(
  (inputconv): Conv2d(112, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (inputconv/relu): ReLU()
  (block0/conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block0/conv1/relu): ReLU()
  (block0/conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block0/conv2/se/pooled): OnnxGlobalAveragePoolWithKnownInputShape()
  (initializers): Module()
  (block0/conv2/se/squeeze): OnnxSqueezeDynamicAxes()
  (block0/conv2/se/matmul1): OnnxMatMul()
  (block0/conv2/se/add1): OnnxBinaryMathOperation()
  (block0/conv2/se/relu): ReLU()
  (block0/conv2/se/matmul2): OnnxMatMul()
  (block0/conv2/se/add2): OnnxBinaryMathOperation()
  (block0/conv2/se/reshape): OnnxReshape()
  (block0/conv2/se/split): OnnxSplit13()
  (block0/conv2/se/sigmoid): Sigmoid()
  (block0/conv2/se/mul): OnnxBinaryMathOperation()
  (block0/conv2/se/add3): OnnxBinaryMathOperation()
  (block0/conv2/mixin): OnnxBinaryMathOperation()
  (block0/conv2/relu): ReL

## Setup Activation Buffer

In [13]:
import chess
import einops

def collate_fn(batch):
    boards = []
    for x in batch:
        fen = x["fen"]
        moves = x["moves"]
        board = chess.Board(fen)
        for move in moves:
            board.push(chess.Move.from_uci(move))
        boards.append(board)
    return boards

def _compute_fn(batch, model, lens):
    boards = batch
    storage = lens.analyse(*boards, model=model)[0]
    if len(storage.keys()) != 1:
        raise NotImplementedError
    acts = next(iter(storage.values()))
    return einops.rearrange(acts, "b c h w -> (b h w) c")


In [17]:
from lczerolens.lenses import ActivationLens, ActivationBuffer

MODULE_NAME = "block9/conv2/relu"
LENS = ActivationLens(MODULE_NAME)
N_BATCHES_IN_BUFFER = 15
COMPUTE_BATCH_SIZE = 1_000
TRAIN_BATCH_SIZE = 10_000

def compute_fn(batch, model):
    return _compute_fn(batch, model, LENS)

In [18]:
train_buffer = ActivationBuffer(
    model,
    dataset["train"],
    compute_fn,
    N_BATCHES_IN_BUFFER,
    COMPUTE_BATCH_SIZE,
    TRAIN_BATCH_SIZE,
    dataloader_kwargs={"collate_fn": collate_fn},
)

val_buffer = ActivationBuffer(
    model,
    dataset["test"],
    compute_fn,
    N_BATCHES_IN_BUFFER,
    COMPUTE_BATCH_SIZE,
    TRAIN_BATCH_SIZE,
    dataloader_kwargs={"collate_fn": collate_fn},
)

In [19]:
acts = next(iter(train_buffer))
print("Out acts: ", acts.shape)
print("Stored acts: ", torch.cat(train_buffer._buffer, dim=0).shape)

Out acts:  torch.Size([10000, 128])
Stored acts:  torch.Size([960000, 128])


## Train a SAE

## Evaluate a SAE