In [1]:
from nnsight import LanguageModel
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

In [2]:
MODEL_PATH = "google/gemma-2-2b-it"

device = "cuda" if torch.cuda.is_available() else "cpu"

model = LanguageModel(
    MODEL_PATH,
    device_map="auto",
)
tokenizer = model.tokenizer

print(model)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemm

In [3]:
prompt_text = """<start_of_turn>user
This is a beatiful car.
Is the above statement positive in sentiment? Answer Yes or No only<end_of_turn>
<start_of_turn>model"""

max_new_tokens = 50

with model.generate(prompt_text, max_new_tokens=max_new_tokens) as gen_tracer:
    output_tokens = model.generator.output.save()

generated = model.tokenizer.decode(output_tokens[0].cpu())
print(generated)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

<bos><start_of_turn>user
This is a beatiful car.
Is the above statement positive in sentiment? Answer Yes or No only<end_of_turn>
<start_of_turn>modelYes 
<end_of_turn>


In [5]:
prompt_text = """<start_of_turn>user
This is a ugly car.
Is the above statement positive in sentiment? Answer Yes or No only<end_of_turn>
<start_of_turn>model
"""

n_steps = 5   # how many tokens to analyze
k = 10        # top-k predictions per step

# Run generation context
with model.generate(prompt_text, max_new_tokens=n_steps) as _:
    # List to hold proxies for each step's logits
    logits_proxies = []
    for _ in range(n_steps):
        # Save current logits (for next-token prediction)
        logits_proxies.append(model.lm_head.output.save())
        # Advance the generator by one token
        model.lm_head.next()

# After exiting the context, extract real tensors and compute top-k
for step_idx, proxy in enumerate(logits_proxies, start=1):
    logits = proxy      # [1, seq_len + step_idx - 1, vocab_size]
    last_logits = logits[0, -1]
    probs = torch.softmax(last_logits, dim=-1)
    topk_probs, topk_ids = torch.topk(probs, k)
    raw_tokens = model.tokenizer.convert_ids_to_tokens(topk_ids.tolist())
    decoded = [model.tokenizer.decode([tid]) for tid in topk_ids.tolist()]

    print(f"Step {step_idx} top-{k} predictions:")
    for tok, rep, p in zip(raw_tokens, decoded, topk_probs.tolist()):
        print(f"  {tok} ({repr(rep)}): {p:.4f}")

Step 1 top-10 predictions:
  No ('No'): 1.0000
  Yes ('Yes'): 0.0000
  Negative ('Negative'): 0.0000
  no ('no'): 0.0000
  NO ('NO'): 0.0000
  False ('False'): 0.0000
  Answer ('Answer'): 0.0000
  Nope ('Nope'): 0.0000
  ▁No (' No'): 0.0000
  N ('N'): 0.0000
Step 2 top-10 predictions:
  ▁ (' '): 0.9822
  . ('.'): 0.0178
  
 ('\n'): 0.0001
  <end_of_turn> ('<end_of_turn>'): 0.0000
  ▁▁ ('  '): 0.0000
  <eos> ('<eos>'): 0.0000
  , (','): 0.0000
  ▁😊 (' 😊'): 0.0000
  

 ('\n\n'): 0.0000
  ▁👍 (' 👍'): 0.0000
Step 3 top-10 predictions:
  
 ('\n'): 1.0000
  

 ('\n\n'): 0.0000
  


 ('\n\n\n'): 0.0000
  <end_of_turn> ('<end_of_turn>'): 0.0000
  🚗 ('🚗'): 0.0000
  ❌ ('❌'): 0.0000
  😜 ('😜'): 0.0000
  <eos> ('<eos>'): 0.0000
  🚫 ('🚫'): 0.0000
  😈 ('😈'): 0.0000
Step 4 top-10 predictions:
  <end_of_turn> ('<end_of_turn>'): 1.0000
  ▁ (' '): 0.0000
  ▁▁ ('  '): 0.0000
  ** ('**'): 0.0000
  Let ('Let'): 0.0000
  <eos> ('<eos>'): 0.0000
  Provide ('Provide'): 0.0000
  Please ('Please'): 0.0000
  ▁Phry

NNsightError: Accessing value before it's been set.