In [1]:
import torch
import nnsight
import datasets
import activation_server.text_dataset as text_dataset
from torch.utils.data import DataLoader
import h5py
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""


# Check if CUDA is available
cuda_available = torch.cuda.is_available()
print("CUDA available:", cuda_available)

if cuda_available:
    # Number of GPUs
    num_gpus = torch.cuda.device_count()
    print("Number of GPUs:", num_gpus)

    # List each device’s name
    for i in range(num_gpus):
        name = torch.cuda.get_device_name(i)
        print(f"GPU {i}: {name}")
else:
    print("No CUDA devices found")

  from .autonotebook import tqdm as notebook_tqdm


CUDA available: False
No CUDA devices found


In [2]:
model = nnsight.LanguageModel('openai-community/gpt2', device_map='auto', dispatch=True)

In [3]:
model.requires_grad_(False)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  

In [4]:
dataset = datasets.load_dataset('Skylion007/openwebtext', split='train')

In [5]:
token_dataset = text_dataset.TextDataset(
    dataset,
    model.tokenizer,
    40,
    drop_last_batch=False,
    seq_len=1023,
)

In [6]:
text_dataset_loader = iter(
    DataLoader(
        token_dataset,
        batch_size=None,
        shuffle=False,
        num_workers=5,
        prefetch_factor=5,
        worker_init_fn=text_dataset.worker_init_fn,
    )
)

Token indices sequence length is longer than the specified maximum sequence length for this model (1561 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2459 > 1024). Running this sequence through the model will result in indexing errors


Token indices sequence length is longer than the specified maximum sequence length for this model (2027 > 1024). Running this sequence through the model will result in indexing errors


In [7]:
import time
from tqdm import tqdm

# Run for 10 seconds to get average throughput
start_time = time.time()
end_time = start_time + 1
total_samples = 0
total_batches = 0

for batch in tqdm(text_dataset_loader):
    if time.time() > end_time:
        break
    total_samples += batch.numel()
    total_batches += 1

elapsed = time.time() - start_time
samples_per_sec = total_samples / elapsed
batches_per_sec = total_batches / elapsed

print(f"\nThroughput:")
print(f"Samples/sec: {samples_per_sec:.1f}")
print(f"Batches/sec: {batches_per_sec:.1f}")
print(f"Average batch size: {total_samples/total_batches:.1f}")

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1561 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2027 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2459 > 1024). Running this sequence through the model will result in indexing errors


440it [00:10, 43.90it/s]



Throughput:
Samples/sec: 1794505.3
Batches/sec: 43.9
Average batch size: 40920.0


In [7]:
def extract_activations(model, tokens):
    with model.trace(tokens) as tracer:
        mlp_ins = []
        mlp_outs = []
        for i in range(12):
            mlp_in = model.transformer.h[i].ln_2.input.save()
            mlp_ins.append(mlp_in)
            mlp_out = model.transformer.h[i].mlp.output.save()
            mlp_outs.append(mlp_out)
    # batch layer in/out d_model
    mlp_ins = torch.stack(mlp_ins, dim=2)
    mlp_outs = torch.stack(mlp_outs, dim=2)
    mlp_acts = torch.stack([mlp_ins, mlp_outs], dim=2)
    return mlp_acts  # batch seq_len in/out n_layer d_model

In [9]:
import time
from tqdm import tqdm


# Run benchmark for 10 seconds to measure activation extraction throughput
start_time = time.time()
end_time = start_time + 50
total_tokens = 0
total_batches = 0

print("Running activation extraction benchmark...")
for batch in tqdm(text_dataset_loader):
    if time.time() > end_time:
        break
        
    # prepend BOS token like in main loop
    batch = torch.roll(batch, shifts=1, dims=1)
    batch[:, 0] = model.config.bos_token_id
    
    # Extract activations
    mlp_acts = extract_activations(model, batch)
    
    total_tokens += batch.numel()
    total_batches += 1

elapsed = time.time() - start_time
tokens_per_sec = total_tokens / elapsed
batches_per_sec = total_batches / elapsed

print(f"\nActivation Extraction Throughput:")
print(f"Tokens/sec: {tokens_per_sec:,.1f}")
print(f"Batches/sec: {batches_per_sec:.1f}")
print(f"Average batch size: {total_tokens/total_batches:.1f} tokens")


Running activation extraction benchmark...


0it [00:00, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
4it [00:57, 14.41s/it]



Activation Extraction Throughput:
Tokens/sec: 2,839.1
Batches/sec: 0.1
Average batch size: 40920.0 tokens


In [9]:
with model.trace(batch):
    print('processing')
    

: 

In [None]:
alskdj

In [8]:
store_path = '/var/local/glang/activations'
filename = 'clt-activations-10M.h5'
store_size = 10000000
actv_size = model.config.n_embd

with h5py.File(os.path.join(store_path, filename), "w") as f:
    h5_dataset = f.create_dataset(
        'tensor', (store_size, 2, model.config.n_layer, model.config.n_embd), dtype='float32'
    )

    h5_pointer = 0
    for batch in text_dataset_loader:
        print(h5_pointer / store_size * 100, "% done")
        # prepend BOS (important)
        batch = torch.roll(batch, shifts=1, dims=1)
        batch[:, 0] = model.config.bos_token_id

        # extract activations
        mlp_acts = extract_activations(model, batch)

        # store activations
        mlp_acts = mlp_acts.flatten(0, 1)
        n_acts = mlp_acts.shape[0]

        if h5_pointer + n_acts > store_size:
            h5_dataset[h5_pointer:] = (
                mlp_acts[: int(store_size - h5_pointer)].cpu().numpy()
            )
            break
        else:
            h5_dataset[h5_pointer : h5_pointer + n_acts] = mlp_acts.cpu().numpy()
            h5_pointer += n_acts

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


0.0 % done
0.4092 % done
0.8184 % done
1.2276 % done
1.6368 % done
2.046 % done
2.4552 % done
2.8644 % done
3.2736 % done
3.6828 % done
4.092 % done
4.501200000000001 % done
4.9104 % done
5.3196 % done
5.7288 % done
6.138 % done
6.5472 % done
6.9564 % done
7.3656 % done
7.7748 % done
8.184 % done
8.5932 % done
9.002400000000002 % done
9.4116 % done
9.8208 % done
10.23 % done
10.6392 % done
11.048399999999999 % done
11.4576 % done
11.8668 % done
12.276 % done
12.6852 % done
13.0944 % done
13.503599999999999 % done
13.9128 % done
14.322 % done
14.7312 % done
15.140400000000001 % done
15.5496 % done
15.9588 % done
16.368 % done
16.7772 % done
17.1864 % done
17.5956 % done
18.004800000000003 % done
18.414 % done
18.8232 % done
19.2324 % done
19.6416 % done
20.0508 % done
20.46 % done
20.8692 % done
21.2784 % done
21.6876 % done
22.096799999999998 % done
22.506 % done
22.9152 % done
23.3244 % done
23.7336 % done
24.1428 % done
24.552 % done
24.9612 % done
25.3704 % done
25.779600000000002 %