In [1]:
import fmrai
import torch
from tqdm import tqdm

# Create model

In [2]:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [3]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
from datasets import load_dataset

dataset = load_dataset('bookcorpus', streaming=True)

batch_size = 4
small_dataset = dataset['train'].take(32)
small_dataset_tokenized = small_dataset.map(
    lambda x: tokenizer(x['text'], padding='longest', return_tensors='pt'),
    batched=True,
    batch_size=batch_size
).remove_columns(['text'])

collator = DataCollatorWithPadding(tokenizer, padding='longest', return_tensors='pt')
loader = DataLoader(small_dataset_tokenized, batch_size=batch_size, collate_fn=collator)

In [4]:
import importlib

In [20]:
import fmrai.analysis.ffn_kv
importlib.reload(fmrai.analysis.ffn_kv)
from fmrai.analysis.ffn_kv import KeyValueAnalyzer

In [6]:
from fmrai.analysis.ffn_kv import KeyValueAnalyzer, KeyValueMaxSearchStrategy
from fmrai.instrument import instrument_model

with fmrai.fmrai() as f:
    with torch.no_grad():
        m = instrument_model(model)
        # print(model.encoder.layer[0].intermediate.intermediate_act_fn)

        first_batch = next(iter(loader))
        m(**first_batch.to(model.device))
        
        analyzer = KeyValueAnalyzer(KeyValueMaxSearchStrategy.CLS)
        for batch in tqdm(loader):
            with analyzer.track_batch():
                m(**batch.to(model.device))

0it [00:00, ?it/s]

73 linears
12 activations



  0%|          | 0/73 [00:00<?, ?it/s][A
  7%|▋         | 5/73 [00:00<00:01, 45.44it/s][A
 15%|█▌        | 11/73 [00:00<00:01, 53.19it/s][A
 26%|██▌       | 19/73 [00:00<00:00, 61.23it/s][A
 36%|███▌      | 26/73 [00:00<00:00, 64.45it/s][A
 47%|████▋     | 34/73 [00:00<00:00, 67.73it/s][A
 58%|█████▊    | 42/73 [00:00<00:00, 71.45it/s][A
 70%|██████▉   | 51/73 [00:00<00:00, 76.83it/s][A
100%|██████████| 73/73 [00:00<00:00, 85.10it/s][A

  0%|          | 0/12 [00:00<?, ?it/s][A
 17%|█▋        | 2/12 [00:00<00:00, 12.75it/s][A
 33%|███▎      | 4/12 [00:00<00:00, 14.87it/s][A
 50%|█████     | 6/12 [00:00<00:00, 15.53it/s][A
 67%|██████▋   | 8/12 [00:00<00:00, 16.45it/s][A
100%|██████████| 12/12 [00:00<00:00, 19.34it/s][A
1it [00:04,  4.46s/it]

linear_top_and_act 12
act_and_linear_bottom 12


8it [00:06,  1.14it/s]


In [8]:
analyzer._accumulator._mem_coeff_bank

{(0, 0): [(49, -0.0044784401543438435),
  (96, -0.009699027054011822),
  (147, -0.011142684146761894),
  (146, -0.012060926295816898),
  (1, -0.01263484451919794),
  (48, -0.013292540796101093),
  (291, -0.013939321041107178),
  (192, -0.015281794592738152),
  (241, -0.015780480578541756),
  (194, -0.015847615897655487)],
 (0, 1): [(147, -0.032944243401288986),
  (338, -0.03957612067461014),
  (192, -0.040040791034698486),
  (193, -0.04233179986476898),
  (241, -0.04359961673617363),
  (194, -0.04556981474161148),
  (49, -0.046144548803567886),
  (289, -0.04920352250337601),
  (48, -0.04983096197247505),
  (336, -0.05117439106106758)],
 (0, 2): [(50, -0.0038569250609725714),
  (99, -0.007161768618971109),
  (240, -0.008113930001854897),
  (195, -0.009400531649589539),
  (51, -0.010447164997458458),
  (96, -0.013469966128468513),
  (336, -0.013811132870614529),
  (288, -0.013934610411524773),
  (0, -0.015919648110866547),
  (192, -0.01767265796661377)],
 (0, 3): [(289, -0.05978241190314

In [14]:
import torch

print(type(torch.randn((1, 2))))

<class 'torch.Tensor'>
