In [3]:
import torch
from torch.utils.data import WeightedRandomSampler

In [149]:
def geometric_sampling(p, size=1, max_n=10):
    # Generate probabilities for integers 1 to n
    probs = (1 - p) * p ** torch.arange(max_n, dtype=torch.float32)
    probs /= probs.sum()  # Normalize the probabilities to sum to 1
    
    # Calculate the total number of samples required
    num_samples = torch.prod(torch.tensor(size)).item()
    
    # Sample integers with the calculated probabilities
    samples = torch.multinomial(probs, num_samples=num_samples, replacement=True) + 1
    
    # Reshape the samples to the specified size
    samples = samples.view(size)
    return samples

# Example usage
p = 0.6
samples = geometric_sampling(p, size=(2, 1, 10, 10))
# Count the occurrences of each integer
counts = torch.bincount(samples, minlength=21)[1:]
print(counts)

tensor([402, 240, 144,  74,  62,  35,  23,   9,   6,   5,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0])


In [64]:
rates = torch.rand(4, 4) * 0.95
torch.poisson(rates)

tensor([[0., 1., 1., 1.],
        [0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 1., 1., 0.]])

In [78]:
# torch.distributions.binomial.Binomial(total_count=5, probs=0.9).sample()
torch.round(torch.exp(torch.rand(4, 4)))

tensor([[2., 2., 2., 2.],
        [1., 1., 2., 2.],
        [3., 2., 1., 2.],
        [2., 2., 1., 2.]])

In [122]:
def build_sharekv_mask(weight, N):
    bs, _, _, L = weight.shape
    
    # Create a random width tensor for each row (values between 1 and N)
    random_widths = torch.randint(0, N, (bs, L))
    
    # Create a mask for the diagonal stripes
    index = torch.arange(L).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, L)
    row_indices = torch.arange(L).unsqueeze(1).unsqueeze(0)  # Shape: (1, L, 1)

    # Create a mask using broadcasting to set the diagonal stripes properly
    mask = (index >= row_indices - random_widths.unsqueeze(2)).unsqueeze(1)

    # Convert the mask to a lower triangular matrix
    mask = torch.tril(mask)

    return mask

In [120]:
import torch

# Define the size of the mask (e.g., L x L)
batch_size = 2
L = 10
N = 5  # Maximum width of the diagonal stripe

# Create a random width tensor for each row (values between 1 and N)
random_widths = torch.randint(0, N, (batch_size, L))

# mask = torch.arange(L).unsqueeze(0) >= torch.arange(L).unsqueeze(1) - random_widths.unsqueeze(1)
index = torch.arange(L).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, L)
row_indices = torch.arange(L).unsqueeze(1).unsqueeze(0)  # Shape: (1, L, 1)

# Create a mask using broadcasting to set the diagonal stripes properly
mask = (index >= row_indices - random_widths.unsqueeze(2)).unsqueeze(1)

# Convert the mask to a lower triangular matrix
mask = torch.tril(mask)

print(mask.shape)
print(mask)

torch.Size([2, 1, 10, 10])
tensor([[[[ True, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False],
          [False, False, False,  True,  True, False, False, False, False, False],
          [False, False, False,  True,  True,  True, False, False, False, False],
          [False, False, False, False, False, False,  True, False, False, False],
          [False, False, False, False, False, False,  True,  True, False, False],
          [False, False, False, False, False,  True,  True,  True,  True, False],
          [False, False, False, False, False, False,  True,  True,  True,  True]]],


        [[[ True, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, 

In [1]:
(None, )

(None,)

In [121]:
torch.arange(L).unsqueeze(0).unsqueeze(0).shape, torch.arange(L).unsqueeze(1).unsqueeze(0).shape

(torch.Size([1, 1, 10]), torch.Size([1, 10, 1]))

In [104]:
import torch

# Define the size of the mask (e.g., L x L)
L = 10
N = 10  # Maximum width of the diagonal stripe

# Create a random width tensor for each row (values between 1 and N)
random_widths = torch.randint(0, N, (L,))

# Create a mask using broadcasting to set the diagonal stripes properly
mask = torch.arange(L).unsqueeze(0) >= (torch.arange(L) - random_widths).unsqueeze(1)

# Convert the mask to a lower triangular matrix
mask = torch.tril(mask)

print(mask)

tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False],
        [False,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [False, False, False, False,  True,  True, False, False, False, False],
        [False,  True,  True,  True,  True,  True,  True, False, False, False],
        [False, False, False,  True,  True,  True,  True,  True, False, False],
        [False, False, False,  True,  True,  True,  True,  True,  True, False],
        [False, False, False, False,  True,  True,  True,  True,  True,  True]])


In [90]:
torch.arange(L).unsqueeze(0)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [89]:
torch.arange(L).unsqueeze(1) - random_widths.unsqueeze(1)

tensor([[-4],
        [ 0],
        [-5],
        [ 1],
        [ 4],
        [ 0],
        [ 6],
        [ 2],
        [ 2],
        [ 5]])

In [91]:
torch.arange(L).unsqueeze(0) >= torch.arange(L).unsqueeze(1) - random_widths.unsqueeze(1)

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True]])

In [67]:
# Example usage:
bs = 1
num_heads = 2
hidden_dim = 5 # number of tokens per bin
sharedkv_N = 3

# Random tensor for demonstration purposes
attn_weights = torch.rand(bs, num_heads, hidden_dim, hidden_dim)
v_weights = torch.rand(bs, num_heads, hidden_dim, 8)

sharedkv_mask = torch.ones((bs, 1, hidden_dim, hidden_dim), dtype=torch.bool)
sharedkv_mask = torch.tril(sharedkv_mask, diagonal=-sharedkv_N) ^ torch.tril(sharedkv_mask, diagonal=0)
# sharedkv_mask = torch.triu(sharedkv_mask, diagonal=1-sharedkv_N) ^ torch.triu(sharedkv_mask, diagonal=1)
inv_sharedkv_mask = ~sharedkv_mask
print(sharedkv_mask)
print(inv_sharedkv_mask)

tensor([[[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  True,  True, False, False],
          [False,  True,  True,  True, False],
          [False, False,  True,  True,  True]]]])
tensor([[[[False,  True,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False, False,  True,  True],
          [ True, False, False, False,  True],
          [ True,  True, False, False, False]]]])


In [26]:
print(sharedkv_mask)
print(inv_sharedkv_mask)

tensor([[[[ True, False, False, False, False],
          [False,  True, False, False, False],
          [False, False,  True, False, False],
          [False, False, False,  True, False],
          [False, False, False, False,  True]]]])
tensor([[[[False,  True,  True,  True,  True],
          [ True, False,  True,  True,  True],
          [ True,  True, False,  True,  True],
          [ True,  True,  True, False,  True],
          [ True,  True,  True,  True, False]]]])


In [21]:
triu_inv_sharedkv_mask

tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1.],
        [1., 0., 0., 1., 1., 1., 1., 1.],
        [1., 1., 0., 0., 1., 1., 1., 1.],
        [1., 1., 1., 0., 0., 1., 1., 1.]])

In [7]:
import torch

def count_top_k_bins(output_prob, k):
    # n: number of bins, vocab_dim: number of tokens in each bin
    n, vocab_dim = output_prob.shape
    
    # Flatten the tensor to treat it as a single list of tokens
    flattened_probs = output_prob.flatten()  # Shape: (n * vocab_dim)
    
    # Get the indices of the top k values
    topk_values, topk_indices = torch.topk(flattened_probs, k)
    
    # Convert flat indices back to bin indices
    topk_bin_ids = topk_indices // vocab_dim  # Dividing by vocab_dim gives the bin index

    # Count occurrences of each bin
    bin_counts = torch.bincount(topk_bin_ids, minlength=n)
    
    return bin_counts

# Example usage:
n = 5  # number of bins
vocab_dim = 10  # number of tokens per bin
k = 8  # top k tokens to select

# Random tensor for demonstration purposes
output_prob = torch.rand(n, vocab_dim)

# Get the count of top-k tokens from each bin
bin_counts = count_top_k_bins(output_prob, k)
print(bin_counts)

tensor([0, 1, 3, 2, 2])


In [2]:
def assign_balls_to_baskets(P: torch.Tensor, k: int) -> torch.Tensor:
    """
    Assign k balls into n baskets with probability P for each basket.

    Args:
        P (torch.Tensor): A 1D tensor of size n representing the probability distribution of baskets.
        k (int): The number of balls to assign.

    Returns:
        torch.Tensor: A 1D tensor of size n representing the number of balls in each basket.
    """
    # Ensure the probabilities sum to 1
    P = P / P.sum()
    
    # Sample `k` balls according to the probability distribution
    ball_assignments = torch.multinomial(P, k, replacement=True)
    
    # Count the number of balls in each basket
    ball_counts = torch.bincount(ball_assignments, minlength=P.size(0))
    
    return ball_counts

# Example usage
P = torch.tensor([0.1, 0.2, 0.3, 0.25, 0.15])  # Probability distribution for the baskets
n = P.size(0)  # number of baskets
k = 10  # number of balls

ball_distribution = assign_balls_to_baskets(P, k)
print(ball_distribution)

tensor([0, 3, 4, 3, 0])


In [6]:
%%time
ball_distribution = assign_balls_to_baskets(P, k)

CPU times: user 194 μs, sys: 297 μs, total: 491 μs
Wall time: 351 μs


In [14]:
ball_distribution

tensor([1, 3, 3, 3, 0])

In [13]:
ball_distribution + ball_distribution.sum() - ball_distribution.cumsum(dim=0)

tensor([10,  9,  6,  3,  0])

In [10]:
ball_distribution.cumsum(dim=0)

tensor([ 1,  4,  7, 10, 10])

In [5]:
P_new = torch.softmax(P, dim=0)

In [4]:
print(P_new)
print(P)

tensor([0.1805, 0.1995, 0.2205, 0.2097, 0.1898])
tensor([0.1000, 0.2000, 0.3000, 0.2500, 0.1500])


In [12]:
# x = torch.rand(2, 5, 512) # torch.Size([2, 5, 512])


In [5]:
mask = torch.tensor([[1, 0, 1, 0, 1], [1, 1, 1, 0, 1]], dtype=torch.bool)
print(mask.shape) # torch.Size([2, 5])

x = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]], [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]], dtype=torch.float32)
print(x.shape)
bs, seq_len, dim = x.shape

x = x[mask]
print(x)



torch.Size([2, 5])
torch.Size([2, 5, 3])
tensor([[ 1.,  2.,  3.],
        [ 7.,  8.,  9.],
        [13., 14., 15.],
        [ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [13., 14., 15.]])


In [6]:
x.shape

torch.Size([6, 512])