This Python function, `collate_data_and_cast`, is responsible for processing a list of samples and generating collated data for further processing.

**Calculate Number of Masked Samples**:
   - The number of samples to be masked is determined based on the specified probability. The variable `n_samples_masked` represents this count.

**Generate Masks**:
   - Masks are generated based on the calculated probability. The probability values are divided into segments, and masks are generated for each segment.
   - The masks are stored in a list, and the variable `upperbound` keeps track of the upper bound for masking patches.
   - The function shuffles the list of masks.

**Collate Masks**:
   - The list of masks is stacked into a single tensor, `collated_masks`, and flattened. It represents the masking patches for all samples.
**Generate Mask Indices**:
   - The function generates mask indices by finding the nonzero elements in the `collated_masks` tensor.

**Calculate Mask Weights**:
   - Mask weights are calculated based on the reciprocal of the number of masking patches in each sample.

**Return Data Dictionary**:
   - The function returns a dictionary containing various collated data:
     - `"collated_global_crops"`: Collated global crops converted to the specified data type.
     - `"collated_local_crops"`: Collated local crops converted to the specified data type.
     - `"collated_masks"`: Collated masks.
     - `"mask_indices_list"`: Mask indices.
     - `"masks_weight"`: Mask weights.
     - `"upperbound"`: The upper bound for masking patches.
     - `"n_masked_patches"`: The number of masked patches.

This function is designed to take a set of data samples, generate masks for a specified proportion of those samples, and prepare the data for further processing, such as training a machine learning model. The masks are used to selectively mask parts of the data, which is a common technique for various computer vision tasks, including self-supervised learning and image inpainting.

In [1]:
import torch
import random


def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
    # dtype = torch.half  # TODO: Remove

    n_global_crops = len(samples_list[0][0]["global_crops"])
    n_local_crops = len(samples_list[0][0]["local_crops"])

    collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])

    collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])

    B = len(collated_global_crops)
    N = n_tokens
    n_samples_masked = int(B * mask_probability)
    probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
    upperbound = 0
    masks_list = []
    for i in range(0, n_samples_masked):
        prob_min = probs[i]
        prob_max = probs[i + 1]
        masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
        upperbound += int(N * prob_max)
    for i in range(n_samples_masked, B):
        masks_list.append(torch.BoolTensor(mask_generator(0)))

    random.shuffle(masks_list)

    collated_masks = torch.stack(masks_list).flatten(1)
    mask_indices_list = collated_masks.flatten().nonzero().flatten()

    masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]

    return {
        "collated_global_crops": collated_global_crops.to(dtype),
        "collated_local_crops": collated_local_crops.to(dtype),
        "collated_masks": collated_masks,
        "mask_indices_list": mask_indices_list,
        "masks_weight": masks_weight,
        "upperbound": upperbound,
        "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
    }