# Efficient LLM serving
By Ankush Chander

## Transformers:quick recap
<!-- ![](img/llm_serving/decoder_only_model_flow.png)   -->
<img src="img/llm_serving/decoder_only_model_flow.png" width="650" height="650">

[Image credits: Decoder only Large Language Models (LLM) for text generation - a primer](https://www.linkedin.com/pulse/decoder-only-large-language-models-llm-text-generation-nikhil-goel/)

## LLM hardware requirements
1. **Model size vs GPU memory:**
  Typically, LLM are trained with full- or half-precision floating point numbers (float32 and float16). One float16 has 16 bits, or 2 bytes, and so one billion parameters require 2 gigabytes. You can add 25% more for memory overhead, 50% more for ideal scenario(large batch decoding)
    
| Model Size(x)<br> In billions         | Minimum GPU memory<br> 2\*x\*(1 + .25)| Recommended GPU memory<br> 2\*x\*(1 + .5) |
| ------------------ | ----------------------- | --------------------------- |
| 1B  | 2.5GB          | 3GB           |
| 7B  | 17.5GB          | 21GB                |
| 13B  | 32.5GB          | 39GB                |
| 70B  | 175GB          | 210GB                |
   
    

2. **CPU vs GPU memory:**
   Good rule of thumb is to have CPU RAM twice as much as the GPU VRAM 


## Key metrics
1. **Time To First Token (TTFT)**: *How quickly users start seeing the model's output after entering their query.* Low waiting times for a response are essential in real-time interactions, but less important in offline workloads. This metric is driven by the time required to process the prompt and then generate the first output token.
2. **Time Per Output Token (TPOT)**: *Time to generate an output token for *each* user that is querying our system.* This metric corresponds with how each user will perceive the "speed" of the model. For example, a TPOT of 100 milliseconds/tok would be 10 tokens per second per user, or ~450 words per minute, which is faster than a typical person can read.
3. **Latency**: The overall time it takes for the model to generate the full response for a user. Overall response latency can be calculated using the previous two metrics: latency = *(TTFT)* + *(TPOT)* \* (the number of tokens to be generated).
4. **Throughput**: The number of output tokens per second an inference server can generate across all users and requests.

## Text generation
Text generation consist of two phases:
1. **Prefill phase**: Input prompt is tokenized and all tokens are processed in parellel.
2. **Generation phase**: Next token is generated based on tokens seen so far. Works in an *autoregressive* manner such that output token at one step becomes part of the input in the next step.

In [15]:
# import necessary packages
import itertools
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import random
from tqdm.auto import tqdm
# add seed to get consistent results
random.seed(42)

In [16]:
# initialize model and tokenizer
device = "cpu"
model = AutoModelForCausalLM.from_pretrained("gpt2")
# use gpu
model.to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")


In [17]:
# form request queue
input_samples = [
    ("Quick brown fox", 15),
    ("Quick brown fox jumped", 15),
    ("Quick brown fox jumped on", 15),
    ("Quick brown fox jumped on the", 15),
]
queue_size = 12
request_queue = [random.choice(input_samples) for i in range(queue_size)]


def generate_token(next_inputs):
    with torch.no_grad():
        outputs = model(**next_inputs)
        logits = outputs.logits
        last_logits = logits[:, -1, :]
        # print(last_logits.shape)
        next_idx = last_logits.argmax()
        return next_idx


def generate_text(text: str, num_tokens):
    # tokenize input
    # prefill stage
    next_inputs = tokenizer(text, return_tensors="pt")
    next_inputs.to(device)
    # print(input_tokens)

    # generation phase
    generated_tokens = []
    for i in range(num_tokens):
        next_idx = generate_token(next_inputs)
        generated_tokens.append(next_idx)
        next_inputs = {
            "input_ids": torch.cat([next_inputs["input_ids"], next_idx.reshape(1, 1).to(device)], dim=1),
            "attention_mask": torch.cat([next_inputs["attention_mask"], torch.tensor([[1]]).to(device)], dim=1)
        }
        # append next_idx to inputs.input_ids
    print(f"{text}\x1b[31m{tokenizer.decode(generated_tokens)}\x1b[0m")


s_time = time.time()
for text, num_tokens in tqdm(request_queue):
    generate_text(text, num_tokens)
    # break
print(f"sequential generate_text took:{time.time() - s_time}")

  0%|          | 0/12 [00:00<?, ?it/s]

Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox jumped on[31m top of her and started to run.

"I'm sorry,[0m
Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox jumped on the[31m back of the car and ran away.

"I'm sorry,[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Qu

## KV caching
**KV-caching** is a technique to speed up token generation by storing some of the tensors in the attention head for use in subsequent generation steps.


**How much memory kv cache takes?** 
- Formula: `2 x precision x layers x dimension x sequence_length x batch`
- Elements:
- 2 for K and V matrices.
- Precision: number of bytes per parameter.
- Layers: total number of layers.
- Dimension: size of embeddings per layer.
- Sequence_length: length of the sequence to generate.
- Batch: batch size.
			

In [18]:
def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values


# use past_key_values
def generate_text_with_kv_caching(text: str, num_tokens):
    # tokenize input
    s_time = time.time()
    next_inputs = tokenizer(text, return_tensors="pt")
    next_inputs.to(device)
    generated_words = []
    for i in range(num_tokens):            
        with torch.no_grad():
            next_token_id, past_key_values = generate_token_with_past(next_inputs)
            # past_key_values: Tuple of tuple(torch.FloatTensor) 
            # of length config.n_layers, 
            # with each tuple having 2 tensors of
            # shape (batch_size, num_heads, sequence_length, embed_size_per_head))
        generated_words.append(next_token_id)
        next_inputs = {
            "input_ids": next_token_id.reshape((1, 1)),
            "attention_mask": torch.cat([next_inputs["attention_mask"], torch.tensor([[1]]).to(device)], dim=1),
            "past_key_values": past_key_values
        }
        
    print(f"{text}\x1b[31m{tokenizer.decode(generated_words)}\x1b[0m")


s_time = time.time()
for text, num_tokens in tqdm(request_queue):
    generate_text_with_kv_caching(text, num_tokens)
    # print("\x1b[31m\"red\"\x1b[0m")
    # break
print(f"kv cached generate_text took:{time.time() - s_time}")

  0%|          | 0/12 [00:00<?, ?it/s]

Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox jumped on[31m top of her and started to run.

"I'm sorry,[0m
Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox jumped on the[31m back of the car and ran away.

"I'm sorry,[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
Qu

## Batching
Baching in Large Language Model (LLM) inference refers to the practice of grouping multiple inference requests together to be processed simultaneously, rather than handling each request individually.  
**On GPUs**  
LLM inference is often memory-IO bound, meaning that the speed at which data can be loaded into the GPU's memory significantly impacts the overall performance. By batching requests, the model parameters are loaded into memory fewer times, thereby reducing the overhead associated with frequent memory accesses.  
**On CPUs**  
Modern frameworks benifits from batching even on CPU because of vectorization and optimized matrix operations.

### batching related bookkeeping
**Padding:** When there are multiple inputs in a batch, to make them of same size pad tokens are added to it and same thing is reflected in attention_mask so that they can be ignored accordingly.

**position_ids:** In single input without kv caching, we send entire sequence so far as input, hence position is clear. In kv caching  we send only the last generated token as input while remaining part of sequence will be  sent as past_key_values. in this case it"s not clear to llm which position in the sequence the current input token belongs. hence position_id is sent explicitly.

In [19]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

print(input_samples)
texts = [tup[0] for tup in input_samples]
tokenizer(texts, padding=True, return_tensors="pt")

[('Quick brown fox', 15), ('Quick brown fox jumped', 15), ('Quick brown fox jumped on', 15), ('Quick brown fox jumped on the', 15)]


{'input_ids': tensor([[50256, 50256, 50256, 21063,  7586, 21831],
        [50256, 50256, 21063,  7586, 21831, 11687],
        [50256, 21063,  7586, 21831, 11687,   319],
        [21063,  7586, 21831, 11687,   319,   262]]), 'attention_mask': tensor([[0, 0, 0, 1, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1]])}

In [20]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[:, -1, :]
    # print(last_logits.shape)
    next_token_ids = last_logits.argmax(dim=1)
    # print(next_token_ids)
    return next_token_ids, outputs.past_key_values


# use past_key_values
def generate_text_with_batching(texts: list, max_tokens=10, show=1):
    # tokenize input
    next_inputs = tokenizer(texts, padding=True, return_tensors="pt")

    # print(next_inputs) 
    # prepare position ids as per batching
    attention_mask = next_inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)
    next_inputs["position_ids"] = position_ids
    # print(f"position_ids:{position_ids}")

    generated_words = [[] for i in range(len(texts))]

    while max_tokens != 0:
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
        # past_key_values: Tuple of tuple(torch.FloatTensor) 
        # of length config.n_layers, 
        # with each tuple having 2 tensors of
        # shape (batch_size, num_heads, sequence_length, embed_size_per_head))
        max_tokens -= 1
        next_inputs = {
            # pass latest generated tokens as input
            "input_ids": next_token_ids.unsqueeze(-1),
            # increment last positions by one and send
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1,
            # append 1 to existing attention mask
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.ones(next_inputs["attention_mask"].shape[0]).unsqueeze(-1)],
                dim=1),
            "past_key_values": past_key_values
        }

        for i, idx in enumerate(next_token_ids):
            generated_words[i].append(idx)
    if show:
        for i, idx in enumerate(generated_words):
            print(f"{i}.{texts[i]}\x1b[31m{tokenizer.decode(idx)}\x1b[0m")


s_time = time.time()
texts = [tup[0] for tup in request_queue]
max_tokens = 15
generate_text_with_batching(texts, max_tokens)
print(f"batch+kv_caching generate_text took:{time.time() - s_time}")

0.Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
1.Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
2.Quick brown fox jumped on[31m top of her and started to run.

"I'm sorry,[0m
3.Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
4.Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
5.Quick brown fox jumped[31m up and down on the ground, and then he jumped up and down on[0m
6.Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
7.Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
8.Quick brown fox jumped on the[31m back of the car and ran away.

"I'm sorry,[0m
9.Quick brown fox[31mes are the most common species of fox in the United States. They are[0m
10.Quick brown fox[31mes are the most common species of fox in the United 

### Continuos batching
Simple batching waits for a batch to finish before processing next set of requests.  
**Continuos batching** is able to swap completed requests with new requests at iteration level. It can achieve 10x-20x better throughput than simple batching.

In [21]:
import random
import itertools
    
# form request queue
input_samples = [
    ("Jack and Jill", 100),
    ("Quick brown fox jumped", 15),
    ("Humpty Dumpty sat", 30),
]

batch_size = 4
queue_size = 12

request_queue = [input_samples[i%len(input_samples)] for i in range(queue_size)]

print(request_queue)


# request_queue

def decode_tokens(tokens):
    return tokenizer.decode(tokens)


def filter_batch(next_inputs, incomplete_mask):
    """
    filter out completed requests from inputs corresponding to mask
    """
    next_inputs["position_ids"] = next_inputs["position_ids"][incomplete_mask]
    next_inputs["input_ids"] = next_inputs["input_ids"][incomplete_mask]
    next_inputs["attention_mask"] = next_inputs["attention_mask"][incomplete_mask]

    # past_key_values: Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of  shape (batch_size, num_heads, sequence_length, embed_size_per_head))
    # filter past_key_values using incomplete_mask
    next_inputs["past_key_values"] = [(layer_tup[0][incomplete_mask], layer_tup[1][incomplete_mask]) for layer_tup in
                                      next_inputs["past_key_values"]]
    return next_inputs


def prefill_batch(texts, max_len=None):
    """
    generate key values for the input text for later use 
    """
    if max_len is None:
        max_len = max([len(text) for text in texts])
    # tokenize input with padding according to max_len
    next_inputs = tokenizer(texts, max_length=max_len-1, padding="max_length", return_tensors="pt")
    # prepare position ids as per batching
    attention_mask = next_inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)
    next_inputs["position_ids"] = position_ids
    # print(f"position_ids:{position_ids}")
    input_id_text_map = torch.IntTensor(range(len(texts)))
    next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
    next_inputs = {
        # pass latest generated tokens as input
        "input_ids": next_token_ids.unsqueeze(-1),
        # increment last positions by one and send
        "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1,
        # append 1 to existing attention mask
        "attention_mask": torch.cat(
            [next_inputs["attention_mask"], torch.ones(next_inputs["attention_mask"].shape[0]).unsqueeze(-1)],
            dim=1),
        "past_key_values": past_key_values
    }
    return next_inputs


def merge_batches(next_inputs, new_next_inputs):
    """
    take existing batch and merge with prefilled batch so that batch remains full
    """

    next_inputs["position_ids"] = torch.cat(
        [next_inputs["position_ids"], new_next_inputs["position_ids"]], dim=0)
    next_inputs["input_ids"] = torch.cat([next_inputs["input_ids"], new_next_inputs["input_ids"]], dim=0)
    
    next_inputs["attention_mask"] = torch.cat(
        [next_inputs["attention_mask"], new_next_inputs["attention_mask"]], dim=0)
    # print(f"next_inputs['attention_mask']: {next_inputs['attention_mask']}")

    assert len(new_next_inputs['past_key_values']) == len(next_inputs['past_key_values'])
    
    num_layers = len(next_inputs['past_key_values'])
    new_past_key_values= []
    for layer_i in range(num_layers):
        layerwise_kv= []
        for kv_i in (0,1):
            layerwise_kv.append(torch.cat([next_inputs["past_key_values"][layer_i][kv_i], new_next_inputs["past_key_values"][layer_i][kv_i]], dim=0))        
        # print(f"layer={layer_i}, layerwise_kv[0].shape={layerwise_kv[0].shape}, layerwise_kv[1].shape={layerwise_kv[1].shape}")
        new_past_key_values.append(layerwise_kv)
    next_inputs["past_key_values"] = tuple(new_past_key_values)
        
    return next_inputs


def generate_text_with_continous_batching(request_queue: list, batch_size=5):
    """
    generate text with continous batching   
    :param request_queue: list of tuple(text, max_tokens)
    :param batch_size: 
    :return: 
    """
    # initialize_batch

    current_texts= [tup[0] for tup in request_queue[:batch_size]] # maintain ongoing texts
    max_tokens = [tup[1] for tup in request_queue[:batch_size]] # maintain ongoing max_tokens
    generated_words = [[] for i, _ in enumerate(current_texts)] # maintain ongoing generated_words
    
    
    request_queue = request_queue[batch_size:]
    
    next_inputs = prefill_batch(current_texts)
    
    s_time = time.time()
    batch_capacity = batch_size
    
    while sum(max_tokens) >= 0:
        # track which sequences are complete/incomplete so that complete sequences can be taken out
        complete_indices = [i for i, remaining_count in enumerate(list(max_tokens)) if remaining_count == 0]

        # mask used to take out complete inputs
        incomplete_mask = [remaining_count > 0 for remaining_count in max_tokens]
        if complete_indices:
            batch_capacity = len(complete_indices)
            next_inputs = filter_batch(next_inputs, incomplete_mask)
            for i in complete_indices:
                print(
                    f"+{round(time.time() - s_time, 3)}{current_texts[i]}\x1b[31m{tokenizer.decode(generated_words[i])}\x1b[0m")
            
            
            current_texts = list(itertools.compress(current_texts, incomplete_mask))
            max_tokens = list(itertools.compress(max_tokens, incomplete_mask))
            generated_words = list(itertools.compress(generated_words, incomplete_mask))
            ###
            # Populate missing slots
            ###
            if request_queue:
                new_texts, new_max_tokens = [tup[0] for tup in request_queue[:batch_capacity]], [
                    tup[1] for tup in request_queue[:batch_capacity]]
                current_texts.extend(new_texts)
                max_tokens.extend(new_max_tokens)
                generated_words.extend([[] for _ in range(batch_capacity)])
            
                max_length = next_inputs["attention_mask"].shape[1]
    
                new_next_inputs = prefill_batch(new_texts, max_length)
                # print(f"next_inputs['input_ids'].shape:{next_inputs['input_ids'].shape}")
                
                next_inputs = merge_batches(next_inputs, new_next_inputs)
            
                # print(f"next_inputs['input_ids'].shape:{next_inputs['input_ids'].shape}")
                request_queue = request_queue[batch_capacity:]
            
            if next_inputs["input_ids"].size(0) <= 0:
                break
        # print(next_inputs)
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
        max_tokens = [mt - 1 for mt in max_tokens]

        next_inputs = {
            # pass latest generated tokens as input
            "input_ids": next_token_ids.unsqueeze(-1),
            # increment last positions by one and send
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1,
            # append 1 to existing attention mask
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.ones(next_inputs["attention_mask"].shape[0]).unsqueeze(-1)],
                dim=1),
            "past_key_values": past_key_values
        }

        # print(f"generated_words: {generated_words}")
        for i, idx in enumerate(next_token_ids):
            generated_words[i].append(idx)


s_time = time.time()
generate_text_with_continous_batching(request_queue, batch_size)
print(f"kv cached generate_text took:{time.time() - s_time}")


[('Jack and Jill', 100), ('Quick brown fox jumped', 15), ('Humpty Dumpty sat', 30), ('Jack and Jill', 100), ('Quick brown fox jumped', 15), ('Humpty Dumpty sat', 30), ('Jack and Jill', 100), ('Quick brown fox jumped', 15), ('Humpty Dumpty sat', 30), ('Jack and Jill', 100), ('Quick brown fox jumped', 15), ('Humpty Dumpty sat', 30)]
+0.726Quick brown fox jumped[31m and down on the ground, and then he jumped up and down on the[0m
+1.658Humpty Dumpty sat[31m the throne of the House of Commons, and he was a man of great courage and of great courage. He was a man of great courage and of[0m
+1.658Quick brown fox jumped[31m and down on the ground, and then he jumped up and down on the[0m
+3.327Humpty Dumpty sat[31m the throne of the House of Commons, and he was a man of great courage and of great courage. He was a man of great courage and of[0m
+4.362Quick brown fox jumped[31m and down on the ground, and then he jumped up and down on the[0m
+5.899Jack and Jill[31m who are both in th

## Quantization
Model quantization is a common way to reduce model hardware requirements. Reducing the precision of the model weights and activations of the model reduces the GPU RAM requirements. For example changing model precision from float16 to int8 halves the size of the VRAM requirements.
It also leads to kv cache size reduction.



### Floating point representation

| Representation | Mantissa     | Exponent(range of numbers)                          | Sign | Example   |
| -------------- | ------------ | --------------------------------------------------- | ---- | --------- |
|                | decides the precision with which numbers can be represented | decides the range of number that can be represented |      |           |
| FP32           | 23           | 8                                                   | 1    | 3.1415927 |
| FP16           | 10           | 5                                                   | 1    | 3.141     |
| FP8            | 2            | 5                                                   | 1    | 3         |

### Effect of quantization on KV cache requirements



| Batch Size | KV cache memory (FP16) | KV cache memory (Int8) |
|------------|-----------------------------|----------------------------|
| 1          | 0.312 GiB                   | 0.156 GiB                  |
| 16         | 5 GiB                       | 2.5 GiB                    |
| 32         | 10 GiB                      | 5 GiB                      |
| 64         | 20 GiB                      | 10 GiB                     |
KV cache size for Llama-2-70B at a sequence length of 1024


In [22]:
# model memory footprint
model_memory_footprint = model.get_memory_footprint()/(1024*1024)
print(f"model_memory_footprint: {model_memory_footprint}")

model_memory_footprint: 486.7002410888672


In [23]:
# define a function to generate text
def generate(model, tokenizer, prompt, max_length=20, num_return_sequences=1):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    outputs = model.generate(
        input_ids,
        do_sample=True,
        max_length=max_length,
        temperature=1.0,
        num_return_sequences=num_return_sequences,
    )
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [24]:
# genrate using original model
generate(model,tokenizer, "Quick brown fox jumped on", 40)



['Quick brown fox jumped on it\'s hind foot, and it moved outwards, moving toward Ruby.\n\nFluid eyes.\n\nBut she couldn\'t even say "Fluid eyes']

In [25]:
# define a quantization function
def quantize(t):
    min_val, max_val = t.min(), t.max()
    scale = (max_val - min_val) / 255
    zero_point = min_val
    state = (scale, zero_point)
    t_quant = (t - min_val) / scale
    t_quant = torch.clamp(t_quant, min=0, max=255)

    # cast to int8
    t_quant = t_quant.type(torch.uint8)

    return t_quant, state

# define a dequantization function
def dequantize(t, state):
    scale, min_val = state
    # upcast to float
    t = t.to(torch.float32)
    t = t * scale + min_val
    return t

In [26]:
# quantize random tensor
# initialize random tensor
t = torch.rand(2, 3)
print(f"original: {t}")
t_quantized, scale = quantize(t)
print(f"quantized:{t_quantized}\nscale:{scale}")

t_recovered = dequantize(t_quantized, scale)
print(f"t_recovered: {t_recovered}")


original: tensor([[0.9623, 0.1279, 0.1236],
        [0.6088, 0.2284, 0.0868]])
quantized:tensor([[255,  11,  10],
        [152,  41,   0]], dtype=torch.uint8)
scale:(tensor(0.0034), tensor(0.0868))
t_recovered: tensor([[0.9623, 0.1246, 0.1211],
        [0.6087, 0.2276, 0.0868]])


In [27]:
# quantize entire model
def quantize_model(model):
    states = {}
    for name,param in model.named_parameters():
        param.requires_grad = False
        param.data , states[name] = quantize(param.data)
    return model, states

def dequantize_model(q_model, states):
    for name ,param in q_model.named_parameters():
        param.data = dequantize(param.data, states[name])
    return q_model

q_model,states = quantize_model(model)
q_model.get_memory_footprint()/ (1024 * 1024)

130.6750946044922

In [28]:
# generate text using recovered model
recovered_model = dequantize_model(model, states)
q_output = generate(recovered_model,tokenizer, "Quick brown fox jumped on", 40)
q_output

['Quick brown fox jumped on top of them but missed after landing. He ran away crying to his puppy. When he saw him lying naked naked naked naked naked naked naked naked naked naked naked naked naked naked']

## Adapter based finetuning
LoRA, short for Low-Rank Adaptation, is a method designed to efficiently fine-tune large pre-trained models. The intuition behind LoRA stems from the understanding that the vast majority of the parameters in a pre-trained model remain unchanged during fine-tuning.

![Title](img/llm_serving/lora_diagram.png)
Image credits: [huggingface](https://huggingface.co/docs/peft/main/en/conceptual_guides/lora)


<img src="img/llm_serving/multi_adapter_serving.jpeg"  width="650" height="650">

# References
1. [DLAI - Efficiently Serving LLMs](https://learn.deeplearning.ai/courses/efficiently-serving-llms)
2. [LLM Inference Performance Engineering: Best Practices | Databricks Blog](https://www.databricks.com/blog/llm-inference-performance-engineering-best-practices)
3. [Hardware for LLMs - by Benjamin Marie](https://kaitchup.substack.com/p/hardware-for-llms)
4. [Orca: A Distributed Serving System for Transformer-Based Generative Models | USENIX](https://www.usenix.org/conference/osdi22/presentation/yu)
5. [QLoRA: Fine-Tune a Large Language Model on Your GPU](https://kaitchup.substack.com/p/qlora-fine-tune-a-large-language-model-on-your-gpu-27bed5a03e2b)
6. [Fundamentals of Data Representation: Floating point numbers - Wikibooks, open books for an open world](https://en.wikibooks.org/wiki/A-level_Computing/AQA/Paper_2/Fundamentals_of_data_representation/Floating_point_numbers)