# Generate `top_p` `End_Of_token`  `n_rep`

In [None]:
def generate_top_p(model, tokenizer, prompt, n_rep=3, max_seq_len=256, T=0.9, top_p=0.9, device='cuda', seed=42):
    # Tokenize the prompt and convert it to a tensor on the specified device
    inputs = torch.tensor(tokenizer.encode(prompt).ids, dtype=torch.int, device=device)  # Shape: [T]

    # Repeat the input prompt n_rep times to generate multiple sequences in parallel
    inputs = inputs.unsqueeze(0).repeat(n_rep, 1)  # Shape: [n_rep, T]

    end_token_id = tokenizer.token_to_id('<|endoftext|>')

    model.eval()
    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(seed)

    # Track which sequences have finished (hit end token)
    finished = torch.zeros(n_rep, dtype=torch.bool, device=device)

    with torch.no_grad():
        while inputs.shape[-1] < max_seq_len and not all(finished):
            # Forward pass: get logits from the model
            logits = model(inputs)  # Shape: [n_rep, T, vocab_size]

            # Apply softmax to get probabilities for the next token
            probs = torch.softmax(logits[:, -1, :] / T, dim=-1)  # Shape: [n_rep, vocab_size]

            # Select the top_P tokens for each sequence in the batch
            sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            # Create mask for top-p sampling
            mask = cumulative_probs <= top_p
            # Ensure we always have at least one token
            mask[:, 0] = True

            # Find cutoff indices
            cutoff = mask.sum(dim=-1)

            # Prepare for batched sampling
            sampled_tokens = []
            for i in range(n_rep):
                if finished[i]:
                    # If sequence is finished, just pad with end token
                    sampled_tokens.append(end_token_id)
                    continue

                num_keep = cutoff[i]
                final_probs = sorted_probs[i, :num_keep]
                final_indices = sorted_indices[i, :num_keep]

                # Renormalize probabilities
                final_probs = final_probs / final_probs.sum()

                # Sample one token
                idx = torch.multinomial(final_probs, num_samples=1, generator=sample_rng)
                sampled_token = final_indices[idx]
                sampled_tokens.append(sampled_token)

                # Check if this sequence should finish
                if sampled_token == end_token_id:
                    finished[i] = True

            # Convert sampled tokens to tensor and add to inputs
            sampled_tokens = torch.tensor(sampled_tokens, device=device).unsqueeze(-1)
            inputs = torch.cat((inputs, sampled_tokens), dim=-1)

            # Early exit if all sequences are finished
            if all(finished):
                break

    # Cut off everything after the first occurrence of end_token_id
    final_outputs = []
    for sequence in inputs.tolist():
        if end_token_id in sequence[1:]:
          end_index = sequence[1:].index(end_token_id)
          final_outputs.append(sequence[:end_index+1])
        else:
            final_outputs.append(sequence)

    # Decode all generated sequences
    # generated_texts = tokenizer.decode_batch(inputs.tolist())
    generated_texts = tokenizer.decode_batch(final_outputs)
    return generated_texts

In [None]:
prompt = 'in last'
generated_texts = generate_top_p(model, tokenizer, prompt, n_rep=3, max_seq_len=256, T=0.9, top_p=0.9, device='cuda', seed=42)
print('Generate top_p End_of_token:')
print()
for i, text in enumerate(generated_texts):
    display(HTML(f"<span style='color: yellow;'>Generated {i+1}:</span> <span style='color: cyan;'>{prompt}</span><span style='color: White;'>{text[len(prompt):]}</span>"))
    print('-'*150)

Generate top_p End_of_token:



------------------------------------------------------------------------------------------------------------------------------------------------------


------------------------------------------------------------------------------------------------------------------------------------------------------


------------------------------------------------------------------------------------------------------------------------------------------------------
