In [21]:
from mini_llms.llama2 import Llama2P, ModelArgs
from transformers import LlamaConfig, LlamaModel
from dataclasses import dataclass
import torch

In [22]:
config = LlamaConfig()
config.vocab_size = 128
config.max_position_embeddings = 2048
config.num_attention_heads = 16
config.num_key_value_heads = 16 # For MHA   
config.num_hidden_layers = 24
config.hidden_size = 128
config.intermediate_size = 128 * 4
config.attention_dropout = 0.0

llama2_hf = LlamaModel(config)

In [23]:
@dataclass
class TransformerConfig:
    vocab_size: int = 128
    n_positions: int = 2048
    n_head: int = 16 
    n_layer: int = 24
    n_embd: int = 128
    attn_pdrop: int = 0.0
    embd_pdrop: int = 0.0
    resid_pdrop: float = 0.0
    flash_attention: bool = False

transformer_config = TransformerConfig()

llama2_p = Llama2P(transformer_config)

In [30]:
# Copy the weights from the llama2_hf model to the llama2_p model
sd_hf = llama2_hf.state_dict()
sd_keys_hf = sd_hf.keys()

sd = llama2_p.state_dict()
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('mask')]

for k_hf, k in zip(sd_keys_hf, sd_keys):
    assert sd[k].shape == sd_hf[k_hf].shape
    with torch.no_grad():
        sd[k].copy_(sd_hf[k_hf])

In [31]:
batch = torch.ones(2, 72) * 2
batch[1, -10:] *= 0
batch = batch.long()

In [32]:
op = llama2_p(batch)

In [33]:
o = llama2_hf(batch,
              attention_mask=(batch != 0).long(),
).last_hidden_state

In [34]:
o

tensor([[[ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         ...,
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001]],

        [[ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         ...,
         [ 0.2310, -1.7526, -1.1323,  ...,  0.7353,  0.7956, -1.0362],
         [ 0.2310, -1.7526, -1.1323,  ...,  0.7353,  0.7956, -1.0362],
         [ 0.2310, -1.7526, -1.1323,  ...,  0.7353,  0.7956, -1.0362]]],
       grad_fn=<MulBackward0>)

In [35]:
op

tensor([[[ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         ...,
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001]],

        [[ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         [ 0.2050, -1.7073, -1.2383,  ...,  0.8298,  0.7385, -0.9001],
         ...,
         [ 0.2310, -1.7526, -1.1323,  ...,  0.7353,  0.7956, -1.0362],
         [ 0.2310, -1.7526, -1.1323,  ...,  0.7353,  0.7956, -1.0362],
         [ 0.2310, -1.7526, -1.1323,  ...,  0.7353,  0.7956, -1.0362]]],
       grad_fn=<MulBackward0>)