In [None]:
import fmrai
from tqdm import tqdm
import torch

## Load model & tokenizer

In [None]:
from transformers import AutoModel, AutoTokenizer

# BERT:
model = AutoModel.from_pretrained('bert-base-uncased').cuda()
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# # GPT:
# model = AutoModel.from_pretrained('openai-gpt').cuda()
# tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
# tokenizer.pad_token = ' '

# # Llama 2:
# from transformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf', load_in_4bit=True)
# tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
# tokenizer.pad_token = tokenizer.eos_token

## Prepare dataset

In [None]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
from datasets import load_dataset

dataset = load_dataset('bookcorpus', streaming=True)

batch_size = 4
small_dataset = dataset['train'].take(256)
small_dataset_tokenized = small_dataset.map(
    lambda x: tokenizer(x['text'], padding='longest', return_tensors='pt'),
    batched=True,
    batch_size=batch_size
).remove_columns(['text'])

collator = DataCollatorWithPadding(tokenizer, padding='longest', return_tensors='pt')
loader = DataLoader(small_dataset_tokenized, batch_size=batch_size, collate_fn=collator)

# Analyze and plot!

In [None]:
from fmrai.instrument import instrument_model
from fmrai.analysis.attention import AttentionHeadClusterAnalyzer

with fmrai.fmrai():
    with torch.no_grad():
        m = instrument_model(model)

        # skip first batch, because it includes the creation of parameters in the computation graph
        # without this, the computation graph will be different in the first iteration.
        first_batch = next(iter(loader))
        m(**first_batch.to(model.device))

        analyzer = AttentionHeadClusterAnalyzer()
        for batch in tqdm(loader):
            with analyzer.track_batch():
                m(**batch.to(model.device))

        analyzer.analyze().plot(figsize=(16, 16))
        del analyzer