<a href="https://colab.research.google.com/github/attentionmech/tensorlens/blob/main/tensorlens/notebooks/tensorlens.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tensorlens datasets

Collecting tensorlens
  Downloading tensorlens-0.0.1-py3-none-any.whl.metadata (328 bytes)
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting gunicorn (from tensorlens)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting jedi>=0.16 (from ipython->tensorlens)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading tensorlens-0.0.1-py3-none-any.whl (963 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import os
import sys
import torch
from datasets import load_dataset
from transformers import (
    DataCollatorForLanguageModeling,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    Trainer,
    TrainerCallback,
    TrainingArguments)
from tensorlens.tensorlens import trace, viewer

# Device setup
DEVICE = "cpu"

# Model config - tiny GPT-2
config = GPT2Config(
    vocab_size=50257,
    n_positions=1024,
    n_embd=40,
    n_layer=2,
    n_head=10,
    attn_implementation="eager",
)

model = GPT2LMHeadModel(config)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

for key, tensor in model.state_dict().items():
    trace(key, tensor.detach().cpu().numpy())


viewer(notebook=True, width='100%', height=600)


# Dataset - small TinyStories slice
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")

def tokenize_function(examples):
    return tokenizer(
        examples["text"], padding="max_length", truncation=True, max_length=128
    )

tokenized_datasets = dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="/tmp/temp_output",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    save_strategy="no",
    logging_dir=None,
    report_to="none",
)

# Callback to trace weights mid-training
class TraceCallback(TrainerCallback):
    def __init__(self, model, trace_steps=10):
        self.model = model
        self.trace_steps = trace_steps

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.trace_steps == 0 and state.global_step > 0:
            self.trace_model_weights(state.global_step)

    def trace_model_weights(self, step):
        # print(f"[TRACE] Step {step} - Tracing model weights")
        for key, tensor in self.model.state_dict().items():
            trace(key, tensor.detach().cpu().numpy())
        # sys.exit(0)  # for quick test/demo; remove to continue training

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    data_collator=data_collator,
    callbacks=[TraceCallback(model, trace_steps=10)],
)

trainer.train()


<IPython.core.display.Javascript object>

KeyboardInterrupt: 