In [1]:
# import helpers
# from helpers import init_batch, generate_next_token
# from helpers import merge_batches, filter_batch

In [2]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
import random
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [4]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"


In [5]:
# multiple prompts of varying lengths to send to the model at once
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What comes up must",
]

# note: padding=True ensures the padding token
# will be inserted into the tokenized tensors
inputs = tokenizer(prompts, padding=True, return_tensors="pt")

In [6]:
# def generate_batch_tokens_with_past(inputs):
#     with torch.no_grad():
#         outputs = model(**inputs)
#     logits = outputs.logits
#     last_logits = logits[:, -1, :]
#     next_token_ids = last_logits.argmax(dim=1)
#     return next_token_ids, outputs.past_key_values


def generate_batch(inputs, max_tokens):
    # create a list of tokens for every input in the batch
    generated_tokens = [[] for _ in range(inputs["input_ids"].shape[0])]

    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
    }

    for _ in range(max_tokens):
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
        next_inputs = {
            "input_ids": next_token_ids.reshape(-1, 1),
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1,
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.ones((next_token_ids.shape[0], 1))],
                dim=1
            ),
            "past_key_values": past_key_values
        }

        next_tokens = tokenizer.batch_decode(next_token_ids)
        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)

    return ["".join(tokens) for tokens in generated_tokens]


In [7]:
# seed the random number generator, so our results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# requests are tuples of (prompt, max_tokens)
request_queue = [
    (prompts[i % len(prompts)], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]



In [8]:
request_queue[:8]

[('The quick brown fox jumped over the', 100),
 ('The rain in Spain falls', 10),
 ('What comes up must', 10),
 ('The quick brown fox jumped over the', 10),
 ('The rain in Spain falls', 10),
 ('What comes up must', 10),
 ('The quick brown fox jumped over the', 10),
 ('The rain in Spain falls', 10)]

In [9]:
batches = [
    request_queue[i: i + batch_size]
    for i in range(0, len(request_queue), batch_size)
]

In [10]:
# t0 = time.time()
# with tqdm(total=len(batches), desc=f"bs={batch_size}") as pbar:
#     for i, batch in enumerate(batches):
#         batch_max_tokens = [bs[1] for bs in batch]
#         max_tokens = max(batch_max_tokens)
#         pbar.set_postfix({'max_tokens': max_tokens})
#
#         batch_prompts = [b[0] for b in batch]
#         inputs = tokenizer(
#             batch_prompts, padding=True, return_tensors="pt")
#         generate_batch(inputs, max_tokens=max_tokens)
#
#         pbar.update(1)
#
# duration_s = time.time() - t0
# print("duration: ", duration_s)



## Let's try continuous batching


In [35]:
# seed the random number generator, so out results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# this time requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[i % len(prompts)], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]
print(request_queue)



[('The quick brown fox jumped over the', 100), ('The rain in Spain falls', 10), ('What comes up must', 10), ('The quick brown fox jumped over the', 10), ('The rain in Spain falls', 10), ('What comes up must', 10), ('The quick brown fox jumped over the', 10), ('The rain in Spain falls', 10), ('What comes up must', 100), ('The quick brown fox jumped over the', 10), ('The rain in Spain falls', 10), ('What comes up must', 10), ('The quick brown fox jumped over the', 10), ('The rain in Spain falls', 10), ('What comes up must', 10), ('The quick brown fox jumped over the', 10), ('The rain in Spain falls', 100), ('What comes up must', 10), ('The quick brown fox jumped over the', 10), ('The rain in Spain falls', 10), ('What comes up must', 10), ('The quick brown fox jumped over the', 10), ('The rain in Spain falls', 10), ('What comes up must', 10), ('The quick brown fox jumped over the', 100), ('The rain in Spain falls', 10), ('What comes up must', 10), ('The quick brown fox jumped over the', 1

In [36]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values


def init_batch(request_queue):
    """
    加工为模型输入的批量。
    进行batch内对齐
    :param request_queue:
    :return:
    """
    prompts, max_tokens = zip(*request_queue)
    inputs = tokenizer(list(prompts), padding=True, return_tensors="pt")
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_tokens": torch.tensor(max_tokens).long(),
    }


def pad_to_max_length(tensor, max_length, pad_value=0):
    pad_size = max_length - tensor.size(1)
    if pad_size > 0:
        padding = torch.full((tensor.size(0), pad_size), pad_value, dtype=tensor.dtype, device=tensor.device)
        tensor = torch.cat([tensor, padding], dim=1)
    return tensor


def merge_batches(cached_batch, new_batch):
    """
    合并新的请求到缓存批量。
    根据更大的batch进行对齐
    :param cached_batch:
    :param new_batch:
    :return:
    """
    max_length = max(cached_batch["input_ids"].size(1), new_batch["input_ids"].size(1))
    # Pad input_ids, attention_mask, and position_ids to the max_length
    cached_batch["input_ids"] = pad_to_max_length(cached_batch["input_ids"], max_length, pad_value=tokenizer.pad_token_id)
    new_batch["input_ids"] = pad_to_max_length(new_batch["input_ids"], max_length, pad_value=tokenizer.pad_token_id)

    cached_batch["attention_mask"] = pad_to_max_length(cached_batch["attention_mask"], max_length, pad_value=0)
    new_batch["attention_mask"] = pad_to_max_length(new_batch["attention_mask"], max_length, pad_value=0)

    cached_batch["position_ids"] = pad_to_max_length(cached_batch["position_ids"], max_length, pad_value=0)
    new_batch["position_ids"] = pad_to_max_length(new_batch["position_ids"], max_length, pad_value=0)

    # Concatenate input_ids, attention_mask, position_ids, and past_key_values
    merged_batch = {
        "input_ids": torch.cat([cached_batch["input_ids"], new_batch["input_ids"]], dim=0),
        "attention_mask": torch.cat([cached_batch["attention_mask"], new_batch["attention_mask"]], dim=0),
        "position_ids": torch.cat([cached_batch["position_ids"], new_batch["position_ids"]], dim=0),
        "past_key_values": None,
        "max_tokens": torch.cat([cached_batch["max_tokens"], new_batch["max_tokens"]], dim=0)
    }
        # 合并 past_key_values
    if cached_batch["past_key_values"] is None:
        # 如果 cached_batch 没有 past_key_values，直接使用 new_batch 的 past_key_values
        merged_batch["past_key_values"] = new_batch["past_key_values"]
    elif new_batch["past_key_values"] is None:
        # 如果 new_batch 没有 past_key_values，直接使用 cached_batch 的 past_key_values
        merged_batch["past_key_values"] = cached_batch["past_key_values"]
    else:
        # 如果两者都有 past_key_values，需要正确合并
        merged_past_key_values = []
        for cached_layer, new_layer in zip(cached_batch["past_key_values"], new_batch["past_key_values"]):
            # 合并 key 和 value
            merged_key = torch.cat([cached_layer[0], new_layer[0]], dim=2)  # 在 batch_size 维度上合并
            merged_value = torch.cat([cached_layer[1], new_layer[1]], dim=2)  # 在 batch_size 维度上合并
            merged_past_key_values.append((merged_key, merged_value))
        merged_batch["past_key_values"] = tuple(merged_past_key_values)
    return merged_batch


def filter_batch(cached_batch):
    eos_token_id = tokenizer.eos_token_id
    finished_indices = []
    for i in range(cached_batch["input_ids"].size(0)):
        # 检查是否生成 EOS token 或达到 max_tokens
        if (
                cached_batch["input_ids"][i, -1] == eos_token_id or
                cached_batch["attention_mask"][i].sum() >= cached_batch["max_tokens"][i]
        ):
            finished_indices.append(i)

    # 移除已完成生成的请求
    remaining_indices = [
        i for i in range(cached_batch["input_ids"].size(0))
        if i not in finished_indices
    ]
    cached_batch["input_ids"] = cached_batch["input_ids"][remaining_indices]
    cached_batch["position_ids"] = cached_batch["position_ids"][remaining_indices]
    cached_batch["attention_mask"] = cached_batch["attention_mask"][remaining_indices]
    cached_batch["max_tokens"] = cached_batch["max_tokens"][remaining_indices]
    cached_batch["past_key_values"] = None
    # if cached_batch["past_key_values"] is not None:
    #     cached_batch["past_key_values"] = None
        # cached_batch["past_key_values"] = [
        #     [kv[remaining_indices] for kv in layer_kvs]
        #     for layer_kvs in cached_batch["past_key_values"]
        # ]

    return cached_batch, finished_indices


In [37]:
def generate_next_token(batch):
    if "past_key_values" in batch:
        sequence_length = batch["input_ids"].shape[1]
        print(sequence_length)
        # 依次打印 past_key_values 的维度
        # for layer_index, (key_tensor, value_tensor) in enumerate(batch["past_key_values"]):
        #     print(f"Layer {layer_index + 1}:")
        #     print(f"  Key tensor shape: {key_tensor.shape}")
        #     print(f"  Value tensor shape: {value_tensor.shape}")
    # 依次打印 input_ids 的维度
    inputs = batch
    # create a list of tokens for every input in the batch
    generated_tokens = [[] for _ in range(inputs["input_ids"].shape[0])]

    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
    }

    next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
    # 依次打印 past_key_values 的维度
    # for layer_index, (key_tensor, value_tensor) in enumerate(past_key_values):
    #     print(f"Layer {layer_index + 1}:")
    #     print(f"  Key tensor shape: {key_tensor.shape}")
    #     print(f"  Value tensor shape: {value_tensor.shape}")
    next_inputs = {
        "input_ids": torch.cat(
            [next_inputs["input_ids"], next_token_ids.reshape(-1, 1)],
            dim=1
        ),
        "position_ids": torch.cat(
            [next_inputs["position_ids"],
             next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1],
            dim=1
        ),
        "attention_mask": torch.cat(
            [next_inputs["attention_mask"], torch.ones((next_token_ids.shape[0], 1))],
            dim=1
        ),
        "past_key_values": None, # 我无法实现 past_key_values 的合并/拆分 相关逻辑，所以这里直接设置为 None。否则会报错
        "max_tokens": batch["max_tokens"]
    }


    next_tokens = tokenizer.batch_decode(next_token_ids)
    for i, token in enumerate(next_tokens):
        generated_tokens[i].append(token)
    # print("".join(["".join(tokens) for tokens in generated_tokens]))
    return next_inputs

t0 = time.time()

with tqdm(total=len(request_queue), desc=f"bs={batch_size}") as pbar:
    # first, let's seed the initial cached_batch
    # with the first `batch_size` inputs
    # and run the initial prefill step
    batch = init_batch(request_queue[:batch_size])
    cached_batch = generate_next_token(batch)
    request_queue = request_queue[batch_size:]

    # continue until both the request queue is
    # fully drained and every input
    # within the cached_batch has completed generation
    while (
            len(request_queue) > 0 or
            cached_batch["input_ids"].size(0) > 0
    ):
        batch_capacity = (
                batch_size - cached_batch['input_ids'].size(0)
        )
        if batch_capacity > 0 and len(request_queue) > 0:
            # prefill
            new_batch = init_batch(request_queue[:batch_capacity])
            new_batch = generate_next_token(new_batch)
            request_queue = request_queue[batch_capacity:]

            # merge
            cached_batch = merge_batches(cached_batch, new_batch)

        # decode
        cached_batch = generate_next_token(cached_batch)

        # remove any inputs that have finished generation
        cached_batch, removed_indices = filter_batch(cached_batch)
        pbar.update(len(removed_indices))

duration_s = time.time() - t0
print("duration: ", duration_s)


bs=8:   6%|▋         | 2/32 [00:00<00:01, 18.63it/s]

 fence on be fence on be fence on
8
 and the a and the a and the
9
 ran first good ran first good ran first
 be fence
10
 to day idea day idea day,,
11
 the of. of. of then and


bs=8:  25%|██▌       | 8/32 [00:00<00:00, 35.35it/s]

 on be fence on
12
 other

,,,,,
 be fence
13
 side is and then and and,,
 on
14
 of the the, the then and,


bs=8:  38%|███▊      | 12/32 [00:00<00:00, 25.13it/s]

 be
15
 the fact rain is rain, and,
 fence on
16
 fence that the is the then,,
 be
17


bs=8:  56%|█████▋    | 18/32 [00:01<00:00, 21.58it/s]

. the the rain, and and,
 fence on
18
 He government in is the then,,
19
 was is the the rain, and and
 be fence on
20
 about not United is the,,,
21
 to going States the rain then and and
 be fence
22
 run to falls, the the,,
23


bs=8:  69%|██████▉   | 22/32 [00:01<00:00, 16.60it/s]

 when be on is fox rain then and
 on be
24
 he able, the was,,,
 fence
25
 saw to and about is and then,
26


bs=8:  84%|████████▍ | 27/32 [00:01<00:00, 14.22it/s]

 the do the to the the, and
 on
27
 fox anything rain run rain is,
28
. about in away the and
29
 He it the when the
30
 ran. United the rain
31
 to
 States fox
32
 the
 falls jumped
33
 otherThe on over
34
 side government, the
35
 of is and fence
36
 the going the.
37
 fence to rain

38
 and have in

39
 ran to the"
40
 to do UnitedI
41
 the something States'm
42
 other about falls not
43
 side it on going
44
 of., to
45
 the
 and let
46
 fence
 the you
47
.The rain run
48
 He government in away
49
 ran is the,"
50
 to going United the
51
 the to States fox
52
 other have falls said
53
 side to on.
54
 of do,

55
 the something and

56
 fence about the"
57
 and it rainI
58
 ran. in'm
59
 to
 the not
60
 the
 United going
61
 otherThe States to
62
 side government falls let
63
 of is on you
64
 the going, run
65
 fence to and away
66
. have the,"
67
 He to rain the
68
 ran do in fox
69
 to something the said
70
 the about United.
71
 other it States

72
 side. falls

73
 of
 on"
74
 t

bs=8:  91%|█████████ | 29/32 [00:08<00:02,  1.27it/s]

 the about in run
99
 other it the away
100
. United,"
101

 States the
102

 falls fox
103
The on said
104
 government,.
105
 and

106
 the

107


bs=8:  97%|█████████▋| 31/32 [00:08<00:00,  1.45it/s]

 rain"
108
I
109
'm
110
 not
111


bs=8: 100%|██████████| 32/32 [00:09<00:00,  3.53it/s]

 going
duration:  9.074449062347412



