#### Batching as a means for optimization of LLM inference
Batching is a technique to increase the throughput of llm inference by supplying multiple sets of input sequences to be processed simultanously. 

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

In [2]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Reusing KV-cache text generation function from the text generation example

In [5]:
prompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors="pt")


def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)
        
    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values


def generate(inputs, max_tokens):
    generated_tokens = []
    next_inputs = inputs
    for _ in range(max_tokens):
        next_token_id, past_key_values = generate_token_with_past(next_inputs)
        next_inputs = {
            "input_ids": next_token_id.reshape((1, 1)),
            "attention_mask": torch.cat([next_inputs["attention_mask"], torch.tensor([[1]])], dim=1),
            "past_key_values": past_key_values
        }
        next_token = tokenizer.decode(next_token_id)
        generated_tokens.append(next_token)
        
    return "".join(generated_tokens)

In [6]:
tokens = generate(inputs, max_tokens=10)
print(tokens)

 fence and ran to the other side of the fence


Add padding toekns to the model to prepare batches of prompts

In [7]:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

In [8]:
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [9]:
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What coems up must",
]

inputs = tokenizer(prompts, padding=True, return_tensors="pt")

In [10]:
print("input_ids", inputs["input_ids"])
print("shape:", inputs["input_ids"].shape)

input_ids tensor([[  464,  2068,  7586, 21831, 11687,   625,   262],
        [50256, 50256,   464,  6290,   287,  8602,  8953],
        [50256, 50256,  2061,   763,  5232,   510,  1276]])
shape: torch.Size([3, 7])


In [11]:
print("attention_mask: ", inputs["attention_mask"])
print("shape: ", inputs["attention_mask"].shape)

attention_mask:  tensor([[1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1]])
shape:  torch.Size([3, 7])


In [13]:
attention_mask = inputs["attention_mask"]
position_ids = attention_mask.long().cumsum(-1) -1
position_ids.masked_fill_(attention_mask == 0, 1)

tensor([[0, 1, 2, 3, 4, 5, 6],
        [1, 1, 0, 1, 2, 3, 4],
        [1, 1, 0, 1, 2, 3, 4]])

In [17]:
with torch.no_grad():
    outputs = model(position_ids=position_ids, **inputs)
logits = outputs.logits

In [20]:
last_logits = logits[:,-1,:]
next_token_ids = last_logits.argmax(dim=1)

In [21]:
next_tokens = tokenizer.batch_decode(next_token_ids)
next_tokens

[' fence', ' on', ' be']

In [22]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)
        
    logits = outputs.logits
    last_logits = logits[:,-1,:]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values

In [None]:
def generate_batch(inputs, max_tokens):
    generate_tokens = [
        [] for _ in range(inputs["input_ids"].shape[0])
    ]
    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) -1
    position_ids.msaked_fill_(attention_mask == 0, 1)
    
    next_inputs = {
        "position_ids" : position_ids,
        **inputs
    }
    
    for _ in range(max_tokens):
        next_token_ids, past_key_values = generate_token_with_past(next_inputs)
        
        next_inputs = {
            "input_ids": next_token_ids.reshape((-1, 1)),
            "position_ids" : next_inputs["position_ids"][:-1].unsqueeze(-1)+ 1,
            "attention_mask": torch.cat([next_inputs["attention_mask"], torch.ones((next_token_ids.shape[0], 1))], dim=1),
            "past_key_values": past_key_values,
        }