In [1]:
import torch
import torch.nn as nn
from typing import List
import safetensors
import safetensors.torch
from pathlib import Path
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit, QuantState
from bitsandbytes.nn.modules import Params4bit, Linear4bit
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers.generation.configuration_utils import GenerationConfig
from accelerate import init_empty_weights
from glob import glob
import os
from fastcore.parallel import parallel
import copy
from tqdm import tqdm
import transformers

In [2]:
transformers.__version__

'4.39.3'

In [3]:
from vllm.sequence import SequenceGroupMetadata, SequenceData

In [4]:
from vllm import LLM, SamplingParams

In [5]:
def replace_linear(model:nn.Module, linear_replacement:nn.Module, quant_config:dict|None=None,
                   skip_modules:List[str]=["lm_head"], **kwargs):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        linear_replacement (`torch.nn.Module`):
            The linear module that replaces the old one. Only expects standard arguments.
            If other arguments need to be passed, use a lambda.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    for name, module in model.named_children():
        if name in skip_modules:
            print(f"Skipping {name}")
            continue
        
        if len(list(module.children())) > 0:
            replace_linear(module, linear_replacement, quant_config, skip_modules, **kwargs)

        if isinstance(module, torch.nn.Linear):
            if issubclass(linear_replacement, Linear4bit):
                model._modules[name] = linear_replacement(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    **kwargs
                )
            # elif issubclass(linear_replacement, HQQLinear):
            #     model._modules[name] = linear_replacement(module, quant_config, **kwargs)
            else:
                raise ValueError(f"Unsupported linear replacement: {type(linear_replacement)}")
    return model

In [6]:
def load_and_quantize(module:nn.Module, name:str, value:torch.Tensor, device:torch.device=None, dtype:torch.dtype=None,
                      skip_names:list[str]=[], is_meta_rank:bool=False, low_memory:bool=True, verbose:bool=False,
                      quant_method:str='bnb', is_dora:bool=False):
    """
    Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.

    Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
    """
    def place_on_device(value):
        if is_meta_rank:
            device = 'meta'
        elif low_memory:
            device = 'cpu'
        return value.to(device=device, dtype=dtype)

    if any([skip_name in name for skip_name in skip_names]):
        if verbose:
            print(f"Skipping {name} because it is in skip_names")
        return

    module_key, _, value_key = name.rpartition('.')
    try:
        submodule = module.get_submodule(module_key)
    except AttributeError as e:
        print(f"Module {module_key} not found:\n{e}")
        return

    try:
        if quant_method=='bnb':
            param = submodule.get_parameter(value_key)
            if isinstance(param, Params4bit):
                # With `sync_module_states=True`, a meta device Params4bit needs to be the same
                # shape as the quantized Params4bit with an initialized quant_state. However,
                # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
                # workaround quantizes Params4bit to initialize quant_state on all ranks, then
                # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
                if is_dora:
                    setattr(submodule, "dora_scale", value.norm(p=2, dim=1).to(dtype=dtype).to("cpu"))                
                    print("DORA scale initialized")
                value = type(param)(value.to(device=device, dtype=dtype).data, **param.__dict__).cuda(device)
                if is_meta_rank:
                    value = type(param)(value.data.to("meta"), **value.__dict__)
                elif low_memory:
                    value = type(param)(value.data.to("cpu"), **value.__dict__)
                # print("Loaded quantized layer")
            else:
                value = type(param)(place_on_device(value).data)
                # print("Loaded regular layer")
    except AttributeError:
        # it's a buffer
        value = place_on_device(value)
        pass
    setattr(submodule, value_key, value)

def load_and_quantize_parallel(name_param, model, **kwargs):
    name, param = name_param
    load_and_quantize(model, name, param, **kwargs)

In [7]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.pad_token = tokenizer.unk_token

#### Load HF Model

In [8]:
orca_math_model_dir = "/home/ubuntu/models/llama-7b-orca-math-100k-full"

In [9]:
MODEL_NAME = "meta-llama/Llama-2-7b-hf"
cfg = AutoConfig.from_pretrained(MODEL_NAME)
cfg._attn_implementation = "flash_attention_2"
cfg._attn_implementation_internal = "flash_attention_2"
skip_modules = ["lm_head"]
load_param_skip_names = ['inv_freq']
compute_dtype = torch_dtype = torch.bfloat16

In [10]:
# !pip install -U transformers

In [11]:
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
# https://github.com/Dao-AILab/flash-attention/issues/742
# !pip install transformers==4.33.1

with init_empty_weights():
    model = AutoModelForCausalLM.from_config(cfg)
    model.model = replace_linear(model.model, Linear4bit, compute_dtype=compute_dtype,
                                 quant_type='nf4', compress_statistics=False,
                                 quant_storage=torch.uint8, skip_modules=skip_modules)
# For some reason slower.
#     for layer in model.model.layers: 
#         m = getattr(layer, 'self_attn')
#         setattr(layer, 'self_attn', LlamaFlashAttention2(m.config, m.layer_idx))
# model.config._attn_implementation = "flash_attention_2"
# model.config._attn_implementation_internal = "flash_attention_2"
model.is_loaded_in_4bit = True

In [12]:
weights = safetensors.torch.load_file(glob(os.path.join(orca_math_model_dir, "*.safetensors"))[0])

In [13]:
parallel(load_and_quantize_parallel, 
         iter(weights.items()), 
         n_workers=8, 
         threadpool=True,
         model=model, 
         dtype=torch_dtype, 
         device=torch.cuda.current_device(),
         skip_names=load_param_skip_names,
         is_meta_rank=False,
         verbose=True,
         quant_method="bnb",
         is_dora=False)

(#291) [None,None,None,None,None,None,None,None,None,None...]

In [14]:
model.cuda();

#### Load VLLM Model

In [8]:
orca_math_model_dir = "/workspace/models/llama-7b-orca-math-100k-full-quantized"
# orca_math_model_dir = "/workspace/models/llama-7b-orca-math-100k-bnb-qdora-vllm"
llm = LLM(model=orca_math_model_dir, tokenizer="meta-llama/Llama-2-7b-hf", dtype="bfloat16", 
          tensor_parallel_size=4, enforce_eager=True, quantization="bnb", gpu_memory_utilization=0.9)



2024-04-09 11:54:06,744	INFO worker.py:1752 -- Started a local Ray instance.


INFO 04-09 11:54:07 llm_engine.py:70] Initializing an LLM engine (v0.3.3) with config: model='/workspace/models/llama-7b-orca-math-100k-full-quantized', tokenizer='meta-llama/Llama-2-7b-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=4, disable_custom_all_reduce=True, quantization=bnb, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 04-09 11:54:15 pynccl_utils.py:13] vLLM is using nccl==2.18.1
INFO 04-09 11:54:15 selector.py:44] flash_attn is not found.
INFO 04-09 11:54:15 selector.py:20] Using XFormers backend.
[36m(RayWorkerVllm pid=214327)[0m INFO 04-09 11:54:16 pynccl_utils.py:13] vLLM is using nccl==2.18.1
[36m(RayWorkerVllm pid=214327)[0m INFO 04-09 11:54:17 selector.py:44] flash_attn is not found.
[36m(RayWorkerVllm pid=214327)[0m INFO 04-09 11:54:17 selector.py:20] Using XFormers backend.
INFO 04-09 11:54:20

In [10]:
# llm = LLM(model="robertgshaw2/llama-2-7b-chat-marlin", tensor_parallel_size=2, 
#           enforce_eager=False, quantization="marlin", gpu_memory_utilization=0.9)

#### Benchmark 

In [9]:
from datasets import load_dataset

In [10]:
def timed(fn, *args, **kwargs):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

In [11]:
dataset = load_dataset("microsoft/orca-math-word-problems-200k")['train'].shuffle(seed=42)
# train with 10k for starters. Then 100k.
# dataset = dataset.select(range(0,100000))

# select last 5k as validation
dataset = dataset.select(range(len(dataset)-5000,len(dataset)))

In [12]:
import re

def extract_last_number_or_ratio(s):
    # Find all sequences of digits, possibly with leading currency symbols, decimal points, and ratios
    patterns = re.findall(r'[\$€£]?\d+(?:\.\d+)?(?:\:\d+(?:\.\d+)?)?', s)
    
    # Return the last pattern found, or None if there are no matches
    if patterns:
        return patterns[-1]
    else:
        return None

# Example usage
examples = [
    "The item costs $123.45, but with a discount of $10.00, the final price is $113.45.",
    "The ratio of water to concentrate is 5.5:1 for the mixture.",
    "The investment return was 10:1.",
    "Answer is 42.3.\nAnswer is 42"
]

for s in examples:
    print(f"The last occurring number or ratio in \"{s}\" is: {extract_last_number_or_ratio(s)}")


The last occurring number or ratio in "The item costs $123.45, but with a discount of $10.00, the final price is $113.45." is: $113.45
The last occurring number or ratio in "The ratio of water to concentrate is 5.5:1 for the mixture." is: 5.5:1
The last occurring number or ratio in "The investment return was 10:1." is: 10:1
The last occurring number or ratio in "Answer is 42.3.
Answer is 42" is: 42


In [13]:
short_answers_gt = parallel(extract_last_number_or_ratio, dataset['answer'], progress=True)

In [14]:
inputs = [f"###Question:\n{question}\n###Answer:\n" for question in dataset[:500]['question']]

In [15]:
len(inputs)

500

In [16]:
valid_dataset = dataset

In [23]:
answers_pred = []
short_answers_pred = []
bs = 8
for i in tqdm(range(0,len(valid_dataset.select(range(50))),bs)):
    
    inputs = [f"###Question:\n{question}\n###Answer:\n" for question in valid_dataset[i:i+bs]['question']]
    input_ids = tokenizer(inputs)['input_ids']
    
    max_toks = max(len(toks) for toks in input_ids)
    b = torch.stack([torch.tensor(((max_toks-len(toks))*[tokenizer.unk_token_id])+toks) for toks in input_ids])
    input_lens = [len(toks) for toks in input_ids]
    
    output = model.generate(b.cuda(), 
                            do_sample=False, 
                            use_cache=True,
                            pad_token_id=tokenizer.unk_token_id, 
                            eos_token_id=tokenizer.eos_token_id, 
                            max_new_tokens=1024).cpu()
    
    pred = [tokenizer.decode(o[o!=tokenizer.unk_token_id][n:]) for o,n in zip(output,input_lens)]
    short_pred = [extract_last_number_or_ratio(p) for p in pred]
    
    answers_pred.extend(pred)
    short_answers_pred.extend(short_pred)

100%|██████████████████████████████████████████████████████████████████████| 7/7 [07:05<00:00, 60.72s/it]


In [24]:
sum(p==g for p,g in zip(short_answers_pred, short_answers_gt))/len(short_answers_pred)

0.17857142857142858

In [16]:
outputs = llm.generate(inputs[:50], SamplingParams(temperature=0.0, stop_token_ids=[tokenizer.eos_token_id], max_tokens=1024))

Processed prompts: 100%|██████████| 50/50 [00:39<00:00,  1.26it/s]


In [17]:
short_answers_pred = [extract_last_number_or_ratio(o.outputs[0].text) for o in outputs]

In [18]:
sum(p==g for p,g in zip(short_answers_pred, short_answers_gt))/len(short_answers_pred)

0.18

In [19]:
sum([len(o.outputs[0].token_ids) for o in outputs])

15923

In [19]:
50 /(39/60)

76.92307692307692

In [20]:
token_per_sec = []
for inp in tqdm(inputs[:3]):
    time_taken = timed(llm.generate,[inp], SamplingParams(temperature=0.0, stop_token_ids=[tokenizer.eos_token_id], max_tokens=1024), use_tqdm=False)
    token_per_sec.append(len(time_taken[0][0].outputs[0].token_ids) / time_taken[1])

100%|██████████| 3/3 [00:26<00:00,  8.68s/it]


In [21]:
token_per_sec

[30.62790650031754, 30.732849847282548, 31.08896707737432]

In [20]:
tokenizer("hello world this is")

{'input_ids': [1, 22172, 3186, 445, 338], 'attention_mask': [1, 1, 1, 1, 1]}

In [21]:
time_taken = timed(llm.generate,["hello world this is"], SamplingParams(temperature=0.0, stop_token_ids=[tokenizer.eos_token_id], max_tokens=256), use_tqdm=False)

In [22]:
len(time_taken[0][0].outputs[0].token_ids) / time_taken[1]

29.96343040495391

In [38]:
# dollar / sec - sending 1 request without batching
1e6 * (0.74 / 60 / 27.5) 

448.4848484848485

In [35]:
# dollar / sec - sending requests with continuous batching
1e6 * ((0.74 / 60) / (10604 / 12)) 

13.956997359486987

In [28]:
# dollar / sec - sending 1 request without batching
1e6 * (1.782 / 2 / 60 / 27.2) 

545.9558823529412

In [43]:
# dollar / sec - sending requests with continuous batching
1e6 * ((1.782 / 60) / (15555 / 39)) 

74.46480231436837

In [25]:
# dollar / sec - sending 1 request without batching (if non-eager mode fixed)
1e6 * (0.74 / 60 / 65.2) 

189.16155419222903

In [22]:
# dollar / sec - sending requests with continuous batching (if non-eager mode fixed)
1e6 * ((0.74 / 60) / (12605 / 8)) 

7.827581647494381

In [34]:
def single_infer(question):
    inputs = [f"###Question:\n{question}\n###Answer:\n"]
    input_ids = tokenizer(inputs)['input_ids']
    
    max_toks = max(len(toks) for toks in input_ids)
    b = torch.stack([torch.tensor(((max_toks-len(toks))*[tokenizer.unk_token_id])+toks) for toks in input_ids])
    input_lens = [len(toks) for toks in input_ids]
    
    output = model.generate(b.cuda(), 
                            do_sample=False, 
                            use_cache=True,
                            pad_token_id=tokenizer.unk_token_id, 
                            eos_token_id=tokenizer.eos_token_id, 
                            max_new_tokens=1024).cpu()

    tokens = [o[o!=tokenizer.unk_token_id][n:] for o,n in zip(output,input_lens)][0]
    pred = [tokenizer.decode(tokens)]
    return pred, tokens, len(tokens)

In [41]:
token_per_sec = []
for question in tqdm(dataset[:10]['question']):
    time_taken = timed(single_infer, question)
    token_per_sec.append(time_taken[0][-1] / time_taken[1])

100%|████████████████████████████████████████████████████████████████████| 10/10 [01:47<00:00, 10.72s/it]


In [42]:
token_per_sec

[18.71630187320748,
 19.143428841315547,
 19.14635015085839,
 19.184236823150655,
 19.221320212171463,
 19.072148874737856,
 19.108014302928094,
 19.104715232776528,
 18.919215197450832,
 19.106046480172743]