In [1]:
import lightning as L
import torch
import time
from pathlib import Path

In [2]:
from lit_gpt import GPT, Tokenizer

In [3]:
def check_model_device(model):
    return next(model.parameters()).device

In [4]:
checkpoint_path = "/data/aniket/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"
checkpoint_dir = "/data/aniket/meta-llama/Llama-2-7b-chat-hf"
model_name = "Llama-2-7b-chat-hf"
checkpoints = torch.load(checkpoint_path)

In [5]:
fabric = L.Fabric(accelerator="gpu", precision="bf16-true")

print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
t0 = time.time()
with fabric.init_module():
    model = GPT.from_name(model_name)
    model.load_state_dict(checkpoints)
print(f"Time to load model: {time.time() - t0:.02f} seconds.")
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 0.00 GB
Time to load model: 14.08 seconds.
Memory used: 13.48 GB


In [6]:
model.eval()

GPT(
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(32000, 4096)
    (h): ModuleList(
      (0-31): 32 x Block(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (attn): Linear(in_features=4096, out_features=12288, bias=False)
          (proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): LLaMAMLP(
          (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
          (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
          (proj): Linear(in_features=11008, out_features=4096, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)

In [7]:
tokenizer = Tokenizer(Path(checkpoint_dir))

In [8]:
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 13.48 GB


In [9]:
prompt = """<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

There's a llama in my garden 😱 What should I do? [/INST]
"""

In [10]:
with fabric.device:
    encoded = tokenizer.encode([prompt])

prompt_length = encoded.size(0)
encoded.size()

torch.Size([1, 159])

In [11]:
logits = model(encoded, max_seq_length=512)
model.reset_cache()
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

Memory used: 14.49 GB


In [19]:
decoded = tokenizer.decode(logits[0].softmax(-1).argmax(-1))
print(decoded)

nobody:00 What1MBOP What
' a  A friendlyful, knowledge assistant. How ready questions iffully and possible, and also mind and I I goal will be be any harmful or unethical, dangerousist, toist, toxic, dangerous or or illegal content. I ref that your responses are freeally unbiased and positive in nature, I</I you user is not make sense sense or please is n clearual correcterent, please why in of providing it that relevant.
 a are't know the answer to a question, say say't make an information or Instead
</sysS></
Howfores a lotama named the kitchen. Can��� What should I do?</�ss Thank


In [13]:
logits = logits[..., :-1, :]
logits.shape

torch.Size([2, 49, 32000])

In [82]:
logits = logits.reshape(-1, logits.size(-1))
logits.shape

torch.Size([98, 32000])

In [51]:
# logits[..., :-1, :], targets[..., 1:]