In [1]:
import torch
import torch.nn.functional as F
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create a dummy batch of image patches with random data
# Batch size = 1, Channels = 3 (RGB), Height = 10, Width = 10
dummy_patches = torch.rand(1, 3, 10, 10)

In [3]:
def get_patch_images_info(patch_images, sample_patch_num, device):
    # Assuming embed_images is a method that applies a CNN (like ResNet)
    image_embed = patch_images  # Direct assignment for debugging
    h, w = image_embed.shape[2], image_embed.shape[3]
    image_num_patches = h * w
    image_padding_mask = patch_images.new_zeros(
        (patch_images.size(0), image_num_patches)).bool()
    image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
        torch.arange(h).unsqueeze(1) * 10 + 1  # Example image_bucket_size = 10
    image_position_idx = image_position_idx.view(-1).to(device)
    image_position_ids = image_position_idx[None, :].expand(
        patch_images.size(0), image_num_patches)

    image_embed = image_embed.flatten(2).transpose(1, 2)
    if sample_patch_num is not None:
        patch_orders = [
            random.sample(range(image_num_patches), k=sample_patch_num)
            for _ in range(patch_images.size(0))
        ]
        patch_orders = torch.LongTensor(patch_orders).to(device)
        image_embed = image_embed.gather(1, patch_orders.unsqueeze(
            2).expand(-1, -1, image_embed.size(2)))
        image_num_patches = sample_patch_num
        image_padding_mask = image_padding_mask.gather(1, patch_orders)
        image_position_ids = image_position_ids.gather(1, patch_orders)

    return image_embed, image_num_patches, image_padding_mask, image_position_ids


In [4]:
# Set the device to 'cpu' or 'cuda' if you are using GPU
device = 'cpu'
sample_patch_num = 5  # Optional: specify a number of patches to sample

# Call the function
embeds, num_patches, padding_mask, position_ids = get_patch_images_info(dummy_patches, sample_patch_num, device)

# Print the outputs to observe changes
print("Embedded Patches Shape:", embeds.shape)
print("Number of Patches:", num_patches)
print("Padding Mask Shape:", padding_mask.shape)
print("Position IDs Shape:", position_ids.shape)

Embedded Patches Shape: torch.Size([1, 5, 3])
Number of Patches: 5
Padding Mask Shape: torch.Size([1, 5])
Position IDs Shape: torch.Size([1, 5])
