In [1]:
import torch
from transformers import AutoModel

In [2]:
%%time

model_path = "/Users/shawon/Codes/llama-hf-converted/llama-2-7b"
model = AutoModel.from_pretrained(model_path)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

CPU times: user 1min 29s, sys: 1min 32s, total: 3min 2s
Wall time: 1min 3s


In [3]:
model

LlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLUActivation()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path)

In [5]:
text = """
Hurricane Otis was a compact yet devastating tropical cyclone which made landfall in October 2023 near Acapulco as a Category 5 hurricane. 
Otis was the first Pacific hurricane to make landfall at Category 5 intensity and surpassed Hurricane Patricia as the strongest landfalling Pacific hurricane on record. 
The fifteenth tropical storm, tenth hurricane, eighth major hurricane,[nb 1] and second Category 5 hurricane of the 2023 Pacific hurricane season, Otis originated from a disturbance several hundred miles south of the Gulf of Tehuantepec. 
Initially forecast to only be a weak tropical storm at peak intensity, Otis instead underwent explosive intensification to reach peak winds of 165 mph (270 km/h) and made landfall at peak intensity.
Once inland, the hurricane rapidly weakened, before dissipating the following day.
"""

In [8]:
encoded = tokenizer.encode_plus(text, return_tensors="pt")

In [9]:
encoded

{'input_ids': tensor([[    1, 29871,    13, 29950,  1038, 26655,  8540,   275,   471,   263,
         11071,  3447,  2906,   579,  1218, 21881,  5094, 16513,   607,  1754,
          2982, 11950,   297,  5533, 29871, 29906, 29900, 29906, 29941,  2978,
           319,  5030,   352,  1111,   408,   263, 17943, 29871, 29945,   298,
          1038, 26655, 29889, 29871,    13, 29949, 28898,   471,   278,   937,
         14328,   298,  1038, 26655,   304,  1207,  2982, 11950,   472, 17943,
         29871, 29945, 26171,   322,  1190,  3364,   287,   379,  1038, 26655,
          4121,  2200,   423,   408,   278,  4549,   342,  2982, 11950,   292,
         14328,   298,  1038, 26655,   373,  2407, 29889, 29871,    13,  1576,
          8461, 19839, 21881, 14280, 29892,   260,  9097,   298,  1038, 26655,
         29892,   321, 18919,  4655,   298,  1038, 26655, 17094,  9877, 29871,
         29896, 29962,   322,  1473, 17943, 29871, 29945,   298,  1038, 26655,
           310,   278, 29871, 29906, 2

In [19]:
encoded["input_ids"].size()

torch.Size([1, 232])

In [10]:
model.eval()

LlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLUActivation()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

In [11]:
with torch.no_grad():
    out = model(**encoded)

In [13]:
out.keys()

odict_keys(['last_hidden_state', 'past_key_values'])

In [14]:
out.last_hidden_state.size()

torch.Size([1, 232, 4096])

In [18]:
len(out.past_key_values)

32

In [20]:
import torch.nn as nn

pooler = nn.Sequential(
    nn.Linear(4096, 384),
    nn.Tanh()
)

In [21]:
last_hidden = out.last_hidden_state

In [29]:
with torch.no_grad():
    pooled = pooler(last_hidden.mean(dim=1))

In [30]:
pooled.size()

torch.Size([1, 384])