In [1]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="SparseLLM/ReluLLaMA-7B")

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

tokenizer = AutoTokenizer.from_pretrained(
    "SparseLLM/ReluLLaMA-7B", use_fast=False
)
# output hidden state
config = AutoConfig.from_pretrained("SparseLLM/ReluLLaMA-7B", output_hidden_states=True)
model = AutoModelForCausalLM.from_pretrained("SparseLLM/ReluLLaMA-7B", config=config)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
type(model)

transformers.models.llama.modeling_llama.LlamaForCausalLM

In [3]:
[module for module in model.named_modules()]

[('',
  LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096, padding_idx=0)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
            (act_fn): ReLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): LlamaRMSNor

In [4]:
[module for module in model.named_modules()][449]
[module for module in model.named_modules()][448]

('model.layers.31.mlp.down_proj',
 Linear(in_features=11008, out_features=4096, bias=False))

In [5]:
model.get_submodule("model.layers.31.mlp.act_fn")

ReLU()

In [6]:
from torchknickknacks import modelutils
layer = model.get_submodule("model.layers.18.mlp.act_fn")
layer2 = model.get_submodule("model.layers.31.mlp.down_proj")
recorder = modelutils.Recorder(layer, record_output=True, backward=False)
recorder2 = modelutils.Recorder(layer2, record_output=True, backward=False)

In [7]:
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors="pt")
output = model(**encoded_input)

In [8]:
print(recorder.recording)
print(recorder2.recording)
sum(sum(sum(recorder.recording == 0)))

tensor([[[0.0931, 0.0139, 0.0314,  ..., 0.0000, 0.0000, 0.0000],
         [0.1245, 0.0000, 0.0000,  ..., 0.0265, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.2297, 0.0000,  ..., 0.0000, 0.0000, 0.5524],
         [0.0000, 0.0022, 0.0000,  ..., 0.0000, 0.0000, 0.2228],
         [0.1003, 0.0000, 0.0274,  ..., 0.0000, 0.0000, 0.0000]]],
       grad_fn=<ReluBackward0>)
tensor([[[-1.1766,  2.3352,  6.6595,  ...,  1.9904, -1.7416,  3.1122],
         [-0.8362,  1.2037, -0.8024,  ..., -1.0820, -0.5340, -0.5148],
         [ 1.0643, -0.9274, -1.4926,  ..., -0.8088, -1.6298, -1.8066],
         ...,
         [ 4.2380, -4.3821, -1.6420,  ..., -3.2197, -4.1690, -4.2523],
         [ 3.2101, -2.4574,  0.7382,  ..., -2.9415,  0.8972, -2.0524],
         [ 0.8878,  2.2958,  4.5667,  ...,  1.1654, -2.0586,  1.2388]]],
       grad_fn=<UnsafeViewBackward0>)


tensor(74572)

In [9]:
output.hidden_states[1]

tensor([[[ 0.0008, -0.0216,  0.0394,  ..., -0.0171, -0.0351,  0.0442],
         [ 0.0371, -0.0065, -0.0127,  ...,  0.0360, -0.0056,  0.0015],
         [ 0.0006, -0.0350,  0.0098,  ..., -0.0288,  0.0186, -0.0051],
         ...,
         [ 0.0021, -0.0060, -0.0308,  ...,  0.0009,  0.0104,  0.0362],
         [ 0.0015,  0.0100,  0.0065,  ..., -0.0196,  0.0180, -0.0051],
         [ 0.0113, -0.0060, -0.0134,  ..., -0.0091,  0.0101, -0.0051]]],
       grad_fn=<AddBackward0>)

In [10]:
output.attentions
# output.pooler_output is not available for 'CausalLMOutputWithPast' object

In [11]:
wrapped_model = model.base_model
wrapped_model.__dict__

{'training': False,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('embed_tokens',
               Embedding(32000, 4096, padding_idx=0)),
              ('layers',
               ModuleList(
                 (0-31): 32 x LlamaDecoderLayer(
                   (self_attn): LlamaSdpaAttention(
                     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
                     (k_proj): Linear(in_features=

In [12]:
for idx, layer in enumerate(wrapped_model.layers):
    if idx == 0:
        print(layer.__dict__)

{'training': False, '_parameters': OrderedDict(), '_buffers': OrderedDict(), '_non_persistent_buffers_set': set(), '_backward_pre_hooks': OrderedDict(), '_backward_hooks': OrderedDict(), '_is_full_backward_hook': None, '_forward_hooks': OrderedDict(), '_forward_hooks_with_kwargs': OrderedDict(), '_forward_hooks_always_called': OrderedDict(), '_forward_pre_hooks': OrderedDict(), '_forward_pre_hooks_with_kwargs': OrderedDict(), '_state_dict_hooks': OrderedDict(), '_state_dict_pre_hooks': OrderedDict(), '_load_state_dict_pre_hooks': OrderedDict(), '_load_state_dict_post_hooks': OrderedDict(), '_modules': OrderedDict([('self_attn', LlamaSdpaAttention(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (rotary_emb): LlamaRotaryEmbedding()
)), ('mlp', LlamaMLP(
  (gate_pro