In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_01 import *

In [3]:
#export
class DepthToSpace(torch.nn.Module):
    def __init__(self,block_size):
        super().__init__()
        self.bs = block_size
        
    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)
        return x

In [4]:
#export
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(conv(64,512,1,1,0), nn.ReLU(),
                                    ResBlock(512), nn.ReLU(),
                                    ResBlock(512), nn.ReLU(),
                                    DepthToSpace(2),
                                    conv(128,256), nn.ReLU(),
                                    ResBlock(256), nn.ReLU(),
                                    DepthToSpace(4),
                                    conv(16,32), nn.ReLU(),
                                    conv(32,3))
        
    def extra_repr(self):
        params = sum(p.numel() for p in self.parameters())
        return f'Total Params: {params}'
        
    def forward(self,x):
        return self.decoder(x)

In [5]:
dec = Decoder()
from torchinfo import summary
summary(dec, (1, 64,16,16), device="cpu")

Layer (type:depth-idx)                   Output Shape              Param #
Decoder                                  --                        --
├─Sequential: 1-1                        [1, 3, 128, 128]          --
│    └─Conv2d: 2-1                       [1, 512, 16, 16]          33,280
│    └─ReLU: 2-2                         [1, 512, 16, 16]          --
│    └─ResBlock: 2-3                     [1, 512, 16, 16]          --
│    │    └─Conv2d: 3-1                  [1, 128, 16, 16]          589,952
│    │    └─Conv2d: 3-2                  [1, 512, 16, 16]          590,336
│    └─ReLU: 2-4                         [1, 512, 16, 16]          --
│    └─ResBlock: 2-5                     [1, 512, 16, 16]          --
│    │    └─Conv2d: 3-3                  [1, 128, 16, 16]          589,952
│    │    └─Conv2d: 3-4                  [1, 512, 16, 16]          590,336
│    └─ReLU: 2-6                         [1, 512, 16, 16]          --
│    └─DepthToSpace: 2-7                 [1, 128, 32, 32]    

Entropy Based Layers

In [6]:
#export
class Quantizer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        p = i.clone()
        L = 16
        for l in range(L):
            p[(p>=l/L)*(p<(l+1)/L)] = l
        return p
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def quantize_values(x):
    return Quantizer.apply(x)

In [7]:
#export
class Mask(torch.autograd.Function):
    @staticmethod
    def forward(ctx,i):
        device = i.device
        N,_,H,W = i.shape
        n = 64
        L = 16
        mask = torch.zeros(n, N*H*W).to(device)
        qimp = i
        qimp_flat = qimp.view(1, N*H*W)
        for indx in range(n):
            mask[indx,:] = torch.where(indx < (n/L)*qimp_flat,torch.Tensor([1]).to(device),torch.Tensor([0]).to(device))
        mask = mask.view(n,N,H,W).permute((1,0,2,3))
        return mask

    @staticmethod
    def backward(ctx, grad_output):
        N,C,H,W = grad_output.shape
        return torch.ones(N,1,H,W).to(grad_output.device)

def generate_mask(x):
    return Mask.apply(x)

In [8]:
a = torch.rand(4,4,requires_grad=True)
print(a)

tensor([[0.4283, 0.8598, 0.1396, 0.5758],
        [0.5413, 0.2879, 0.2739, 0.4917],
        [0.2935, 0.0594, 0.9540, 0.7086],
        [0.8742, 0.8271, 0.9119, 0.7585]], requires_grad=True)


In [9]:
b = bin_values(a)
print(a)
print(b)

tensor([[0.4283, 0.8598, 0.1396, 0.5758],
        [0.5413, 0.2879, 0.2739, 0.4917],
        [0.2935, 0.0594, 0.9540, 0.7086],
        [0.8742, 0.8271, 0.9119, 0.7585]], requires_grad=True)
tensor([[0., 1., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 1., 1.],
        [1., 1., 1., 1.]], grad_fn=<BinarizerBackward>)


In [10]:
loss = b.sum()
print('Loss:{}'.format(loss))

Loss:9.0


In [11]:
loss.backward()
a.grad

  Variable._execution_engine.run_backward(


tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [12]:
imp_map = torch.rand(1,1,4,4,requires_grad=True)
print(imp_map)

tensor([[[[0.0903, 0.7760, 0.5508, 0.6340],
          [0.0921, 0.9516, 0.9845, 0.4161],
          [0.1060, 0.4921, 0.5821, 0.0448],
          [0.1448, 0.3033, 0.6726, 0.7270]]]], requires_grad=True)


In [13]:
qimp = quantize_values(imp_map)
print(imp_map)
print(qimp)

tensor([[[[0.0903, 0.7760, 0.5508, 0.6340],
          [0.0921, 0.9516, 0.9845, 0.4161],
          [0.1060, 0.4921, 0.5821, 0.0448],
          [0.1448, 0.3033, 0.6726, 0.7270]]]], requires_grad=True)
tensor([[[[ 1., 12.,  8., 10.],
          [ 1., 15., 15.,  6.],
          [ 1.,  7.,  9.,  0.],
          [ 2.,  4., 10., 11.]]]], grad_fn=<QuantizerBackward>)


In [14]:
mask = generate_mask(qimp)
print(qimp)
print(mask)

tensor([[[[ 1., 12.,  8., 10.],
          [ 1., 15., 15.,  6.],
          [ 1.,  7.,  9.,  0.],
          [ 2.,  4., 10., 11.]]]], grad_fn=<QuantizerBackward>)
tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 0.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 0.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 0.],
          [1., 1., 1., 1.]],

         ...,

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], grad_fn=<MaskBackward>)


In [15]:
loss = mask.sum()
loss.backward()

In [16]:
imp_map.grad

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])

In [17]:
print(imp_map[0,0,0,0])
print(qimp[0,0,0,0])
print(mask[0,:,0,0])
print(mask[0,:,0,0].sum())

tensor(0.0903, grad_fn=<SelectBackward0>)
tensor(1., grad_fn=<SelectBackward0>)
tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SelectBackward0>)
tensor(4., grad_fn=<SumBackward0>)


In [18]:
mask.size()

torch.Size([1, 64, 4, 4])

In [19]:
!python notebook2script.py 02_decoder.ipynb

Converted 02_decoder.ipynb to exp/nb_02.py
