In [1]:
import numpy as np
seen = [0.1185,0.1253,0.1257] # seen MAPE
unseen = [0.1872,0.1873,0.1863]
print(f'seen| {np.mean(seen)}±{np.std(seen)}')
print(f'unseen| {np.mean(unseen)}±{np.std(unseen)}')

seen| 0.12316666666666666±0.0033038697848970346
unseen| 0.1869333333333333±0.0004496912521077373


In [13]:
import torch
import torch.nn.functional as F
from layers.MOE_dispatcher import MOEDispatcher
import pickle
import torch

def top_p_mask(logits, p=0.9):
    """
    Creates a mask tensor with the same shape as logits where selected indices are set to 1 and unselected indices are set to 0.

    Parameters:
    - logits: A tensor of shape [B, N] where B is batch size and N is the number of logits.
    - p: The cumulative probability threshold.

    Returns:
    - mask: A tensor of shape [B, N] with 1s for selected indices and 0s for unselected indices.
    """
    # Calculate probabilities using softmax
    probabilities = torch.softmax(logits, dim=-1)

    # Sort probabilities and corresponding indices in descending order
    sorted_probs, sorted_indices = torch.sort(probabilities, dim=-1, descending=True)

    # Calculate cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # Create a mask to select indices where cumulative probability is less than p
    mask = cumulative_probs < p

    # Ensure we always select at least one element
    # Find the first index where cumulative probability is greater than p and set it to True
    mask[:, 0] = True

    # Initialize a mask tensor with zeros
    output_mask = torch.zeros_like(logits, dtype=torch.int)

    # Use the mask to set the selected indices in the output mask tensor to 1
    output_mask.scatter_(1, sorted_indices, mask.int())

    return output_mask

# Example usage
logits = torch.randn(3, 10)  # Example logits tensor with batch size 3 and 10 classes
p = 0.9

indices = top_p_mask(logits, p)
print("Selected indices:")
print(indices)

Selected indices:
tensor([[1, 1, 0, 1, 0, 1, 1, 0, 0, 1],
        [0, 1, 1, 0, 1, 1, 1, 1, 0, 1],
        [1, 0, 1, 0, 1, 1, 1, 0, 0, 1]], dtype=torch.int32)


In [3]:
import math
def compute_pooled_length(stride=2, kernel_size=2, padding=0, dilation=1):
    original_L = 100
    for _ in range(2):
        new_L = math.floor((original_L + 2*padding - dilation*(kernel_size-1) - 1) / stride + 1)
        original_L = new_L
    return new_L

compute_pooled_length()

25

In [6]:
range_tensor = torch.arange(100).expand(4, 100)
lengths = torch.tensor([1, 20, 30, 40])
attention_mask = range_tensor >= lengths.unsqueeze(1)
attention_mask = attention_mask.int()
attention_mask

tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [2]:
_, indices = torch.topk(logits, top_k, dim=1)
# Create a mask where only the top-K values will be kept
mask = torch.zeros_like(logits, dtype=torch.bool)
# Scatter the mask at the indices of the top-K values
mask.scatter_(1, indices, 1) # 0 indicates mask
logits = F.softmax(logits, dim=1) # [B, num_experts]
raw_logits = logits.clone()
# logits.masked_fill_(mask==0, 0) # [B, num_experts]
logits = logits * mask
de_norm = torch.sum(logits, dim=1) + eps
logits = logits / de_norm.unsqueeze(-1)
logits

tensor([[0.4841, 0.5159, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4994, 0.5006],
        [0.5037, 0.0000, 0.0000, 0.4963]])

In [3]:
my_dispatcher = MOEDispatcher(2, logits, 2)

In [4]:
print("_batch_index", my_dispatcher._batch_index)
print("_nonzero_gates", my_dispatcher._nonzero_gates)
print("_part_sizes", my_dispatcher._part_sizes)
print(my_dispatcher.dispatch())

_batch_index tensor([0, 2, 0, 1, 1, 2])
_nonzero_gates tensor([[0.4841],
        [0.5037],
        [0.5159],
        [0.4994],
        [0.5006],
        [0.4963]])
_part_sizes [2, 1, 1, 2]
(tensor([0, 2]), tensor([0]), tensor([1]), tensor([1, 2]))


In [5]:
def rearrange_tensor(X, _batch_index):
    # Get the batch size B
    B = len(_batch_index) // 2
    
    # Create an index tensor to rearrange X
    # This will have shape [B, 2] where each row contains indices of the two representations
    indices = torch.zeros((B, 2), dtype=torch.long)
    
    # Fill in the indices array
    for i in range(B):
        indices[i] = torch.where(_batch_index == i)[0]
    
    # Use advanced indexing to gather the pairs
    Y = X[indices]
    
    return Y

_batch_index = torch.FloatTensor([1,0,2,2,0,1])
X = torch.rand(6,12)
rearrange_tensor(X,_batch_index).shape

torch.Size([3, 2, 12])