In [2]:
import torch
from modeling_llama import LlamaForCausalLM
from tokenization_llama_fast import LlamaTokenizerFast

def load_model(model_name: str, cache_dir: str):
    tokenizer = LlamaTokenizerFast.from_pretrained(model_name, cache_dir=cache_dir)
    model = LlamaForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=torch.float16, attn_implementation="flash_attention_2")
    return model, tokenizer

In [3]:
device = torch.device("cuda:0")
model, tokenizer = load_model("meta-llama/Llama-2-7b-chat-hf", cache_dir='/workspace/hf_home/')
model.eval()
model.to(device)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


my LlamaModel
_attn_implementation flash_attention_2


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.25s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head

In [4]:
x = tokenizer("Hello I am the llama, ", return_tensors="pt")
print("x", x)
x = x.to(device)
y = model.generate(x.input_ids, min_new_tokens=1,  max_new_tokens=128, do_sample=True,  temperature=1.0, repetition_penalty=1.2)
print("y", y)

x {'input_ids': tensor([[    1, 15043,   306,   626,   278, 11148,  3304, 29892, 29871]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}
y tensor([[    1, 15043,   306,   626,   278, 11148,  3304, 29892, 29871, 29906,
         29900, 29896, 29929,    13, 29930, 15918,  1596,   373,  5650, 19239,
           373,  8112,  9451,   411,  1274,   719,   506, 10675,   322, 18187,
         29889,    13, 10994,   306,  1913,   450,   365, 29880,  3304,   313,
         16432,   511, 29871, 29906, 29900, 29896, 29929,    13, 27103,  1596,
           373,  5650, 19239,   373,  8112,  9451,   411,  1274,   719,   506,
         10675,   322, 18187, 29889,     2]], device='cuda:0')


In [5]:
tokenizer.batch_decode(sequences=y)

['<s> Hello I am the llama, 2019\n* Digital print on paper mounted on wood panel with acrylic paint and fabric.\nHello I Am The Llama (detail), 2019\nDigital print on paper mounted on wood panel with acrylic paint and fabric.</s>']