In [1]:
import torch
import nnsight
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaConfig, LlamaForCausalLM

import os

from elk.extraction.prompt_loading import load_prompts

os.environ['HF_TOKEN'] = "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"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "meta-llama/Llama-2-7b-hf"
#model_name = "EleutherAI/pythia-12b"
#model_name = "gpt2"

# first, load the state dict of the model from the model name

model = nnsight.LanguageModel(model_name, device="cpu")._load(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)

# check if tokenizer has a pad and bos token
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
if tokenizer.bos_token_id is None:
    tokenizer.bos_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.01s/it]


In [3]:
import inspect
print(inspect.getsource(model.model._update_causal_mask))

    def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        if hasattr(self.layers[0].self_attn, "past_key_value"):  # static cache
            target_length = self.config.max_position_embeddings
        else:  # dynamic cache
            target_length = (
                attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
            )

        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask

In [4]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [5]:
class InitialEmbedding(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.word_embed = model.model.wte
        self.pos_embed = model.model.wpe
    
    def forward(self, input_ids):
        # TODO : check that
        pos_ids = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)
        print(pos_ids)
        return self.word_embed(input_ids) + self.pos_embed(pos_ids)

def get_embed(model):
    #return InitialEmbedding(model)
    return model.model.embed_tokens

def get_block(model, layer):
    return model.model.layers[layer]

In [6]:
import time

@torch.no_grad()
def get_acts(statements, tokenizer, model, batch_size=32, layers=None, intermediate_device="cpu", compute_device=DEVICE):
    """
    Get given layer activations for the statements.
    Return dictionary of stacked activations.

    Caution: Layer 0 is embedding layer, layer 1 is the first transformer layer, so model.model.h[0]
    """
    t_to_cpu = 0
    t_to_gpu = 0
    t_start = time.perf_counter_ns()
    model.eval().to(intermediate_device)
    t_to_cpu += time.perf_counter_ns() - t_start
    if layers is None:
        layers = list(range(model.config.num_hidden_layers + 1))

    # get last token indexes for all statements
    last_tokens = [len(tokenizer.encode(statement)) - 1 for statement in statements]

    #print(last_tokens)

    current_hiddens = []
    cache_positions = []
    positions_ids = []
    all_hiddens = [[] for _ in range(model.config.num_hidden_layers + 1)]

    t_start = time.perf_counter_ns()
    embed = get_embed(model).to(compute_device)
    t_to_gpu += time.perf_counter_ns() - t_start

    bos_token = tokenizer.bos_token_id
    
    for batch_start in range(0, len(statements), batch_size):
        batch = statements[batch_start:min(batch_start + batch_size, len(statements))]
        # TODO : check for BOS token and last token (should be ".")
        input_ids = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).input_ids
        t_start = time.perf_counter_ns()
        input_ids = input_ids.to(compute_device)
        t_to_gpu += time.perf_counter_ns() - t_start

        if bos_token is not None and bos_token != input_ids[0, 0]:
            input_ids = torch.cat([torch.zeros(input_ids.size(0), 1, device=input_ids.device, dtype=input_ids.dtype).fill_(bos_token), input_ids], dim=1)

        # print(input_ids)
        # print(input_ids.size())

        current_hiddens.append(embed(input_ids))
        cache_positions.append(torch.arange(input_ids.size(1), device=compute_device))
        positions_ids.append(cache_positions[-1].unsqueeze(0))
        if 0 in layers:
            t_start = time.perf_counter_ns()
            all_hiddens[0].append(current_hiddens[-1][torch.arange(input_ids.size(0)), last_tokens[batch_start:batch_start + input_ids.size(0)]].to(intermediate_device))
            t_to_cpu += time.perf_counter_ns() - t_start
        
        t_start = time.perf_counter_ns()
        input_ids = input_ids.to(intermediate_device)
        t_to_cpu += time.perf_counter_ns() - t_start

    t_start = time.perf_counter_ns()
    embed.to(intermediate_device)
    t_to_cpu += time.perf_counter_ns() - t_start

    t_free = 0
    t_start = time.perf_counter_ns()
    torch.cuda.empty_cache()
    t_free += time.perf_counter_ns() - t_start

    #causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)

    for block_idx in range(max(layers)):
        t_start = time.perf_counter_ns()
        decoder_layer = get_block(model, block_idx).to(compute_device)
        t_to_gpu += time.perf_counter_ns() - t_start

        for batch_idx, batch in enumerate(current_hiddens):
            cache_position = cache_positions[batch_idx]
            position_ids = positions_ids[batch_idx]
            causal_mask = model.model._update_causal_mask(None, batch, cache_position) 
            out = decoder_layer(
                batch,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=None,
                output_attentions=False,
                use_cache=False,
                cache_position=cache_position,
            )[0]
            if block_idx + 1 in layers:
                idx = last_tokens[batch_idx * batch_size:batch_idx * batch_size + out.size(0)]
                t_start = time.perf_counter_ns()
                all_hiddens[block_idx + 1].append(
                    out[
                        torch.arange(out.size(0), device=out.device),
                        idx
                    ].to(intermediate_device)
                )
                t_to_cpu += time.perf_counter_ns() - t_start
            current_hiddens[batch_idx] = out

        t_start = time.perf_counter_ns()
        decoder_layer.to(intermediate_device)
        t_to_cpu += time.perf_counter_ns() - t_start
        t_start = time.perf_counter_ns()
        torch.cuda.empty_cache()
        t_free += time.perf_counter_ns() - t_start
    
    print(f"Time to CPU : {t_to_cpu / 1e9}")
    print(f"Time to GPU : {t_to_gpu / 1e9}")
    print(f"Time free : {t_free / 1e9}")
    
    return {layer: torch.cat(acts) for layer, acts in enumerate(all_hiddens) if len(acts) > 0}

@torch.no_grad()
def generate_acts(
    cfg,
    layers=None,
    split_type = "train",
    rank=0,
    world_size=1,
    output_dir="acts",
    noperiod=False,
    intermediate_device="cpu",
    compute_device=DEVICE,
):
    tokenizer = AutoTokenizer.from_pretrained(cfg.model)
    model = AutoModelForCausalLM.from_pretrained(cfg.model).to(intermediate_device)
    
    if layers is None:
        layers = list(range(model.config.num_hidden_layers + 1))

    ds_names = cfg.datasets

    prompt_ds = load_prompts(
        ds_names[0],
        binarize=cfg.binarize,
        num_shots=cfg.num_shots,
        split_type=split_type,
        template_path=cfg.template_path,
        rank=rank,
        world_size=world_size,
        seed=cfg.seed,
    )
    
    num_yielded = 0
    for example_id, example in enumerate(prompt_ds):
        num_variants = len(example["prompts"])
        num_choices = len(example["prompts"][0])

        hidden_dict = {
            f"hidden_{layer_idx}": torch.empty(
                num_variants,
                num_choices,
                model.config.hidden_size,
                device=intermediate_device,
                dtype=torch.int16,
            )
            for layer_idx in layers
        }
        
        text_questions = []
        statements = []
        for i, record in enumerate(example["prompts"]):
            variant_questions = []

            # Iterate over answers
            for j, choice in enumerate(record):
                text = choice["question"]

                variant_questions.append(
                    dict(
                        {
                            "template_id": i,
                            "template_name": example["template_names"][i],
                            "text": dict(
                                {
                                    "question": text,
                                    "answer": choice["answer"],
                                }
                            ),
                            "example_id": example_id,
                        }
                    )
                )
                statements.append(choice["question"] + " " + choice["answer"])
            
            text_questions.append(text)
        
        acts = get_acts(statements, tokenizer, model, layers=layers, intermediate_device=intermediate_device, compute_device=compute_device)

        # Fill hidden_dict with activations
        for layer_idx, act in acts.items():
            idx = 0
            for i, record in enumerate(example["prompts"]):
                for j, choice in enumerate(record):
                    hidden_dict[f"hidden_{layer_idx}"][i, j] = act[idx]
                    idx += 1
            
        # We skipped a variant because it was too long; move on to the next example
        if len(text_questions) != num_variants:
            continue

        out_record = dict(
            label=example["label"],
            variant_ids=example["template_names"],
            text_questions=text_questions,
            **hidden_dict,
        )

        num_yielded += 1
        yield out_record

results :
gpt2 
- extract_hiddens : 6-7-8 min, 15 au début
- generate_acts : 

In [7]:
acts = get_acts(["Hello world.", "Hello hello !."], tokenizer, model, batch_size=2)

for act in acts.values():
    print(act.shape)

for layer_idx, act in acts.items():
    print(layer_idx, act.shape)


# TODO : use nnsight to check the activations are the same

#generate_acts

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Time to CPU : 4.3362326
Time to GPU : 5.3387427
Time free : 0.1654169
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
torch.Size([2, 4096])
0 torch.Size([2, 4096])
1 torch.Size([2, 4096])
2 torch.Size([2, 4096])
3 torch.Size([2, 4096])
4 torch.Size([2, 4096])
5 torch.Size([2, 4096])
6 torch.Size([2, 4096])
7 torch.Size([2, 4096])
8 torch.Size