In [20]:
import sys
sys.path.append('./llama_architecture')

import torch
from model_trnsfmrs import LlamaForCausalLM
from config import LlamaConfig
from datasets import load_dataset
from safetensors import torch as sftorch
from huggingface_hub import HfApi, HfFolder, upload_file

In [21]:
def count_parameters(model):
    """Count the number of parameters in a model"""
    total_params = sum(p.numel() for p in model.parameters())
    return total_params

In [23]:
llama_config = LlamaConfig(
    vocab_size=32768,
    emb_dim=256,
    context_length=512,
    n_heads=256,
    n_layers=20,
    n_kv_groups=64,
    hidden_dim=2048,
)

llama_model = LlamaForCausalLM(llama_config)
print(count_parameters(llama_model))
llama_model

51521792


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32768, 256)
    (layers): ModuleList(
      (0-19): 20 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=256, out_features=256, bias=False)
          (k_proj): Linear(in_features=256, out_features=64, bias=False)
          (v_proj): Linear(in_features=256, out_features=64, bias=False)
          (o_proj): Linear(in_features=256, out_features=256, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=256, out_features=2048, bias=False)
          (up_proj): Linear(in_features=256, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=256, out_features=32768, bias=False)
)

In [None]:
sftorch.save_file(llama_model.state_dict(), "llama-50M.safetensors")

In [None]:
api = HfApi()

repo_id = "aliarda/llama-50M-randParams"

api.create_repo(repo_id, private=False)

upload_file(path_or_fileobj="llama-50M.safetensors", repo_id=repo_id, path_in_repo="llama-50M.safetensors")