# Run Models on GPU

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/train_saes.ipynb)

## Setup

In [None]:
DEV = False

In [None]:
if DEV:
    !rm -r lczerolens
    !git clone https://github.com/Xmaster6y/lczerolens -b <branch>
    !pip install -q ./lczerolens
else:
    !pip install -q lczerolens

In [None]:
!gdown 1cxC8_8vw7akfPyc9cZxwaAbLG2Zl4XiT -O lc0-10-4238.onnx

In [None]:
import torch

if not torch.cuda.is_available():
    raise RuntimeError("This notebook requires a GPU")

## Load the Model and the Dataset

In [None]:
from datasets import load_dataset
from lczerolens import LczeroModel

model = LczeroModel.from_path(
    "lc0-10-4238.onnx"
).to("cuda")

dataset = load_dataset("lczero-planning/boards")
dataset

In [None]:
model

## Setup Activation Buffer

In [None]:
import chess
import einops

def collate_fn(batch):
    boards = []
    for x in batch:
        fen = x["fen"]
        moves = x["moves"]
        board = chess.Board(fen)
        for move in moves:
            board.push(chess.Move.from_uci(move))
        boards.append(board)
    return boards

def compute_fn(batch, model, contrastive, lens):
    boards = batch
    storage = lens.analyse(*boards, model=model)[0]
    if len(storage.keys()) != 1:
        raise NotImplementedError
    acts = next(iter(storage.values()))
    return einops.rearrange(acts, "b c h w -> (b h w) c")


In [None]:
from lczerolens.lenses import ActivationLens, ActivationBuffer

MODULE_NAME = "block9/conv2/relu"
LENS = ActivationLens(MODULE_NAME)
N_BATCHES_IN_BUFFER = 15
COMPUTE_BATCH_SIZE = 1_000
TRAIN_BATCH_SIZE = 10_000


In [None]:
train_buffer = ActivationBuffer(
    model,
    dataset["train"],
    compute_fn,
    N_BATCHES_IN_BUFFER,
    COMPUTE_BATCH_SIZE,
    TRAIN_BATCH_SIZE,
    dataloader_kwargs={"collate_fn": collate_fn},
)

val_buffer = ActivationBuffer(
    model,
    dataset["test"],
    compute_fn,
    N_BATCHES_IN_BUFFER,
    COMPUTE_BATCH_SIZE,
    TRAIN_BATCH_SIZE,
    dataloader_kwargs={"collate_fn": collate_fn},
)

In [None]:
acts = next(iter(train_buffer))
print("Out acts: ", acts.shape)
print("Stored acts: ", torch.cat(train_buffer._buffer, dim=0).shape)

## Train a SAE

## Evaluate a SAE