In [11]:
import torch
import torch.nn as nn
import torch.onnx

from torchsummary import summary

In [12]:
from deephar.layers import *

In [17]:
class ReceptionBlock(nn.Module):
    def __init__(self):
        super(ReceptionBlock, self).__init__()
        
        self.acb = ACB(input_filters=576, output_filters=288, kernel_size=(1,1), stride=(1,1), padding=0)
        
        # calculating padding using
        # padding_zeroes = (kernel_size - 1 ) / 2
        self.sacb1 = Residual(Sep_ACB(input_filters=288, output_filters=288, kernel_size=(5,5), stride=(1,1), padding=2))
        self.sacb2 = Residual(Sep_ACB(input_filters=288, output_filters=288, kernel_size=(5,5), stride=(1,1), padding=2))
        self.sacb3 = Residual(Sep_ACB(input_filters=288, output_filters=288, kernel_size=(5,5), stride=(1,1), padding=2))
        self.sacb4 = Residual(Sep_ACB(input_filters=288, output_filters=288, kernel_size=(5,5), stride=(1,1), padding=2))
        self.sacb5 = Residual(Sep_ACB(input_filters=288, output_filters=288, kernel_size=(5,5), stride=(1,1), padding=2))
        self.sacb6 = Residual(Sep_ACB(input_filters=576, output_filters=576, kernel_size=(5,5), stride=(1,1), padding=2))
        self.sacb7 = Sep_ACB(input_filters=288, output_filters=576, kernel_size=(5,5), stride=(1,1), padding=2)
        
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2,2))
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2))
                
    def forward(self, x):
        a = self.maxpool1(x)
        a = self.acb(a)
        print(a.shape)
        a = self.sacb1(a)
        
        b = self.maxpool2(a)
        b = self.sacb2(b)
        b = self.sacb3(b)
        b = self.sacb4(b)
        b = nn.functional.interpolate(b, scale_factor=2, mode="nearest") # Maybe align_corners needs to be set?
        
        b = b + self.sacb5(a)
        b = self.sacb7(b)
        b = nn.functional.interpolate(b, scale_factor=2, mode="nearest") # Maybe align_corners needs to be set?
        
        return b + self.sacb6(x)
        

In [18]:
summary(ReceptionBlock(), input_size=(576, 32, 32))

torch.Size([2, 288, 16, 16])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         MaxPool2d-1          [-1, 576, 16, 16]               0
              ReLU-2          [-1, 576, 16, 16]               0
            Conv2d-3          [-1, 288, 16, 16]         165,888
       BatchNorm2d-4          [-1, 288, 16, 16]             576
              ReLU-5          [-1, 288, 16, 16]               0
            Conv2d-6          [-1, 288, 16, 16]           2,592
            Conv2d-7          [-1, 288, 16, 16]          82,944
   SeparableConv2D-8          [-1, 288, 16, 16]               0
       BatchNorm2d-9          [-1, 288, 16, 16]             576
         Residual-10          [-1, 288, 16, 16]               0
        MaxPool2d-11            [-1, 288, 8, 8]               0
             ReLU-12            [-1, 288, 8, 8]               0
           Conv2d-13            [-1, 288, 8, 8]           2,592
          