In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BlockDropout(nn.Module):
  def __init__(self,p:float=0.5,block_size:float=2):
    super(BlockDropout,self).__init__()

    assert p < 1 and p >= 0
    self.p = p
    self.block_size = block_size
  def forward(self,x):
    if self.training:
      og_x = x
      *B,C = x.shape
      with torch.no_grad():
        # include C candidate vectors-to-zero for every instance of a channel
        dropout_mask = torch.rand(x.shape,device=x.device) > self.p # i.e. this is 1 if you're *keeping* a dim
        N = C // self.block_size
        R = C % self.block_size
        dropout_mask[...,:N*self.block_size] = dropout_mask[...,:N].reshape(-1).repeat_interleave(self.block_size).view(*B,N*self.block_size)

      return x * dropout_mask.clone() / (1-self.p)

    return x


In [None]:
demo_tensor = torch.zeros(3,5,7) # B,N,C

BlockDropout(0.1,block_size=5)(demo_tensor)