In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
from torch import tensor

In [28]:
class Mask(torch.autograd.Function):
    @staticmethod
    def forward(ctx,i):
        # checking for is_cuda() is
        # a hack to work around torch.where
        # not knowing which device to put
        # the tensors on
        if i.is_cuda: 
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        N,_,H,W = i.shape
        n = 8
        L = 4
        qimp = i
        mask = torch.zeros(n, N*H*W).to(device)
        qimp_flat = qimp.view(1, N*H*W)
        for indx in range(n):
            mask[indx,:] = torch.where(indx < (n/L)*qimp_flat,torch.Tensor([1]).to(device),torch.Tensor([0]).to(device))
        mask = mask.view(n,N,H,W).permute((1,0,2,3))
        return mask

    @staticmethod
    def backward(ctx, grad_output):
        N,C,H,W = grad_output.shape
        if grad_output.is_cuda: return torch.ones(N,1,H,W).cuda()
        else: return torch.ones(N,1,H,W)

def generate_mask(x):
    return Mask.apply(x)

In [23]:
class Quantizer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        p = i.clone()
        L = 4
        for l in range(L):
            p[(p>=l/L)*(p<(l+1)/L)] = l
        return p
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def quantize_values(x):
    return Quantizer.apply(x)

In [24]:
imp_map = torch.rand(1,1,1,4,requires_grad=True)
print(imp_map)

tensor([[[[0.9232, 0.7170, 0.2459, 0.0291]]]], requires_grad=True)


In [25]:
qimp = quantize_values(imp_map)
print(imp_map)
print(qimp)

tensor([[[[0.9232, 0.7170, 0.2459, 0.0291]]]], requires_grad=True)
tensor([[[[3., 2., 0., 0.]]]], grad_fn=<QuantizerBackward>)


In [26]:
mask = generate_mask(qimp)
print(qimp)
print(mask)

torch.Size([8, 4])
tensor([[[[3., 2., 0., 0.]]]], grad_fn=<QuantizerBackward>)
tensor([[[[1., 1., 0., 0.]],

         [[1., 1., 0., 0.]],

         [[1., 1., 0., 0.]],

         [[1., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 0.]],

         [[0., 0., 0., 0.]]]], grad_fn=<MaskBackward>)


In [21]:
a = torch.Tensor([1, 2, 3, 1])
z = torch.zeros(8,a.shape[0])
a,z

(tensor([1., 2., 3., 1.]),
 tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]))

In [9]:
b = torch.arange(0,32).view(8,(a.shape[0]))
b, b[0,:]

(tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]),
 tensor([0, 1, 2, 3]))

In [10]:
for indx in range(8):
    z[indx,:] = torch.where((indx<2*a),tensor(1),tensor(0))
z

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 1., 1., 0.],
        [0., 1., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [11]:
z.shape

torch.Size([8, 4])

In [12]:
loss = mask.sum()
loss.backward()

In [13]:
imp_map.grad

tensor([[[[1., 1., 1., 1.]]]])

In [14]:
print(imp_map[0,0,0,1])
print(qimp[0,0,0,1])
print(mask[0,:,0,1])
print(mask[0,:,0,1].sum())

tensor(0.5358, grad_fn=<SelectBackward>)
tensor(2., grad_fn=<SelectBackward>)
tensor([1., 1., 1., 1., 0., 0., 0., 0.], grad_fn=<SelectBackward>)
tensor(4., grad_fn=<SumBackward0>)


In [15]:
mask.size()

torch.Size([1, 8, 1, 4])