In [13]:
import torch
import torch.nn.functional as F

In [21]:
class SelectionMask:
    def __init__(self, shape, dtype=None, device=None):
        self.mask = torch.randn(shape, dtype=dtype, device=device)
        self.mask[self.mask == 0] = torch.finfo(dtype).tiny #Ensures abscense of zeros
        
    # def dot(self, tensor):
    #     return torch.dot(self.mask.flatten(), tensor.flatten().to(self.device))
    
    def interpolate(self, new_size, mode="nearest"):
        input_4d = self.mask.unsqueeze(0).unsqueeze(0)  # Adds batch and channel dims

        interpolated = F.interpolate(
            input_4d.float(),
            size=new_size,
            mode=mode,
        ).squeeze(0).squeeze(0)  # Remove batch e channel dims

        return interpolated.to(self.mask.dtype)
    
    def binarize(self, treshold=0.5):
        return (torch.abs(self.mask) >= treshold).float()

In [22]:
x = torch.rand((2, 2, 3))
x

tensor([[[0.7548, 0.2481, 0.7681],
         [0.9982, 0.8777, 0.0892]],

        [[0.0799, 0.2758, 0.3649],
         [0.8682, 0.9950, 0.3230]]])

In [23]:
mask = SelectionMask(shape=(4,4,2), dtype=x.dtype)
mask.mask

tensor([[[ 0.6520, -0.5271],
         [ 0.0185, -0.9295],
         [-0.1075,  0.6466],
         [ 0.5247, -0.5613]],

        [[-0.3560,  0.0277],
         [ 0.7789, -0.0232],
         [-1.1976,  0.1776],
         [-1.6394, -0.3975]],

        [[ 0.0519,  1.9313],
         [ 0.6923, -0.4437],
         [ 0.8182,  0.3519],
         [-2.4245, -0.3636]],

        [[-1.3998, -1.1266],
         [ 0.6156,  1.0660],
         [ 1.2461,  1.2641],
         [-1.0975,  1.2361]]])

In [24]:
mask.binarize()

tensor([[[1., 1.],
         [0., 1.],
         [0., 1.],
         [1., 1.]],

        [[0., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.]],

        [[0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])