In [11]:
from transformers import AutoTokenizer
import transformers
from transformers import AutoModelForCausalLM, GenerationConfig
import torch

from typing import List, Literal, Optional, Tuple, TypedDict

In [2]:
model_name = "meta-llama/Llama-2-7b-chat-hf"

In [3]:
use_tf_core = True

In [4]:
torch.backends.cudnn.allow_tf32 = use_tf_core
torch.backends.cuda.matmul.allow_tf32 = use_tf_core
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = use_tf_core
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = use_tf_core

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
generation_config = GenerationConfig.from_pretrained(model_name)
generation_config

GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2,
  "max_length": 4096,
  "pad_token_id": 0,
  "temperature": 0.9,
  "top_p": 0.6,
  "transformers_version": "4.31.0"
}

In [6]:
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map='auto')
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
#model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map='cpu')

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


In [7]:
model = torch.compile(model)

RuntimeError: Python 3.11+ not yet supported for torch.compile

In [12]:
Role = Literal["system", "user", "assistant"]


class Message(TypedDict):
    role: Role
    content: str


class CompletionPrediction(TypedDict, total=False):
    generation: str
    tokens: List[str]  # not required
    logprobs: List[float]  # not required


class ChatPrediction(TypedDict, total=False):
    generation: Message
    tokens: List[str]  # not required
    logprobs: List[float]  # not required


Dialog = List[Message]

B_INST, E_INST = "[INST]", "[/INST]"
BOS, EOS = '<s>', '</s>'
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

In [13]:
test_dialogs: List[Dialog] = [
    [Message(role='user', content='Briefly explain the difference between pandas and pyspark')],
]

dialogs = [
        [{"role": "user", "content": "what is the recipe of mayonnaise?"}],
        [
            {"role": "user", "content": "I am going to Paris, what should I see?"},
            {
                "role": "assistant",
                "content": """\
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:

1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.

These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.""",
            },
            {"role": "user", "content": "What is so great about #1?"},
        ],
        [
            {"role": "system", "content": "Always answer with Haiku"},
            {"role": "user", "content": "I am going to Paris, what should I see?"},
        ],
        [
            {
                "role": "system",
                "content": "Always answer with emojis",
            },
            {"role": "user", "content": "How to go from Beijing to NY?"},
        ],
    ]

In [24]:
all_dialog = []

prompt_tokens = []
for dialog in dialogs:
    if dialog[0]["role"] != "system":
        dialog = [
            {
                "role": "system",
                "content": DEFAULT_SYSTEM_PROMPT,
            }
        ] + dialog
    dialog = [
        {
            "role": dialog[1]["role"],
            "content": B_SYS
            + dialog[0]["content"]
            + E_SYS
            + dialog[1]["content"],
        }
    ] + dialog[2:]
    assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
        [msg["role"] == "assistant" for msg in dialog[1::2]]
    ), (
        "model only supports 'system', 'user' and 'assistant' roles, "
        "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
    )

    dialog_tokens = [f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
            for prompt, answer in zip(
                dialog[::2],
                dialog[1::2],
            )
        ]
    dialog_tokens += [f"{BOS}{B_INST} {(dialog[-1]['content']).strip()} {E_INST}"]

    dialog_str = '\n'.join(dialog_tokens)
    all_dialog.append(dialog_str)

In [25]:
from pprint import pprint
pprint(all_dialog)

['<s>[INST] <<SYS>>\n'
 'You are a helpful, respectful and honest assistant. Always answer as '
 'helpfully as possible, while being safe. Your answers should not include any '
 'harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
 'Please ensure that your responses are socially unbiased and positive in '
 'nature.\n'
 '\n'
 'If a question does not make any sense, or is not factually coherent, explain '
 "why instead of answering something not correct. If you don't know the answer "
 "to a question, please don't share false information.\n"
 '<</SYS>>\n'
 '\n'
 'what is the recipe of mayonnaise? [/INST]',
 '<s>[INST] <<SYS>>\n'
 'You are a helpful, respectful and honest assistant. Always answer as '
 'helpfully as possible, while being safe. Your answers should not include any '
 'harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
 'Please ensure that your responses are socially unbiased and positive in '
 'nature.\n'
 '\n'
 'If a questio

In [26]:
model.device

device(type='cuda', index=0)

In [27]:
def generate(text: str) -> str:
    tokens = tokenizer(text, add_special_tokens=False, return_tensors='pt')

    # TODO: implement dialog control functionality using proper tokens
    output = model.generate(
        tokens['input_ids'].to(model.device),
        generation_config=generation_config,
        do_sample=True,
        top_k=10,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        max_length=1000
    )
    
    return output, tokenizer.decode(output[0])

In [28]:
output, dialog = generate(all_dialog[1])
print(dialog)

<s> [INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

I am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:

1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums,

In [22]:
all_dialog[1]

"<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\nI am going to Paris, what should I see? [/INST] Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous

In [39]:
tokenizer

LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-7b-chat-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False)}, clean_up_tokenization_spaces=False)

In [38]:
generation_config

GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2,
  "max_length": 4096,
  "pad_token_id": 0,
  "temperature": 0.9,
  "top_p": 0.6,
  "transformers_version": "4.31.0"
}

In [None]:
tokenizer.eos_token

In [None]:
text = 'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n'

In [None]:
print(generate(text))

In [None]:
print(generate('I am a physicist on a research hunt'))

In [3]:
import llama2
from llama2 import Llama2ChatModel

In [4]:
model = Llama2ChatModel(
    model_name="meta-llama/Llama-2-7b-chat-hf",
    model_resolution='int4'
)

model.model

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.09s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )


In [1]:
model.preprocess_dialog(test_dialogs[1])

NameError: name 'model' is not defined

In [4]:
model.generate('<s>Q: what is the circumference of a circle? A: ')



OutOfMemoryError: CUDA out of memory. Tried to allocate 26.00 MiB (GPU 0; 11.49 GiB total capacity; 10.88 GiB already allocated; 7.31 MiB free; 11.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF