In [18]:
# 'AutoModelForCausalLM' is just an interface into the language model.
# 'CausalLM' refers to left-to-right language modeling, or next-token-prediction.
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

import torch

In [1]:
model_id = "microsoft/Phi-3-mini-4k-instruct"
causal_lm_model = AutoModelForCausalLM.from_pretrained(model_id)
model = causal_lm_model.model

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
model

Phi3Model(
  (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
  (embed_dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-31): 32 x Phi3DecoderLayer(
      (self_attn): Phi3SdpaAttention(
        (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
        (rotary_emb): Phi3RotaryEmbedding()
      )
      (mlp): Phi3MLP(
        (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
        (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
        (activation_fn): SiLU()
      )
      (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
      (resid_attn_dropout): Dropout(p=0.0, inplace=False)
      (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
    )
  )
  (norm): Phi3RMSNorm((3072,), eps=1e-05)
)

In [3]:
causal_lm_model.lm_head

Linear(in_features=3072, out_features=32064, bias=False)

In [4]:
model.embed_tokens

Embedding(32064, 3072, padding_idx=32000)

In [5]:
len(model.layers)

32

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [13]:
def phi_prompt(*, system: str, user: str) -> str:
    return "\n".join([
        "<|system|>",
        f"{system}<|end|>",
        "<|user|>",
        f"{user}<|end|>",
        "<|assistant|>",
    ])

In [14]:
prompt = phi_prompt(
    system="",
    user="A journey of a thousand miles begins with a single ",
)
inputs = tokenizer(
    prompt,
    return_tensors="pt",
)
input_batch = inputs["input_ids"]
input_batch.shape

torch.Size([1, 16])

In [15]:
inputs_embeds = model.embed_tokens(input_batch)
inputs_embeds.shape

torch.Size([1, 16, 3072])

In [16]:
hidden_states = inputs_embeds

In [19]:
position_ids = torch.tensor(range(0, input_batch.shape[1])).unsqueeze(dim=0)

In [20]:
for i, layer in enumerate(model.layers):
    print("layer", i)
    residual = hidden_states
    hidden_states = layer.input_layernorm(hidden_states)
    attn_outputs, _self_attn_weights, _present_key_value = layer.self_attn(
        hidden_states=hidden_states,
        position_ids=position_ids,
        output_attentions=False,
        use_cache=False,
    )
    hidden_states = residual + attn_outputs
    residual = hidden_states
    hidden_states = layer.post_attention_layernorm(hidden_states)
    hidden_states = layer.mlp(hidden_states)
    hidden_states = residual + layer.resid_mlp_dropout(hidden_states)
    pass

layer 0
layer 1
layer 2
layer 3
layer 4
layer 5
layer 6
layer 7
layer 8
layer 9
layer 10
layer 11
layer 12
layer 13
layer 14
layer 15
layer 16
layer 17
layer 18
layer 19
layer 20
layer 21
layer 22
layer 23
layer 24
layer 25
layer 26
layer 27
layer 28
layer 29
layer 30
layer 31


In [21]:
normed_hidden_state = model.norm(hidden_states)
last_token_hidden_state = normed_hidden_state[:, -1, :]

In [22]:
batch_logits = causal_lm_model.lm_head(last_token_hidden_state)
logits = batch_logits[0]

In [23]:
token = torch.argmax(logits)
tokenizer.decode(token)

'step'