In [2]:
from model_utils import load_model_and_tokenizer, get_submodule

In [3]:
model_name = "microsoft/phi-2"
model, tokenizer = load_model_and_tokenizer(model_name)

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.81s/it]

Loaded microsoft/phi-2





In [4]:
modules = list(model.state_dict().keys())
print("\n".join(modules))

model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.q_proj.bias
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.k_proj.bias
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.v_proj.bias
model.layers.0.self_attn.dense.weight
model.layers.0.self_attn.dense.bias
model.layers.0.mlp.fc1.weight
model.layers.0.mlp.fc1.bias
model.layers.0.mlp.fc2.weight
model.layers.0.mlp.fc2.bias
model.layers.0.input_layernorm.weight
model.layers.0.input_layernorm.bias
model.layers.1.self_attn.q_proj.weight
model.layers.1.self_attn.q_proj.bias
model.layers.1.self_attn.k_proj.weight
model.layers.1.self_attn.k_proj.bias
model.layers.1.self_attn.v_proj.weight
model.layers.1.self_attn.v_proj.bias
model.layers.1.self_attn.dense.weight
model.layers.1.self_attn.dense.bias
model.layers.1.mlp.fc1.weight
model.layers.1.mlp.fc1.bias
model.layers.1.mlp.fc2.weight
model.layers.1.mlp.fc2.bias
model.layers.1.input_layernorm.weight
model.layers.1.input_layer

In [5]:
print(get_submodule(model, "model.layers.1.self_attn"))

PhiSdpaAttention(
  (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
  (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
  (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
  (dense): Linear(in_features=2560, out_features=2560, bias=True)
  (rotary_emb): PhiRotaryEmbedding()
)


In [6]:
import re

In [7]:
target_pattern = "model.layers.*.self_attn.v_proj"

target_modules = {}
for module in modules:
    match = re.search(target_pattern, module)
    if match:
        target_module = match.group()
        target_modules[target_module] = get_submodule(model, target_module)


In [8]:
target_modules

{'model.layers.0.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.1.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.2.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.3.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.4.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.5.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.6.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.7.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.8.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.9.self_attn.v_proj': Linear(in_features=2560, out_features=2560, bias=True),
 'model.layers.10.self_attn.v_proj': Linear(in_features=2560, out_features=2560,

In [9]:
from functools import partial

In [10]:
recorded_outputs = {k: [] for k in target_modules.keys()}

def record_output(module, input, output, name):
    for i in range(output.shape[0]):
        recorded_outputs[name].append(output[i])

for name, module in target_modules.items():
    module.register_forward_hook(partial(record_output, name=name))


In [11]:
input_samples = [
    "The dog is on the mat.",
    "Do you know the muffin man",
    "Hasta la vista baby",
    "Sugar spice and everything nice"
]

In [12]:
input_tokens = tokenizer(
    input_samples,
    return_tensors="pt",
    padding=True
)

In [13]:
input_tokens

{'input_ids': tensor([[  464,  3290,   318,   319,   262,  2603,    13],
        [ 5211,   345,   760,   262, 27563,   259,   582],
        [   39, 40197,  8591,   410, 12523,  5156, 50256],
        [   50, 35652, 25721,   290,  2279,  3621, 50256]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 0]])}

In [18]:
device = "mps"

model.to(device).eval()
input_tokens = {k: v.to(device) for k, v in input_tokens.items()}


model_outputs = model(**input_tokens, output_hidden_states=False)


In [23]:
for name, outs in recorded_outputs.items():
    print(f"{name}: {len(outs)}")

model.layers.0.self_attn.v_proj: 4
model.layers.1.self_attn.v_proj: 4
model.layers.2.self_attn.v_proj: 4
model.layers.3.self_attn.v_proj: 4
model.layers.4.self_attn.v_proj: 4
model.layers.5.self_attn.v_proj: 4
model.layers.6.self_attn.v_proj: 4
model.layers.7.self_attn.v_proj: 4
model.layers.8.self_attn.v_proj: 4
model.layers.9.self_attn.v_proj: 4
model.layers.10.self_attn.v_proj: 4
model.layers.11.self_attn.v_proj: 4
model.layers.12.self_attn.v_proj: 4
model.layers.13.self_attn.v_proj: 4
model.layers.14.self_attn.v_proj: 4
model.layers.15.self_attn.v_proj: 4
model.layers.16.self_attn.v_proj: 4
model.layers.17.self_attn.v_proj: 4
model.layers.18.self_attn.v_proj: 4
model.layers.19.self_attn.v_proj: 4
model.layers.20.self_attn.v_proj: 4
model.layers.21.self_attn.v_proj: 4
model.layers.22.self_attn.v_proj: 4
model.layers.23.self_attn.v_proj: 4
model.layers.24.self_attn.v_proj: 4
model.layers.25.self_attn.v_proj: 4
model.layers.26.self_attn.v_proj: 4
model.layers.27.self_attn.v_proj: 4
mo

In [30]:
import torch

In [32]:
torch.stack(recorded_outputs["model.layers.13.self_attn.v_proj"]).shape
# n_tokens, n_outputs

torch.Size([4, 7, 2560])

In [39]:
input_tokens["attention_mask"].shape

torch.Size([4, 7])

In [40]:
def principal_components(data, n_components):
    data = data - data.mean(dim=0)
    U, S, V = torch.linalg.svd(data)
    proj = V[:,:n_components]
    return proj
