In [1]:
# NBVAL_IGNORE_OUTPUT
# Janky code to do different setup when run in a Colab notebook vs VSCode
import os

IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"

try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")
    


if IN_COLAB or IN_GITHUB:
    # %pip install sentencepiece # Llama tokenizer requires sentencepiece
    %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8
    %pip install torch
    %pip install tiktoken
    # %pip install transformer_lens
    %pip install transformers_stream_generator
    # !huggingface-cli login --token NEEL'S TOKEN

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import torch

from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, loading
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
from typing import List
import gc

untested_models = []
untested_models.extend(loading.OFFICIAL_MODEL_NAMES)

print("TransformerLens currently supports " + str(len(untested_models)) + " models out of the box.")

GENERATE = True
# Fill this in if you have llama weights uploaded, and you with to test those models
LLAMA_MODEL_PATH = ""

TransformerLens currently supports 205 models out of the box.


In [3]:
def mark_models_as_tested(model_set: List[str]) -> None:
    for model in model_set:
        untested_models.remove(model)
    

def run_set(model_set: List[str], device="cuda") -> None:
    for model in model_set:
        print("Testing " + model)
        tl_model = HookedTransformer.from_pretrained_no_processing(model, device=device)
        if GENERATE:
            print(tl_model.generate("Hello my name is"))
        del tl_model
        gc.collect()
        if IN_COLAB:
            %rm -rf /root/.cache/huggingface/hub/models*

def run_llama_set(model_set: List[str], weight_root: str, device="cuda") -> None:
    for model in model_set:
        print("Testing " + model)
        # to run this, make sure weight root is the root that contains all models with the 
        # sub directories sharing the same name as the model in the list of models
        tokenizer = LlamaTokenizer.from_pretrained(weight_root + model)
        hf_model = LlamaForCausalLM.from_pretrained(weight_root + model, low_cpu_mem_usage=True)
        tl_model = HookedTransformer.from_pretrained_no_processing(
            model, 
            hf_model=hf_model,
            device=device,
            fold_ln=False,
            center_writing_weights=False,
            center_unembed=False,
            tokenizer=tokenizer,
        )
        if GENERATE:
            print(tl_model.generate("Hello my name is"))
        del tl_model
        gc.collect()
        if IN_COLAB:
            %rm -rf /root/.cache/huggingface/hub/models*


def run_encoder_decoder_set(model_set: List[str], device="cuda") -> None:
    for model in model_set:
        print("Testing " + model)
        tokenizer = AutoTokenizer.from_pretrained(model)
        tl_model = HookedEncoderDecoder.from_pretrained(model, device=device)
        if GENERATE:
            # Originally from the t5 demo
            prompt = "Hello, how are you? "
            inputs = tokenizer(prompt, return_tensors="pt")
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]
            decoder_input_ids = torch.tensor([[tl_model.cfg.decoder_start_token_id]]).to(input_ids.device)


            while True:
                logits = tl_model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)
                # logits.shape == (batch_size (1), predicted_pos, vocab_size)

                token_idx = torch.argmax(logits[0, -1, :]).item()
                print("generated token: \"", tokenizer.decode(token_idx), "\", token id: ", token_idx, sep="")

                # append token to decoder_input_ids
                decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[token_idx]]).to(input_ids.device)], dim=-1)

                # break if End-Of-Sequence token generated
                if token_idx == tokenizer.eos_token_id:
                    break
        del tl_model
        gc.collect()
        if IN_COLAB:
            %rm -rf /root/.cache/huggingface/hub/models*

def run_encoder_only_set(model_set: List[str], device="cuda") -> None:
    for model in model_set:
        print("Testing " + model)
        tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        tl_model = HookedEncoder.from_pretrained(model, device=device)

        if GENERATE:
            # Slightly adapted version of the BERT demo
            prompt = "The capital of France is [MASK]."

            input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]

            logprobs = tl_model(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)
            prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())

            print(f"Prompt: {prompt}")
            print(f'Prediction: "{prediction}"')

        del tl_model
        gc.collect()
        if IN_COLAB:
            %rm -rf /root/.cache/huggingface/hub/models*

In [4]:
# The following models can run in the T4 free environment
free_compatible = [
    "ai-forever/mGPT",
    "ArthurConmy/redwood_attn_2l",
    "bigcode/santacoder",
    "bigscience/bloom-1b1",
    "bigscience/bloom-560m",
    "distilgpt2",
    "EleutherAI/gpt-neo-1.3B",
    "EleutherAI/gpt-neo-125M",
    "EleutherAI/gpt-neo-2.7B",
    "EleutherAI/pythia-1.4b",
    "EleutherAI/pythia-1.4b-deduped",
    "EleutherAI/pythia-1.4b-deduped-v0",
    "EleutherAI/pythia-1.4b-v0",
    "EleutherAI/pythia-14m",
    "EleutherAI/pythia-160m",
    "EleutherAI/pythia-160m-deduped",
    "EleutherAI/pythia-160m-deduped-v0",
    "EleutherAI/pythia-160m-seed1",
    "EleutherAI/pythia-160m-seed2",
    "EleutherAI/pythia-160m-seed3",
    "EleutherAI/pythia-160m-v0",
    "EleutherAI/pythia-1b",
    "EleutherAI/pythia-1b-deduped",
    "EleutherAI/pythia-1b-deduped-v0",
    "EleutherAI/pythia-1b-v0",
    "EleutherAI/pythia-31m",
    "EleutherAI/pythia-410m",
    "EleutherAI/pythia-410m-deduped",
    "EleutherAI/pythia-410m-deduped-v0",
    "EleutherAI/pythia-410m-v0",
    "EleutherAI/pythia-70m",
    "EleutherAI/pythia-70m-deduped",
    "EleutherAI/pythia-70m-deduped-v0",
    "EleutherAI/pythia-70m-v0",
    "facebook/opt-1.3b",
    "facebook/opt-125m",
    "gpt2",
    "gpt2-large",
    "gpt2-medium",
    "gpt2-xl",
    "meta-llama/Llama-3.2-1B",
    "meta-llama/Llama-3.2-1B-Instruct",
    "microsoft/phi-1",
    "microsoft/phi-1_5",
    "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
    "NeelNanda/Attn_Only_1L512W_C4_Code",
    "NeelNanda/Attn_Only_2L512W_C4_Code",
    "NeelNanda/Attn_Only_3L512W_C4_Code",
    "NeelNanda/Attn_Only_4L512W_C4_Code",
    "NeelNanda/GELU_1L512W_C4_Code",
    "NeelNanda/GELU_2L512W_C4_Code",
    "NeelNanda/GELU_3L512W_C4_Code",
    "NeelNanda/GELU_4L512W_C4_Code",
    "NeelNanda/SoLU_10L1280W_C4_Code",
    "NeelNanda/SoLU_10L_v22_old",
    "NeelNanda/SoLU_12L1536W_C4_Code",
    "NeelNanda/SoLU_12L_v23_old",
    "NeelNanda/SoLU_1L512W_C4_Code",
    "NeelNanda/SoLU_1L512W_Wiki_Finetune",
    "NeelNanda/SoLU_1L_v9_old",
    "NeelNanda/SoLU_2L512W_C4_Code",
    "NeelNanda/SoLU_2L_v10_old",
    "NeelNanda/SoLU_3L512W_C4_Code",
    "NeelNanda/SoLU_4L512W_C4_Code",
    "NeelNanda/SoLU_4L512W_Wiki_Finetune",
    "NeelNanda/SoLU_4L_v11_old",
    "NeelNanda/SoLU_6L768W_C4_Code",
    "NeelNanda/SoLU_6L_v13_old",
    "NeelNanda/SoLU_8L1024W_C4_Code",
    "NeelNanda/SoLU_8L_v21_old",
    "Qwen/Qwen-1_8B",
    "Qwen/Qwen-1_8B-Chat",
    "Qwen/Qwen1.5-0.5B",
    "Qwen/Qwen1.5-0.5B-Chat",
    "Qwen/Qwen1.5-1.8B",
    "Qwen/Qwen1.5-1.8B-Chat",
    "Qwen/Qwen2-0.5B",
    "Qwen/Qwen2-0.5B-Instruct",
    "Qwen/Qwen2-1.5B",
    "Qwen/Qwen2-1.5B-Instruct",
    "Qwen/Qwen2.5-0.5B",
    "Qwen/Qwen2.5-0.5B-Instruct",
    "Qwen/Qwen2.5-1.5B",
    "Qwen/Qwen2.5-1.5B-Instruct",
    "roneneldan/TinyStories-1Layer-21M",
    "roneneldan/TinyStories-1M",
    "roneneldan/TinyStories-28M",
    "roneneldan/TinyStories-2Layers-33M",
    "roneneldan/TinyStories-33M",
    "roneneldan/TinyStories-3M",
    "roneneldan/TinyStories-8M",
    "roneneldan/TinyStories-Instruct-1M",
    "roneneldan/TinyStories-Instruct-28M",
    "roneneldan/TinyStories-Instruct-2Layers-33M",
    "roneneldan/TinyStories-Instruct-33M",
    "roneneldan/TinyStories-Instruct-3M",
    "roneneldan/TinyStories-Instruct-8M",
    "roneneldan/TinyStories-Instuct-1Layer-21M",
    "stanford-crfm/alias-gpt2-small-x21",
    "stanford-crfm/arwen-gpt2-medium-x21",
    "stanford-crfm/battlestar-gpt2-small-x49",
    "stanford-crfm/beren-gpt2-medium-x49",
    "stanford-crfm/caprica-gpt2-small-x81",
    "stanford-crfm/celebrimbor-gpt2-medium-x81",
    "stanford-crfm/darkmatter-gpt2-small-x343",
    "stanford-crfm/durin-gpt2-medium-x343",
    "stanford-crfm/eowyn-gpt2-medium-x777",
    "stanford-crfm/expanse-gpt2-small-x777",
]

if IN_COLAB:
    run_set(free_compatible)
    
mark_models_as_tested(free_compatible)

In [5]:
paid_gpu_models = [
    "01-ai/Yi-6B",
    "01-ai/Yi-6B-Chat",
    "bigscience/bloom-1b7",
    "bigscience/bloom-3b",
    "bigscience/bloom-7b1",
    "codellama/CodeLlama-7b-hf",
    "codellama/CodeLlama-7b-Instruct-hf",
    "codellama/CodeLlama-7b-Python-hf",
    "EleutherAI/pythia-2.8b",
    "EleutherAI/pythia-2.8b-deduped",
    "EleutherAI/pythia-2.8b-deduped-v0",
    "EleutherAI/pythia-2.8b-v0",
    "EleutherAI/pythia-6.9b",
    "EleutherAI/pythia-6.9b-deduped",
    "EleutherAI/pythia-6.9b-deduped-v0",
    "EleutherAI/pythia-6.9b-v0",
    "facebook/opt-2.7b",
    "facebook/opt-6.7b",
    "google/gemma-2-2b",
    "google/gemma-2-2b-it",
    "google/gemma-2b",
    "google/gemma-2b-it",
    "google/gemma-7b",
    "google/gemma-7b-it",
    "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Llama-3.1-8B",
    "meta-llama/Llama-3.1-8B-Instruct",
    "meta-llama/Llama-3.2-3B",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Meta-Llama-3-8B-Instruct",
    "microsoft/phi-2",
    "microsoft/Phi-3-mini-4k-instruct",
    "mistralai/Mistral-7B-Instruct-v0.1",
    "mistralai/Mistral-7B-v0.1",
    "mistralai/Mistral-Nemo-Base-2407",
    "Qwen/Qwen-7B",
    "Qwen/Qwen-7B-Chat",
    "Qwen/Qwen1.5-4B",
    "Qwen/Qwen1.5-4B-Chat",
    "Qwen/Qwen1.5-7B",
    "Qwen/Qwen1.5-7B-Chat",
    "Qwen/Qwen2-7B",
    "Qwen/Qwen2-7B-Instruct",
    "Qwen/Qwen2.5-3B",
    "Qwen/Qwen2.5-3B-Instruct",
    "Qwen/Qwen2.5-7B",
    "Qwen/Qwen2.5-7B-Instruct",
    "stabilityai/stablelm-base-alpha-3b",
    "stabilityai/stablelm-base-alpha-7b",
    "stabilityai/stablelm-tuned-alpha-3b",
    "stabilityai/stablelm-tuned-alpha-7b",
]

if IN_COLAB:
    run_set(paid_gpu_models)
    
mark_models_as_tested(paid_gpu_models)

In [6]:
paid_cpu_models = [
    "EleutherAI/gpt-j-6B",
    "EleutherAI/gpt-neox-20b",
    "EleutherAI/pythia-12b",
    "EleutherAI/pythia-12b-deduped",
    "EleutherAI/pythia-12b-deduped-v0",
    "EleutherAI/pythia-12b-v0",
    "facebook/opt-13b",
    "google/gemma-2-9b",
    "google/gemma-2-9b-it",
    "meta-llama/Llama-2-13b-chat-hf",
    "meta-llama/Llama-2-13b-hf",
    "Qwen/Qwen-14B",
    "Qwen/Qwen-14B-Chat",
    "Qwen/Qwen1.5-14B",
    "Qwen/Qwen1.5-14B-Chat",
    "Qwen/Qwen2.5-14B",
    "Qwen/Qwen2.5-14B-Instruct",
]

if IN_COLAB:
    run_set(paid_cpu_models, "cpu")
    
mark_models_as_tested(paid_cpu_models)

In [7]:
incompatible_models = [
    "01-ai/Yi-34B",
    "01-ai/Yi-34B-Chat",
    "facebook/opt-30b",
    "facebook/opt-66b",
    "google/gemma-2-27b",
    "google/gemma-2-27b-it",
    "meta-llama/Llama-2-70b-chat-hf",
    "meta-llama/Llama-3.1-70B",
    "meta-llama/Llama-3.1-70B-Instruct",
    "meta-llama/Meta-Llama-3-70B",
    "meta-llama/Meta-Llama-3-70B-Instruct",
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "mistralai/Mixtral-8x7B-v0.1",
    "Qwen/Qwen2.5-32B",
    "Qwen/Qwen2.5-32B-Instruct",
    "Qwen/Qwen2.5-72B",
    "Qwen/Qwen2.5-72B-Instruct",
    "Qwen/QwQ-32B-Preview",
]

mark_models_as_tested(incompatible_models)

In [8]:
# The following models take a few extra steps to function. Check the official demo for more
# information on how to use. 7b and 13b will work in the paid environment. 30b and 65b will not work
# in Colab
not_hosted_models = [
    "llama-7b-hf",
    "llama-13b-hf",
    "llama-30b-hf",
    "llama-65b-hf",
]

if LLAMA_MODEL_PATH:
    run_llama_set(not_hosted_models, LLAMA_MODEL_PATH)

mark_models_as_tested(not_hosted_models)

In [9]:
# These all work on the free version of Colab
encoder_decoders = [
    "google-t5/t5-base",
    "google-t5/t5-large",
    "google-t5/t5-small",
]
if IN_COLAB:
    run_encoder_decoder_set(encoder_decoders)

mark_models_as_tested(encoder_decoders)

In [10]:
# This model works on the free version of Colab
encoder_only_models = ["bert-base-cased"]

if IN_COLAB:
    run_encoder_only_set(encoder_only_models)

mark_models_as_tested(encoder_only_models)

In [11]:
broken_models = [
    "Baidicoot/Othello-GPT-Transformer-Lens",
]

In [12]:
# Any models listed in the cell below have not been tested. This should always remain blank. If your
# PR fails due to this notebook, most likely you need to check any new model changes to ensure that
# this notebook is up to date.
print(*untested_models, sep = '\n')

Baidicoot/Othello-GPT-Transformer-Lens
