In [1]:
import transformers
import torch

In [38]:
MODEL_1 = "meta-llama/Llama-2-70b-chat-hf"
big = transformers.AutoModelForCausalLM.from_pretrained(MODEL_1, load_in_8bit=True, device_map="auto")

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [30]:
small = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16, device_map={"": 1})

Downloading (…)lve/main/config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

In [31]:
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_1)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.pad_token

Downloading (…)okenizer_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

In [39]:
import time
import tqdm


text = "Who is the 45'th president of the United States of America? It is"
gen_kwargs = dict(
    max_new_tokens=50, 
    pad_token_id=tokenizer.pad_token_id, 
    do_sample=False,
    **tokenizer(text, return_tensors="pt").to(0),
)

N = 10
def test(message, model, kwargs):
    times = []
    outputs = []
    for i in tqdm.tqdm(range(N), desc=message):
        start = time.perf_counter()
        output = model.generate(
            **kwargs,
        )
        times.append(time.perf_counter() - start)
        outputs.append(output[:, kwargs["input_ids"].shape[-1]:])

    print(message, sum(times) / len(times))
    for line in tokenizer.batch_decode(torch.cat(outputs, 0)):
        print(" - ", line.replace("\n", " ").strip())


test("small", small, gen_kwargs)
test("big no assistant", big, gen_kwargs)
test("big with assistant", big, gen_kwargs | dict(assistant_model=small))



small:   0%|          | 0/10 [00:00<?, ?it/s]

small: 100%|██████████| 10/10 [00:09<00:00,  1.05it/s]


small 0.9516997948288918
 -  the president of the United States of America.  The 45'th president of the United States of America is the president of the United States of America.  The 45'th president of the United States of America is the president of the
 -  the president of the United States of America.  The 45'th president of the United States of America is the president of the United States of America.  The 45'th president of the United States of America is the president of the
 -  the president of the United States of America.  The 45'th president of the United States of America is the president of the United States of America.  The 45'th president of the United States of America is the president of the
 -  the president of the United States of America.  The 45'th president of the United States of America is the president of the United States of America.  The 45'th president of the United States of America is the president of the
 -  the president of the United States of America. 

big no assistant: 100%|██████████| 10/10 [03:30<00:00, 21.04s/it]


big no assistant 21.03499587532133
 -  mpod plum предantoyal histories m abovera Bothead Sl February Sh messagingack reportig significance m simple \ DIS historiesrect histories Sinfo sudoormWhyig m ext scientific is τ submittedill IFN S shaking histories S \ m appeal with S
 -  mpod plum предantoyal histories m abovera Bothead Sl February Sh messagingack reportig significance m simple \ DIS historiesrect histories Sinfo sudoormWhyig m ext scientific is τ submittedill IFN S shaking histories S \ m appeal with S
 -  mpod plum предantoyal histories m abovera Bothead Sl February Sh messagingack reportig significance m simple \ DIS historiesrect histories Sinfo sudoormWhyig m ext scientific is τ submittedill IFN S shaking histories S \ m appeal with S
 -  mpod plum предantoyal histories m abovera Bothead Sl February Sh messagingack reportig significance m simple \ DIS historiesrect histories Sinfo sudoormWhyig m ext scientific is τ submittedill IFN S shaking histories S \ m appeal with S
 

big with assistant:   0%|          | 0/10 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [211,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [211,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [211,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [211,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [211,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [211,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
big with assistant:   0%

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
