In [1]:
import numpy as np
import torch
from tensordict import TensorDict
import numba

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

batch_size = 32
seq_len = 200

In [3]:
mask_sensorial_data = 0.5

In [4]:
generator_numpy = np.random.default_rng(0)

In [8]:
def generate_mask():
    sensorial_masks = TensorDict(
                            device=device, batch_size=batch_size)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        for batch_idx in range(batch_size):
            mask = sensorial_masks[sensorial_dim][batch_idx]

            masked_count = (
                mask.shape[0] - mask.sum())/mask.shape[0]

            if masked_count < mask_sensorial_data:
                idxs = torch.argwhere(mask != 0).flatten()
                idxs = idxs.cpu().numpy()

                to_mask_count = (mask.shape[0] *
                                    (mask_sensorial_data-masked_count))
                to_mask_count = int(
                    np.ceil(to_mask_count.cpu().item()))

                to_mask = generator_numpy.choice(
                    idxs, to_mask_count, replace=False)

                sensorial_masks[sensorial_dim][batch_idx][to_mask] = 0

    return sensorial_masks

In [16]:
%timeit generate_mask()

45.1 ms ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [41]:
@numba.njit(parallel=True)
def mask_mask(masks:np.ndarray, mask_pct:float, ):
    for batch_idx in numba.prange(batch_size):
        mask = masks[batch_idx]

        masked_count = (
            mask.shape[0] - mask.sum())/mask.shape[0]

        if masked_count < mask_pct:
            idxs = np.argwhere(mask != 0).flatten()

            to_mask_count = (mask.shape[0] *
                                (mask_pct-masked_count))
            to_mask_count = int(
                np.ceil(to_mask_count))

            to_mask = np.random.choice(
                idxs, to_mask_count, replace=False)
            
            masks[batch_idx][to_mask] = 0

    return masks

In [37]:
def generate_mask_2():
    sensorial_masks = TensorDict(
                            device=device, batch_size=batch_size)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        masks = sensorial_masks[sensorial_dim].cpu().numpy()
        sensorial_masks[sensorial_dim] = torch.tensor(mask_mask(masks, mask_sensorial_data), device=device)

    return sensorial_masks

In [42]:
sensorial_masks = generate_mask_2()

In [43]:
%timeit generate_mask_2()

916 μs ± 81.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [50]:
def generate_mask3():
    sensorial_masks = TensorDict(
                            device=device, batch_size=batch_size)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        for batch_idx in range(batch_size):
            mask = sensorial_masks[sensorial_dim][batch_idx]

            masked_count = (
                mask.shape[0] - mask.sum())/mask.shape[0]

            if masked_count < mask_sensorial_data:
                idxs = torch.argwhere(mask != 0).flatten()

                to_mask_count = (mask.shape[0] *
                                    (mask_sensorial_data-masked_count))
                to_mask_count = torch.ceil(to_mask_count).int()

                to_mask_idxs = torch.randperm(len(idxs))[:to_mask_count]
                to_mask = idxs[to_mask_idxs]

                sensorial_masks[sensorial_dim][batch_idx][to_mask] = 0
    return sensorial_masks

In [54]:
generate_mask3()["M"].sum(axis=1)

tensor([100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100], device='cuda:0')

In [None]:
generate_mask3()["M"].device

device(type='cuda', index=0)

In [55]:
%timeit generate_mask3()

55.4 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:


def generate_mask4():
    sensorial_masks = TensorDict(batch_size=batch_size,
                            )#device=device)

    for name in ["M", "C"]:
        sensorial_masks[name] = torch.ones(
        (batch_size, seq_len), dtype=bool)#, device=device)
        
    for sensorial_dim in sensorial_masks.keys():
        masks = sensorial_masks[sensorial_dim]

        masked_count = (
            masks.shape[1] - masks.sum(axis=1))/masks.shape[1]
        
        

        idxs = torch.argwhere(masks != 0)

        to_mask_count = (masks.shape[1] *
                            (mask_sensorial_data-masked_count))
        to_mask_count = torch.ceil(to_mask_count).int()
        
        

        to_mask_idxs = torch.randperm(len(idxs))
        
        to_mask = idxs[to_mask_idxs]

        print("masks", masks.shape)
        print("masked_count", masked_count.shape)
        print("masks != 0", (masks != 0).shape)
        print("idxs", idxs.shape)
        print("to_mask_count", to_mask_count.shape)
        print("to_mask_idxs", to_mask_idxs.shape)
        print("to_mask", to_mask.shape)

        sensorial_masks[sensorial_dim][to_mask] = 0
    return sensorial_masks

In [111]:
sensorial_masks = TensorDict(batch_size=batch_size,
                            )#device=device)

for name in ["M", "C"]:
    sensorial_masks[name] = torch.ones(
    (batch_size, seq_len), dtype=bool)#, device=device)
    
for sensorial_dim in sensorial_masks.keys():
    masks = sensorial_masks[sensorial_dim]

    masked_count = (
        masks.shape[1] - masks.sum(axis=1))/masks.shape[1]
    
    

    idxs = torch.argwhere(masks != 0)

    to_mask_count = (masks.shape[1] *
                        (mask_sensorial_data-masked_count))
    to_mask_count = torch.ceil(to_mask_count).int()
    

    to_mask_idxs = torch.randperm(len(idxs))
    
    to_mask = idxs[to_mask_idxs]

    print("masks", masks.shape)
    print("masked_count", masked_count.shape)
    print("masks != 0", (masks != 0).shape)
    print("idxs", idxs.shape)
    print("to_mask_count", to_mask_count.shape)
    print("to_mask_idxs", to_mask_idxs.shape)
    print("to_mask", to_mask.shape)

    sensorial_masks[sensorial_dim][to_mask] = 0

masks torch.Size([32, 200])
masked_count torch.Size([32])
masks != 0 torch.Size([32, 200])
idxs torch.Size([6400, 2])
to_mask_count torch.Size([32])
to_mask_idxs torch.Size([6400])
to_mask torch.Size([6400, 2])


IndexError: index 80 is out of bounds for dimension 0 with size 32

In [120]:
sensorial_masks["C"].shape

torch.Size([32, 200])

In [110]:
generate_mask4()

masks torch.Size([32, 200])
masked_count torch.Size([32])
masks != 0 torch.Size([32, 200])
idxs torch.Size([6400, 2])
to_mask_count torch.Size([32])
to_mask_idxs torch.Size([6400])
to_mask torch.Size([6400, 2])


IndexError: index 162 is out of bounds for dimension 0 with size 32

## Apply mask

In [9]:
masks = generate_mask()

In [12]:
masks["M"].shape

torch.Size([32, 200])

In [13]:
conditioning_mask = masks["M"]

In [25]:
def create_scales():
    gamma1 = torch.rand((32, 200, 6), device=device) 
    gamma2 = torch.rand((32, 200, 6), device=device)
    alpha1 = torch.rand((32, 200, 6), device=device) 
    alpha2 = torch.rand((32, 200, 6), device=device)

    scales = [gamma1, gamma2, alpha1, alpha2]


    return scales

scales = create_scales()

In [26]:
for scale in scales:
    scale[torch.bitwise_not(conditioning_mask)] = 0.0

In [27]:
scale[0][3], conditioning_mask[0][3]

(tensor([0., 0., 0., 0., 0., 0.], device='cuda:0'),
 tensor(False, device='cuda:0'))

In [32]:
scales = create_scales()

for scale in scales:
    scale *= conditioning_mask[..., None]

scale[0][3], conditioning_mask[0][3]

(tensor([0., 0., 0., 0., 0., 0.], device='cuda:0'),
 tensor(False, device='cuda:0'))

In [None]:
def apply1():
    for scale in scales:
        scale[torch.bitwise_not(conditioning_mask)] = 0.0

def apply2():
    for scale in scales:
        scale *= conditioning_mask[..., None]

%timeit apply1()
%timeit apply2()

182 μs ± 22.2 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
66.7 μs ± 7.69 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [35]:
def apply3():
    for scale in scales:
        scale *= conditioning_mask.unsqueeze(-1)

%timeit apply3()

62.7 μs ± 2.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [36]:
scales = create_scales()

apply3()

scale[0][3], conditioning_mask[0][3]

(tensor([0., 0., 0., 0., 0., 0.], device='cuda:0'),
 tensor(False, device='cuda:0'))

In [37]:
conditioning_dim = 6
conditioning = torch.nn.Linear(conditioning_dim, 6*conditioning_dim, bias=True)

In [83]:
conditioning_data = conditioning(torch.rand(32, 200, 6)).to(device)

conditioning_data *= conditioning_mask.unsqueeze(-1)



In [90]:
conditioning_data = conditioning(torch.rand(32, 200, 6)).to(device)
def apply4():
    global conditioning_data
    conditioning_data *= conditioning_mask.unsqueeze(-1)

%timeit apply4()

22.8 μs ± 5.1 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [91]:
182 / 22.8

7.982456140350877

In [None]:
gamma1, beta1, alpha1, gamma2, beta2, alpha2 = conditioning_data.chunk(6, dim=2)

In [57]:
alpha2[0][3], conditioning_mask[0][3]

(tensor([0., -0., 0., -0., 0., -0.], device='cuda:0', grad_fn=<SelectBackward0>),
 tensor(False, device='cuda:0'))

In [93]:
%timeit gamma1.clone()

31.2 μs ± 4.81 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [121]:
time = 31.2*10e-6
time *= 6 #params
time *= 250 #steps
time *= 5 #epochs
time *= 5 #runs

time

11.7

In [123]:
time = (182-22.8)*10e-6
time *= 250 #steps
time *= 5 #epochs
time *= 5 #runs

time

9.950000000000001