In [1]:
import torch

In [5]:
def apply_masks(x, masks):
    """
    :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
    :param masks: list of tensors containing indices of patches in [N] to keep
    """
    all_x = []
    for m in masks:
        mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
        all_x += [torch.gather(x, dim=1, index=mask_keep)]
    return torch.cat(all_x, dim=0)

In [3]:
x = torch.tensor([
    [
        [1,2],
        [1,2]
    ],
    [
        [3, 4],
        [5, 6]
    ]
])

In [4]:
x.shape

torch.Size([2, 2, 2])

In [19]:
masks = torch.tensor([
    [0,1],
    [1,1]
])

for m in masks:
    print(m.unsqueeze(-1).repeat(1,1,5))


tensor([[[0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1]]])
tensor([[[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1]]])


In [6]:
x = torch.tensor([
    [
        [1, 1, 1],
        [2, 2, 2], 
        [3, 3, 3], 
        [4, 4, 4]
    ],  # Batch 1 patches
    [
        [5, 5, 5], 
        [6, 6, 6], 
        [7, 7, 7], 
        [8, 8, 8]
    ]   # Batch 2 patches
])

# masks: list of tensors with indices of patches to keep in each batch
masks = [
    torch.tensor([
        [1,2],
        [2,3]
    ])
]

In [7]:
apply_masks(x, masks)

tensor([[[2, 2, 2],
         [3, 3, 3]],

        [[7, 7, 7],
         [8, 8, 8]]])

In [38]:
x.shape

torch.Size([2, 4, 3])

In [39]:
apply_masks(x, masks)

tensor([[[1, 1, 1],
         [3, 3, 3]],

        [[2, 2, 2],
         [4, 4, 4]]])

In [27]:
torch.gather(x, dim=1, index=torch.tensor([[[1, 0, 0],
         [2, 2, 2]]]))

tensor([[[2, 1, 1],
         [3, 3, 3]]])

In [27]:
x = torch.tensor([
    [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],  # Image 1
    [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]]  # Image 2
])  # Shape: [2 (B), 4 (N), 3 (D)]


# Mask to keep patches 0 and 2 from Image 1 and 1 and 3 from Image 2
mask_1 = torch.tensor([0,2])  # Keep patches 0 and 2 from Image 1
mask_2 = torch.tensor([1, 1])  # Keep patches 1 and 3 from Image 2

masks = [mask_1, mask_2]


result = apply_masks(x, masks)
print(result)
tensor([
    [[ 1,  2,  3],
         [ 7,  8,  9]],

        [[13, 14, 15],
         [19, 20, 21]],

        [[ 4,  5,  6],
         [ 4,  5,  6]],

        [[16, 17, 18],
         [16, 17, 18]]
])


tensor([[[ 1,  2,  3],
         [ 7,  8,  9]],

        [[13, 14, 15],
         [19, 20, 21]],

        [[ 4,  5,  6],
         [ 4,  5,  6]],

        [[16, 17, 18],
         [16, 17, 18]]])


NameError: name 'tensor' is not defined

In [19]:
x = torch.randn(2, 8, 3)
masks = [torch.tensor([1, 3, 5]), torch.tensor([0, 2, 6])]

# Apply masks
result = apply_masks(x, masks)

print(result.shape)  # torch.Size([4, 3, 3])

torch.Size([4, 3, 3])


In [16]:
x

tensor([[[ 0.3405,  1.0730,  0.8639],
         [-0.6915, -0.8233,  2.2888],
         [-0.6145, -0.4392,  0.7010],
         [-1.0011,  0.4583,  1.0431],
         [ 0.3076, -0.8169, -0.8629],
         [ 0.5445,  2.0003, -1.6975],
         [-1.8074,  0.6449,  1.0224],
         [ 0.1282,  1.1422, -0.9853]],

        [[ 0.1313, -1.7302,  0.8081],
         [-0.1802,  0.7888,  0.5702],
         [-1.0691, -0.4204,  0.0851],
         [ 1.3423,  1.1664, -0.9405],
         [-1.6101, -0.9445,  0.6677],
         [ 2.7764,  0.1288, -0.6273],
         [-0.0263, -0.6501,  0.2953],
         [-0.5329, -0.3462,  0.7117]]])

In [17]:
result

tensor([[[-0.6915, -0.8233,  2.2888],
         [-1.0011,  0.4583,  1.0431],
         [ 0.5445,  2.0003, -1.6975]],

        [[ 0.3405,  1.0730,  0.8639],
         [-0.6145, -0.4392,  0.7010],
         [-1.8074,  0.6449,  1.0224]]])

In [24]:
a = torch.tensor([[1,2,3],[4,5,6]])

In [25]:
a.repeat(5,1,1)

tensor([[[1, 2, 3],
         [4, 5, 6]],

        [[1, 2, 3],
         [4, 5, 6]],

        [[1, 2, 3],
         [4, 5, 6]],

        [[1, 2, 3],
         [4, 5, 6]],

        [[1, 2, 3],
         [4, 5, 6]]])