Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate with batched inputs #188

Open
qiqiApink opened this issue Apr 23, 2023 · 8 comments
Open

Generate with batched inputs #188

qiqiApink opened this issue Apr 23, 2023 · 8 comments

Comments

@qiqiApink
Copy link

It seems that the generation code doesn't support batched inputs. Can you give me some instructions?

@lantiga
Copy link
Collaborator

lantiga commented Apr 23, 2023

We removed that to make the code simpler, with the idea to add a separate script in case it was of interest (it looks like there is interest :-) ).

In the meantime you can take a look at the PR where we removed batched generation: #162
if you look at the "before the changes" part it will show you how to run with batched inputs

@qiqiApink
Copy link
Author

qiqiApink commented Apr 24, 2023

I tried the codes before the changes, but they can't generate the right things. If I use the present codes, the answer will be right. How can I fix this? And I found that for the same prompt, the logits output by the model are different between single generation and batch generation.

@timothylimyl
Copy link
Contributor

we can definitely add batch generation, but we will need to change up all the generation_x.py files. Both old and new script expects the generation prompt to be in a form of str and not list[str].

@lucas-ventura
Copy link
Contributor

lucas-ventura commented May 7, 2023

What do you think of this implementation? I still need to improve it and use max_new_tokens, which I'm not using at the moment.

@torch.no_grad()
def generate_batch(
    model: torch.nn.Module,
    idx: List[torch.Tensor],
    max_new_tokens: int,
    max_seq_length: int,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
    The implementation of this function is modified from A. Karpathy's nanoGPT.
    Args:
        model: The model to use.
        idx: Tensor of shape (B, T) with indices of the prompt sequence.
        max_new_tokens: The number of new tokens to generate.
        max_seq_length: The maximum sequence length allowed.
        temperature: Scales the predicted logits by 1 / temperature
        top_k: If specified, only sample among the tokens with the k highest probabilities
    """

    # create an empty tensor of the expected final shape and fill in the current tokens
    batch_size = len(idx)
    empty = torch.empty(
        batch_size, max_seq_length, dtype=idx[0].dtype, device=idx[0].device
    )
    indices = torch.zeros(batch_size, dtype=torch.int32, device=idx[0].device)
    for i in range(batch_size):
        empty[i, : idx[i].shape[0]] = idx[i]
        indices[i] = idx[i].shape[0]
    idx = empty

    min_idx = torch.min(indices)
    for t in range(min_idx, max_seq_length):
        # Select rows to process
        rows = indices == t
        if not rows.any():
            break  # will be return
        idx_cond = idx[rows, :t]

        # forward
        logits = model(idx_cond)
        logits = logits[:, -1] / temperature

        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, k=top_k, dim=-1)
            min_v = v[:, -1].unsqueeze(-1)
            logits[logits < min_v] = -float("Inf")

        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        # concatenate the new generation
        idx_next = idx_next.view(-1)
        idx[rows, t] = idx_next.int()

        indices[rows] += 1
        indices[idx[:, t] == eos_id] -= 1

    return [out_idx[: out_ind + 1] for out_idx, out_ind in zip(idx, indices)]

@AngainorDev
Copy link

What do you think of this implementation?

Thanks for this!
I was trying to get batch inference working myself, hoping for a lower inference time.

I used your code, it works well (could be better with batched encode/decode by modifying also the tokenizer part) but I find the speed to be even lower than with sequential inference.
Sequential: 33 tok/sec ; batched: 22 tok/sec

I was expecting batched inference to be faster.
Do you experience the same, do you have an idea why? padding?

@carmocca
Copy link
Contributor

@AngainorDev Did you update the token count to account for the fact that it's batched? This line: https://github.com/Lightning-AI/lit-llama/blob/main/generate.py#LL150

It currently computes the length of 1 sequence (T). With batched generation you would need to consider that each sequence can have a different T, and might want to discount padding (or not if you just care about raw speed).

@AngainorDev
Copy link

Thanks,

Yes I did the real token counts.
token/s is coherent with generation time.
A single generation in my case is 5 sec, while a batch of 4 (even 4 same ones) takes like 30 sec.

Side note, may have its importance, I'm testing on lit-parrot, with Falcon Model.

All runs fine, the batch works, but generation time is awful compared to sequential (and yeah, I'm looking for best possible generation time)

I tested with the original batched code (which uses a tensor as inputs instead of a list of tensors) and it's the same, batched generation takes more time than N x 1 generation.
(Used a batch size of 4 so far)

Did you compare on llama, do you get any perf improvement from batch generation? Could it be a Falcon only issue?

@carmocca
Copy link
Contributor

We also don't have batched inference implemented in LlaMA. If this was a problem with the specific architecture of Falcon, you could still check using StableLM or Pythia weights.

One critical part to fast inference is to use KV caching. I'm not sure if our implementation works out-of-the-box with batched inference, we would need to check if there are required changes.

The easiest way to help you would be that you open a PR with your changes so that we can take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants