In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_01 import *

In [3]:
#export
class DepthToSpace(torch.nn.Module):
    def __init__(self,block_size):
        super().__init__()
        self.bs = block_size
        
    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)
        return x

In [4]:
#export
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(conv(64,512,1,1,0), relu,
                                    ResBlock(512), relu,
                                    ResBlock(512), relu,
                                    DepthToSpace(2),
                                    conv(128,256), relu,
                                    ResBlock(256), relu,
                                    DepthToSpace(4),
                                    conv(16,32), relu,
                                    conv(32,3))
        
    def extra_repr(self):
        params = sum(p.numel() for p in self.parameters())
        return f'Total Params: {params}'
        
    def forward(self,x):
        return self.decoder(x)

In [5]:
dec = Decoder().cuda()
from torchsummary import summary
summary(dec, (64,16,16))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 512, 16, 16]          33,280
              ReLU-2          [-1, 512, 16, 16]               0
            Conv2d-3          [-1, 128, 16, 16]         589,952
            Conv2d-4          [-1, 512, 16, 16]         590,336
          ResBlock-5          [-1, 512, 16, 16]               0
              ReLU-6          [-1, 512, 16, 16]               0
            Conv2d-7          [-1, 128, 16, 16]         589,952
            Conv2d-8          [-1, 512, 16, 16]         590,336
          ResBlock-9          [-1, 512, 16, 16]               0
             ReLU-10          [-1, 512, 16, 16]               0
     DepthToSpace-11          [-1, 128, 32, 32]               0
           Conv2d-12          [-1, 256, 32, 32]         295,168
             ReLU-13          [-1, 256, 32, 32]               0
           Conv2d-14          [-1, 128,

In [6]:
decoder = Decoder(); decoder

Decoder(
  Total Params: 3284739
  (decoder): Sequential(
    (0): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): ResBlock(
      (conv1): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ReLU()
    (4): ResBlock(
      (conv1): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (5): ReLU()
    (6): DepthToSpace()
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): ResBlock(
      (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (10): ReLU()
    (11): DepthToSpace()
    (12): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14):

In [7]:
torch.where??

Entropy Based Layers

In [8]:
#export
class Quantizer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        p = i.clone()
        L = 16
        for l in range(L):
            p[(p>=l/L)*(p<(l+1)/L)] = l
        return p
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def quantize_values(x):
    return Quantizer.apply(x)

In [9]:
#export
class Mask(torch.autograd.Function):
    @staticmethod
    def forward(ctx,i):
        # checking for is_cuda() is
        # a hack to work around torch.where
        # not knowing which device to put
        # the tensors on
        if i.is_cuda: 
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        N,_,H,W = i.shape
        n = 64
        L = 16
        mask = torch.zeros(n, N*H*W).to(device)
        qimp_flat = qimp.view(1, N*H*W)
        for indx in range(n):
            mask[indx,:] = torch.where(indx < (n/L)*qimp_flat,torch.Tensor([1]).to(device),torch.Tensor([0]).to(device))
        mask = mask.view(n,N,H,W).permute((1,0,2,3))
        return mask

    @staticmethod
    def backward(ctx, grad_output):
        N,C,H,W = grad_output.shape
        if grad_output.is_cuda: return torch.ones(N,1,H,W).cuda()
        else: return torch.ones(N,1,H,W)

def generate_mask(x):
    return Mask.apply(x)

In [10]:
a = torch.rand(4,4,requires_grad=True)
print(a)

tensor([[0.6299, 0.1945, 0.6043, 0.6919],
        [0.5468, 0.9871, 0.0042, 0.4288],
        [0.2492, 0.4456, 0.4644, 0.1328],
        [0.5835, 0.0848, 0.6075, 0.4402]], requires_grad=True)


In [11]:
b = bin_values(a)
print(a)
print(b)

tensor([[0.6299, 0.1945, 0.6043, 0.6919],
        [0.5468, 0.9871, 0.0042, 0.4288],
        [0.2492, 0.4456, 0.4644, 0.1328],
        [0.5835, 0.0848, 0.6075, 0.4402]], requires_grad=True)
tensor([[1., 0., 1., 1.],
        [1., 1., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.]], grad_fn=<BinarizerBackward>)


In [12]:
loss = b.sum()
print('Loss:{}'.format(loss))

Loss:7.0


In [13]:
loss.backward()
a.grad

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [14]:
imp_map = torch.rand(1,1,4,4,requires_grad=True)
print(imp_map)

tensor([[[[0.1982, 0.5915, 0.4897, 0.7062],
          [0.4668, 0.7811, 0.8191, 0.5416],
          [0.1989, 0.2273, 0.0747, 0.9160],
          [0.6714, 0.8860, 0.6340, 0.7865]]]], requires_grad=True)


In [15]:
qimp = quantize_values(imp_map)
print(imp_map)
print(qimp)

tensor([[[[0.1982, 0.5915, 0.4897, 0.7062],
          [0.4668, 0.7811, 0.8191, 0.5416],
          [0.1989, 0.2273, 0.0747, 0.9160],
          [0.6714, 0.8860, 0.6340, 0.7865]]]], requires_grad=True)
tensor([[[[ 3.,  9.,  7., 11.],
          [ 7., 12., 13.,  8.],
          [ 3.,  3.,  1., 14.],
          [10., 14., 10., 12.]]]], grad_fn=<QuantizerBackward>)


In [16]:
mask = generate_mask(qimp)
print(qimp)
print(mask)

tensor([[[[ 3.,  9.,  7., 11.],
          [ 7., 12., 13.,  8.],
          [ 3.,  3.,  1., 14.],
          [10., 14., 10., 12.]]]], grad_fn=<QuantizerBackward>)
tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         ...,

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], grad_fn=<MaskBackward>)


In [17]:
loss = mask.sum()
loss.backward()

In [18]:
imp_map.grad

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])

In [19]:
print(imp_map[0,0,0,0])
print(qimp[0,0,0,0])
print(mask[0,:,0,0])
print(mask[0,:,0,0].sum())

tensor(0.1982, grad_fn=<SelectBackward>)
tensor(3., grad_fn=<SelectBackward>)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SelectBackward>)
tensor(12., grad_fn=<SumBackward0>)


In [20]:
mask.size()

torch.Size([1, 64, 4, 4])

In [21]:
!python notebook2script.py 02_decoder.ipynb

Converted 02_decoder.ipynb to exp/nb_02.py
