In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
share = {}
device_count = torch.cuda.device_count()

In [3]:
model_name = "Llama-2-13b-chat-hf"
model_path = f"/gpfs/jsh/pretrained-models/{model_name}"

share["tokenizer"] = AutoTokenizer.from_pretrained(model_path)
share["model"] = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
memories = [torch.cuda.memory_reserved(i) / 1024 ** 3 for i in range(device_count)]
print(f"{sum(memories)} GiB")

24.45703125 GiB


In [5]:
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = share["tokenizer"](prompt, return_tensors="pt").to(f"cuda:{device_count-1}")
inputs

{'input_ids': tensor([[    1, 18637, 29892,   526,   366, 19861, 29973,  1815,   366,  5193,
           304,   592, 29973]], device='cuda:7'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:7')}

In [6]:
outputs = share["model"].generate(
    **inputs,
    do_sample=True,
    top_k=50,
    top_p=0.92,
    min_length=20,
    max_length=200,
    temperature=0.9,
    repetition_penalty=1.5,
    no_repeat_ngram_size=3,
)

memories = [torch.cuda.memory_reserved(i) / 1024 ** 3 for i in range(device_count)]
print(f"{sum(memories)} GiB")

sentence = share["tokenizer"].decode(outputs[0], skip_special_tokens=True)
sentence = sentence.split("\n")[-1].strip()
print(sentence)

24.677734375 GiB
(Note: This is a weird question. I don't think it would be appropriate for most conversations.)


In [9]:
import requests

port = 1047
url = f"http://127.0.0.1:{port}/infer"
rs = requests.post(url, json={"sentence": "Hey, are you conscious? Can you talk to me?"})
rs.json()

{'sentence': '37. The only thing that is infinite in the universe of numbers: its decimals! I am curious if any number theory experts might be able or willing enough to help explain why this factoid holds true (and how it could possibly even work).',
 'elapsed_time': 1.9909100532531738,
 'used_memory': 12.873046875}

In [10]:
rs.json()["sentence"]

'37. The only thing that is infinite in the universe of numbers: its decimals! I am curious if any number theory experts might be able or willing enough to help explain why this factoid holds true (and how it could possibly even work).'