In [1]:
from transformers import LlamaForCausalLM, LlamaConfig, LlamaModel
import torch

In [2]:
llama_config = LlamaConfig(vocab_size=256,
                           hidden_size=1024,
                           intermediate_size=4096,
                           num_hidden_layers=12,
                           num_attention_heads=16,
                           pad_token_id=250,
                           bos_token_id=251,
                           eos_token_id=252,)
# save config
llama_config.save_pretrained('/home/t-zeqianju/yuancwang/AmphionOpen/data/llama_config')

llama_model = LlamaForCausalLM(llama_config)
print(llama_model)
# print number of parameters
print(f"Number of parameters: {llama_model.num_parameters()/1e6:.2f}M")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(256, 1024, padding_idx=250)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (up_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (down_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (l

In [17]:
def add_padding(input_ids, pad_token_id, max_length=None):
    # input_ids: List[List[int]]
    # pad_token_id: int
    # max_length: int
    if max_length is None:
        max_length = max(len(ids) for ids in input_ids)
    padded_input_ids = []
    for ids in input_ids:
        padded_ids = ids + [pad_token_id] * (max_length - len(ids))
        padded_input_ids.append(padded_ids)
    return padded_input_ids

def add_bos(input_ids, bos_token_id):
    # input_ids: List[List[int]]
    # bos_token_id: int
    bos_input_ids = [[bos_token_id] + ids for ids in input_ids]
    return bos_input_ids

def add_eos(input_ids, eos_token_id):
    # input_ids: List[List[int]]
    # eos_token_id: int
    eos_input_ids = [ids + [eos_token_id] for ids in input_ids]
    return eos_input_ids

def add_mask(input_ids, pad_token_id):
    # input_ids: List[List[int]]
    # pad_token_id: int
    attention_mask = [[int(token_id != pad_token_id) for token_id in ids] for ids in input_ids]
    return attention_mask

def add_labels(input_ids, pad_token_id):
    # input_ids: List[List[int]]
    # pad_token_id: int
    labels = [[token_id if token_id != pad_token_id else -100 for token_id in ids] for ids in input_ids]
    return labels

input_token_ids = [[1, 4, 5, 6, 3], [1, 2, 3], [4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14]]

input_token_ids = add_bos(input_token_ids, llama_config.bos_token_id)
input_token_ids = add_eos(input_token_ids, llama_config.eos_token_id)
input_token_ids = add_padding(input_token_ids, llama_config.pad_token_id)
attention_mask = add_mask(input_token_ids, llama_config.pad_token_id)
labels = add_labels(input_token_ids, llama_config.pad_token_id)

print(input_token_ids)
print(attention_mask)
print(labels)

input_token_ids = torch.tensor(input_token_ids)
attention_mask = torch.tensor(attention_mask)
labels = torch.tensor(labels)

print(input_token_ids.shape)
print(attention_mask.shape)
print(labels.shape)

[[251, 1, 4, 5, 6, 3, 252], [251, 1, 2, 3, 252, 250, 250], [251, 4, 5, 252, 250, 250, 250], [251, 6, 7, 8, 9, 10, 252], [251, 11, 12, 13, 14, 252, 250]]
[[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0]]
[[251, 1, 4, 5, 6, 3, 252], [251, 1, 2, 3, 252, -100, -100], [251, 4, 5, 252, -100, -100, -100], [251, 6, 7, 8, 9, 10, 252], [251, 11, 12, 13, 14, 252, -100]]
torch.Size([5, 7])
torch.Size([5, 7])
torch.Size([5, 7])


In [18]:
device = "cuda"
llama_model.to(device)
input_token_ids = input_token_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)

out = llama_model(input_token_ids, attention_mask=attention_mask, labels=labels, return_dict=True)
print(out.loss.item())
print(out.logits.shape)

5.326286792755127
torch.Size([5, 7, 256])
