In [1]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
import os
import glob
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from mi_utils.util.logit_lens_utils.llama_wrapper import LlamaPromptLens, run_logit_lens_batched

In [2]:
from enum import Enum

class Models(Enum):
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"


class Names(Enum):
    LAIN8B = "Meta-Llama-3-8B-Instruct-fp"
    HF100B = "Llama3-8B-1.58-100B-tokens"

In [3]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:10]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [4]:
nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

### LLaMA FP

In [5]:
llama8b_fp = LlamaPromptLens(
    model_id=Models.LAIN8B.value,
    apply_per_layer_norm=False,
    include_subblocks=False,
    device="cpu"
)



Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Architecture detected: llama
Standard FP16 or FP32 model.


In [6]:
run_logit_lens_batched(
    lens=llama8b_fp,
    prompts=nq_queries,
    dataset_name="nq_query",
    model_name="llama8b_fp",
    proj_precision=None,
    batch_size=10,
)

[✓] Saved batch 0: logs/lens_batches/nq_query_llama8b_fp_batch0.pt
All 10 prompts processed.


In [7]:
import torch

data = torch.load(
    "logs/lens_batches/nq_query_llama8b_fp_batch0.pt",
    weights_only=False 
)

In [8]:
data.tail()

Unnamed: 0,prompt_id,prompt_text,dataset,vocab_size,layer_index,layer_name,input_ids,target_ids,logits,position
335,9,when did fosters home for imaginary friends start,nq_query,128000,29,layer_28,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(-0.1089), tensor(-0.7619), tensor(-0....","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
336,9,when did fosters home for imaginary friends start,nq_query,128000,30,layer_29,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(-0.5514), tensor(-0.4113), tensor(-0....","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
337,9,when did fosters home for imaginary friends start,nq_query,128000,31,layer_30,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(-2.1905), tensor(-1.0497), tensor(-1....","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
338,9,when did fosters home for imaginary friends start,nq_query,128000,32,layer_31,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(6.7354), tensor(7.3383), tensor(5.922...","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
339,9,when did fosters home for imaginary friends start,nq_query,128000,33,output,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(2.7837), tensor(4.3844), tensor(3.649...","[tensor(0), tensor(1), tensor(2), tensor(3), t..."


### HF1BitLLM

In [9]:
llama8b_hf100b = LlamaPromptLens(
    model_id=Models.HF100B.value,
    apply_per_layer_norm=False,
    include_subblocks=False,
    device="cpu"
)



Architecture detected: bitnet
BitNet model (BitLinear layers).


In [None]:
run_logit_lens_batched(
    lens=llama8b_hf100b,
    prompts=nq_queries,
    dataset_name="nq_query",
    model_name="llama8b_hf100b",
    proj_precision=None,
    batch_size=10,
)