In [None]:
import torch
import numpy as np

def top_p_sampling(probabilities, top_p=0.90):
    # step 1
    sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)

    # step 2
    cummulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # step 3
    sort_mask = cummulative_probs <= top_p
    sort_mask[..., 0] = True

    allowed_probs = sorted_probs[sort_mask]
    allowed_indices = sorted_indices[sort_mask]

    # step 4
    allowed_probs = allowed_probs / allowed_probs.sum()
    
    # sampling a token from the allowed_prob tokens
    sample_idx = torch.multinomial(allowed_probs, num_samples=1)
    sample_idx = allowed_indices[sample_idx].item()

    return sample_idx, np.round(probabilities[sample_idx].item(), decimals=2).item()


prob = torch.tensor([0.15, 0.10, 0.40, 0.25, 0.06, 0.04])
idx, token = top_p_sampling(prob)
print(f'Token index: {idx}, Token: {token}')