# Efficient Initialization of Large Models

* [Blog](https://lightning.ai/pages/community/efficient-initialization-of-large-models/)

We will be using `EleutherAI/pythia-1b` for example.

In [6]:
import lightning as L
import torch
import time
from pathlib import Path

In [1]:
from lit_gpt import GPT, Tokenizer

In [2]:
def check_model_device(model):
    return next(model.parameters()).device

In [5]:
checkpoint_path = "/data/aniket/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"
checkpoint_dir = "/data/aniket/meta-llama/Llama-2-7b-chat-hf"
model_name = "Llama-2-7b-chat-hf"
checkpoints = torch.load(checkpoint_path)

In [4]:
t0 = time.time()
model = GPT.from_name(model_name)
model.load_state_dict(checkpoints)
print(f"Time to load model: {time.time() - t0:.02f} seconds.")
print("This is your old nn.Module:", isinstance(model, torch.nn.Module))
print("device", next(model.parameters()).device)

Time to load model: 39.23 seconds.
This is your old nn.Module: True
device cpu


In [5]:
next(model.parameters()).device

device(type='cpu')

In [7]:
fabric = L.Fabric(accelerator="gpu", precision="bf16-true")

print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
t0 = time.time()
with fabric.init_module():
    model = GPT.from_name(model_name)
    model.load_state_dict(checkpoints)
print(f"Time to load model: {time.time() - t0:.02f} seconds.")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 0.00 GB


NameError: name 'checkpoints' is not defined


7B

32 bits => 4 bytes

7 x 10^9 x 32 bits => 28 x 10^9 x bytes

28 GB memory

---

16 bits => 2 bytes

7 x 10^9 x 16 bits => 14 x 10^9 x bytes

14 GB memory

## Let's play with non Lightning models now

In [7]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [5]:
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf",
                                                device_map={"": 0},)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 0.00 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Memory used: 4.42 GB


In [16]:
check_model_device(model_hf)

device(type='cuda', index=0)

In [6]:
fabric = L.Fabric(accelerator="gpu", precision="bf16-true")

In [7]:
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
with fabric.init_module():
    model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 0.00 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Memory used: 13.57 GB


In [13]:
model

GPT(
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(32000, 4096)
    (h): ModuleList(
      (0-31): 32 x Block(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (attn): Linear(in_features=4096, out_features=12288, bias=False)
          (proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): LLaMAMLP(
          (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
          (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
          (proj): Linear(in_features=11008, out_features=4096, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)

In [21]:
model_hf

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )


In [15]:
tokenizer_hf = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

In [38]:
encoded = tokenizer_hf("I am Aniket", max_length=512, return_tensors="pt", truncation=True)
encoded

{'input_ids': tensor([[  1, 306, 626, 530, 638, 300]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [8]:
tokenizer = Tokenizer(Path(checkpoint_dir))

In [9]:
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 13.48 GB


In [10]:
with fabric.device:
    encoded = tokenizer.encode(["What is my name?"*10]*2)
encoded.size()

torch.Size([2, 50])

In [14]:
logits = model(encoded, max_seq_length=512)
model.reset_cache()
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 14.92 GB


In [15]:
logits.shape

torch.Size([2, 50, 32000])

In [13]:
logits = logits[..., :-1, :]
logits.shape

torch.Size([2, 49, 32000])

In [82]:
logits = logits.reshape(-1, logits.size(-1))
logits.shape

torch.Size([98, 32000])

In [51]:
# logits[..., :-1, :], targets[..., 1:]