In [32]:
import numpy as np
import torch
from tensordict import TensorDict
import numba

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

batch_size = 32
seq_len = 200

In [8]:
mask_sensorial_data = 0.5

In [14]:
generator_numpy = np.random.default_rng(0)

In [None]:
def generate_mask():
    sensorial_masks = TensorDict(
                            device=device, batch_size=batch_size)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        for batch_idx in range(batch_size):
            mask = sensorial_masks[sensorial_dim][batch_idx]

            masked_count = (
                mask.shape[0] - mask.sum())/mask.shape[0]

            if masked_count < mask_sensorial_data:
                idxs = torch.argwhere(mask != 0).flatten()
                idxs = idxs.cpu().numpy()

                to_mask_count = (mask.shape[0] *
                                    (mask_sensorial_data-masked_count))
                to_mask_count = int(
                    np.ceil(to_mask_count.cpu().item()))

                to_mask = generator_numpy.choice(
                    idxs, to_mask_count, replace=False)

                sensorial_masks[sensorial_dim][batch_idx][to_mask] = 0

In [16]:
%timeit generate_mask()

45.1 ms ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [41]:
@numba.njit(parallel=True)
def mask_mask(masks:np.ndarray, mask_pct:float, ):
    for batch_idx in numba.prange(batch_size):
        mask = masks[batch_idx]

        masked_count = (
            mask.shape[0] - mask.sum())/mask.shape[0]

        if masked_count < mask_pct:
            idxs = np.argwhere(mask != 0).flatten()

            to_mask_count = (mask.shape[0] *
                                (mask_pct-masked_count))
            to_mask_count = int(
                np.ceil(to_mask_count))

            to_mask = np.random.choice(
                idxs, to_mask_count, replace=False)
            
            masks[batch_idx][to_mask] = 0

    return masks

In [37]:
def generate_mask_2():
    sensorial_masks = TensorDict(
                            device=device, batch_size=batch_size)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        masks = sensorial_masks[sensorial_dim].cpu().numpy()
        sensorial_masks[sensorial_dim] = torch.tensor(mask_mask(masks, mask_sensorial_data), device=device)

    return sensorial_masks

In [42]:
sensorial_masks = generate_mask_2()

In [43]:
%timeit generate_mask_2()

916 μs ± 81.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [50]:
def generate_mask3():
    sensorial_masks = TensorDict(
                            device=device, batch_size=batch_size)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        for batch_idx in range(batch_size):
            mask = sensorial_masks[sensorial_dim][batch_idx]

            masked_count = (
                mask.shape[0] - mask.sum())/mask.shape[0]

            if masked_count < mask_sensorial_data:
                idxs = torch.argwhere(mask != 0).flatten()

                to_mask_count = (mask.shape[0] *
                                    (mask_sensorial_data-masked_count))
                to_mask_count = torch.ceil(to_mask_count).int()

                to_mask_idxs = torch.randperm(len(idxs))[:to_mask_count]
                to_mask = idxs[to_mask_idxs]

                sensorial_masks[sensorial_dim][batch_idx][to_mask] = 0
    return sensorial_masks

In [54]:
generate_mask3()["M"].sum(axis=1)

tensor([100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100], device='cuda:0')

In [None]:
generate_mask3()["M"].device

device(type='cuda', index=0)

In [55]:
%timeit generate_mask3()

55.4 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [66]:
def generate_mask4():
    sensorial_masks = TensorDict(
                            device=device, batch_size=batch_size)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        masks = sensorial_masks[sensorial_dim]

        masked_count = (
            masks.shape[1] - masks.sum(axis=1))/masks.shape[1]
        
        print(masks.shape, masked_count.shape)

        idxs = torch.argwhere(masks != 0)

        to_mask_count = (masks.shape[1] *
                            (mask_sensorial_data-masked_count))
        to_mask_count = torch.ceil(to_mask_count).int()
        
        print(to_mask_count.shape)

        to_mask_idxs = torch.randperm(len(idxs))
        print(to_mask_idxs.shape)
        to_mask = idxs[to_mask_idxs]

        sensorial_masks[sensorial_dim][to_mask] = 0
    return sensorial_masks

In [67]:
generate_mask4()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
