In [1]:
from huggingface_hub import login
login()

In [2]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

In [6]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_name = "microsoft/phi-3-mini-4k-instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_name)

model

Loading weights:   0%|          | 0/195 [00:00<?, ?it/s]

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLUActivation()
        )
        (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): Phi3RMSNorm((3072,), eps=1e-05)
    (rotary_emb): Phi3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3072, out_featur

In [9]:
text = "1:sunshine, 2:apple, 3:koala, 2:apple, 1:sunshine, 3:koala, 1:"
text_tokenized = tokenizer.encode(text, return_tensors="pt").to(device)
output = model.generate(text_tokenized, num_beams=4, max_new_tokens=3, do_sample=True)
text_decoded = tokenizer.decode(output[0])
print(text_decoded)

1:sunshine, 2:apple, 3:koala, 2:apple, 1:sunshine, 3:koala, 1:sunshine


In [None]:
def decode_token_with_logit_lens(
    model,
    device,
    tokenizer,
    input,
    tokens_to_gen=None
):
    inputs = tokenizer(input, return_tensors="pt").to(device)

    text = tokenizer.decode(inputs['input_ids'][0])

    if tokens_to_gen != None:
        output = model.generate(
            inputs['input_ids'],
            do_sample=True,
            top_p=0.95,
            temperature=0.001,
            tok_k=0,
            max_new_tokens=tokens_to_gen
        )
        new_token = tokenizer.decode(output[0][-tokens_to_gen:])
        text += new_token
        inputs = tokenizer(text, return_tensors="pt").to(device)
    text_tokens = [tokenizer.decode(id) for id in inputs['input_ids'][0]]

classifier_head = model.lm_head

hidden_states = model(**inputs, output_hidden_states=True).hidden_states

decoded_intermediate_tokens = {}
decoded_intermediate_logits = {}

with torch.no_grad():
    for layer_id in range(len(hidden_states)):
        hidden_state = hidden_states[layer_id]
        decoded_value = classifier_head(hidden_state)
        decoded_values = torch.nn.functional.softmax(decoded_value, dim=-1)
        argmax = torch.argmax(decoded_values, dim=-1)
        decoded_token = [tokenizer.decode([id]) for id in argmax]
        decoded_logit = [decoded_values[0, it, argmax[0, it]].item() for it in range(len(argmax[0]))]
        decoded_intermediate_tokens[layer_id] = decoded_token
        decoded_intermediate_logits[layer_id] = decoded_logit

tokens = list(decoded_intermediate_tokens.values())
logits = list(decoded_intermediate_logits.values())
return {'text_tokens': text_tokens,
        'decoded_tokens': tokens,
        'decoded_logits': logits}