In [1]:
import torch
from torch import nn

class Bottleneck(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, bias=False, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes, kernel_size=3, bias=False, padding=1)
        self.bn3 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        # SE
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_down = nn.Conv2d(
            planes, planes // 4, kernel_size=1, bias=False)
        self.conv_up = nn.Conv2d(
            planes // 4, planes, kernel_size=1, bias=False)
        self.sig = nn.Sigmoid()
        # Downsample
        self.downsample = nn.Conv2d(inplanes, planes,
                          kernel_size=1, stride=stride, bias=False)

        self.bn4 = nn.BatchNorm2d(planes)
        self.stride = stride


    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out1 = self.global_pool(out)
        out1 = self.conv_down(out1)
        out1 = self.relu(out1)
        out1 = self.conv_up(out1)
        out1 = self.sig(out1)

        if self.downsample is not None:
            residual = self.downsample(x)
            residual = self.bn4(residual)

        res = out1 * out + residual

        return res


In [74]:
class SEResNet(nn.Module):
    def __init__(self, in_channels):
        super(SEResNet,self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=7, stride=2, padding=3,
                                bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        return x

In [9]:
from thop import profile

dummy_images = torch.rand(1, 32, 320, 80)
flops, params = profile(Bottleneck(64, 128, 2), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.Sigmoid'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.Bottleneck'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 1.43986304G


In [112]:
dummy_images = torch.rand(1, 10, 240, 320)
flops, params = profile(SEResNet(10), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[91m[WARN] Cannot find rule for <class '__main__.SEResNet'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.6045696G


In [31]:
import math
from torch.autograd import Variable

class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, stride):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel for both cnn and rnn.
        cnn_dropout, rnn_dropout: float
            cnn_dropout: dropout rate for convolutional input.
            rnn_dropout: dropout rate for convolutional state.
        bias: bool
            Whether or not to add the bias.
        peephole: bool
            add connection between cell state to gates
        layer_norm: bool
            layer normalization 
        """

        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = int(self.kernel_size/2)
        self.stride = stride
        self.bias = True
        
        self.input_conv = nn.Conv2d(in_channels=self.input_dim, out_channels=4*self.hidden_dim,
                                  kernel_size=self.kernel_size,
                                  stride = self.stride,
                                  padding=self.padding,
                                  bias=self.bias)
        self.rnn_conv = nn.Conv2d(self.hidden_dim, out_channels=4*self.hidden_dim, 
                                  kernel_size = self.kernel_size,
                                  padding=math.floor(self.kernel_size/2),
                                  bias=self.bias)
        
    
    def forward(self, x):

        x_conv = self.input_conv(x)

        self.h = Variable(torch.zeros((x_conv.shape[0], self.hidden_dim, x_conv.shape[2], x_conv.shape[3])).to(x.device))
        self.c = Variable(torch.zeros((x_conv.shape[0], self.hidden_dim, x_conv.shape[2], x_conv.shape[3])).to(x.device))
        
        h_cur, c_cur = self.h, self.c

        # separate i, f, c o
        x_i, x_f, x_c, x_o = torch.split(x_conv, self.hidden_dim, dim=1)
        
        h_conv = self.rnn_conv(h_cur)
        # separate i, f, c o
        h_i, h_f, h_c, h_o = torch.split(h_conv, self.hidden_dim, dim=1)
        
        f = torch.sigmoid((x_f + h_f))  #H/2 * W/2 * 256
        i = torch.sigmoid((x_i + h_i))  #H/2 * W/2 * 256
        
        g = torch.tanh((x_c + h_c))  #H/2 * W/2 * 256
        c_next = f * c_cur + i * g  #3 * H/2 * W/2 * 256

        o = torch.sigmoid((x_o + h_o))  #H/2 * W/2 * 256

        h_next = o * torch.tanh(c_next)  #H/2 * W/2 * 256

        self.h = h_next
        self.c = c_next
        self.has_memory = True

        return h_next   #H/2 * W/2 * 256 * 8

In [38]:
dummy_images = torch.rand(1, 256, 2, 3)
flops, params = profile(ConvLSTMCell(256,256,3,2), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[91m[WARN] Cannot find rule for <class '__main__.ConvLSTMCell'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.01888256G


In [50]:
class BoxPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.cls_header = nn.Conv2d(256, 6 * 3, kernel_size=3, stride=1, padding=1)
        self.reg_header = nn.Conv2d(256, 6 * 4, kernel_size=3, stride=1, padding=1)
    
    def forward(self,x):
        x1 = self.cls_header(x)
        x2 = self.reg_header(x)

In [55]:
dummy_images = torch.rand(1, 256, 1, 1)
flops, params = profile(BoxPredictor(), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[91m[WARN] Cannot find rule for <class '__main__.BoxPredictor'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.00019362G


In [56]:
class Focus(nn.Module):
    """Focus width and height information into channel space."""

    def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="gelu"):
        super().__init__()
        pad = (ksize - 1) // 2
        self.conv = nn.Conv2d(
            in_channels * 4,
            out_channels,
            kernel_size=ksize,
            stride=stride,
            padding=pad,
            groups=1,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def patch_and_conv(self, x, conv):
        # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
        patch_top_left = x[..., ::2, ::2]
        patch_top_right = x[..., ::2, 1::2]
        patch_bot_left = x[..., 1::2, ::2]
        patch_bot_right = x[..., 1::2, 1::2]
        x = torch.cat(
            (
                patch_top_left,
                patch_bot_left,
                patch_top_right,
                patch_bot_right,
            ),
            dim=1,
        )
        return self.bn(conv(x))
    
    def forward(self, x):
        return self.patch_and_conv(x, self.conv)

In [107]:
dummy_images = torch.rand(1, 10, 512, 640)
flops, params = profile(Focus(10, 64, 3), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[91m[WARN] Cannot find rule for <class '__main__.Focus'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 3.79584512G


In [68]:
class dark1(nn.Module):
    """Focus width and height information into channel space."""

    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        pad = (3 - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=pad,
            groups=1,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels//2,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels//2)
        self.conv3 = nn.Conv2d(
            out_channels//2,
            out_channels,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.bn3 = nn.BatchNorm2d(out_channels)

    
    def forward(self, x):
        return self.bn3(self.conv3(self.bn2(self.conv2(self.bn(self.conv(x))))))

In [70]:
dummy_images = torch.rand(1, 256, 16, 20)
flops, params = profile(dark1(256,256,2), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[91m[WARN] Cannot find rule for <class '__main__.dark1'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.128098304G


In [71]:
class dark2(nn.Module):
    """Focus width and height information into channel space."""

    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        pad = (3 - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=pad,
            groups=1,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels//2,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels//2)
        self.conv3 = nn.Conv2d(
            out_channels//2,
            out_channels,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.conv4 = nn.Conv2d(
            out_channels,
            out_channels//2,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.bn4 = nn.BatchNorm2d(out_channels//2)
        self.conv5 = nn.Conv2d(
            out_channels//2,
            out_channels,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.bn5 = nn.BatchNorm2d(out_channels)

    
    def forward(self, x):
        return self.bn5(self.conv5(self.bn4(self.conv4(self.bn3(self.conv3(self.bn2(self.conv2(self.bn(self.conv(x))))))))))

In [73]:
dummy_images = torch.rand(1, 256, 32, 40)
flops, params = profile(dark2(256,256,2), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[91m[WARN] Cannot find rule for <class '__main__.dark2'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.677931008G


In [78]:
class spp_block(nn.Module):
    """Focus width and height information into channel space."""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        hidden_channels = out_channels // 2
        self.bconv1 = nn.Conv2d(
            in_channels,
            hidden_channels,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.bbn1 = nn.BatchNorm2d(hidden_channels)

        self.p1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)
        self.p2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2)
        self.p3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2)

        conv2_channels = hidden_channels * 4

        self.bconv2 = nn.Conv2d(
            conv2_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.bbn2 = nn.BatchNorm2d(out_channels)
        
        self.conv4 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.bn4 = nn.BatchNorm2d(out_channels)
        self.conv5 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.bn5 = nn.BatchNorm2d(out_channels)

    
    def forward(self, x):
        x = self.bn(self.conv(x))
        x = self.bn2(self.conv2(x))
        x = self.bbn1(self.bconv1(x))
        x1 = self.p1(x)
        x2 = self.p2(x)
        x3 = self.p3(x)
        x = self.bbn2(self.bconv2(torch.cat([x, x1, x2, x3], dim=1)))
        x = self.conv2(x)
        return x

In [79]:
dummy_images = torch.rand(1, 256, 8, 10)
flops, params = profile(spp_block(256,256), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[91m[WARN] Cannot find rule for <class '__main__.spp_block'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.111353856G


In [121]:
class fpn(nn.Module):
    """Focus width and height information into channel space."""

    def __init__(self):
        super().__init__()
        self.lc0 = nn.Conv2d(
            256,
            256,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.lc0bn = nn.BatchNorm2d(256)
        
        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")

        self.c3p4c1 = nn.Conv2d(
            512,
            256,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.c3p4c1bn = nn.BatchNorm2d(256)
        self.c3p4c2 = nn.Conv2d(
            512,
            256,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.c3p4c2bn = nn.BatchNorm2d(256)
        self.c3p4bc1 = nn.Conv2d(
            256,
            256,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.c3p4bc1bn = nn.BatchNorm2d(256)
        self.c3p4bc2 = nn.Conv2d(
            256,
            256,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.c3p4bc2bn = nn.BatchNorm2d(256)
        self.c3p4c3 = nn.Conv2d(
            512,
            256,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.c3p4c3bn = nn.BatchNorm2d(256)
        
        self.rc1 = nn.Conv2d(
            256,
            256,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.rc1bn = nn.BatchNorm2d(256)

        self.buc2 = nn.Conv2d(
            256,
            256,
            kernel_size=3,
            padding=1,
            stride=2,
            bias=False,
        )
        self.buc2bn = nn.BatchNorm2d(256)

        self.buc1 = nn.Conv2d(
            256,
            256,
            kernel_size=3,
            padding=1,
            stride=2,
            bias=False,
        )
        self.buc1bn = nn.BatchNorm2d(256)

    
    def forward(self, features):
        [x2, x1, x0] = features

        fpn_out0 = self.lc0bn(self.lc0(x0))  # 
        f_out0 = self.upsample(fpn_out0)  # 
        f_out0 = torch.cat([f_out0, x1], 1)  # 
        x_1 = self.c3p4c1bn(self.c3p4c1(f_out0))
        x_2 = self.c3p4c2bn(self.c3p4c2(f_out0))
        x_1 = self.c3p4bc1bn(self.c3p4bc1(x_1))
        x_1 = self.c3p4bc2bn(self.c3p4bc2(x_1))
        x = torch.cat((x_1, x_2), dim=1)
        f_out0 = self.c3p4c3bn(self.c3p4c3(x))

        fpn_out1 = self.rc1bn(self.rc1(f_out0))  # 
        f_out1 = self.upsample(fpn_out1)  # 
        f_out1 = torch.cat([f_out1, x2], 1)  # 
        x_1 = self.c3p4c1bn(self.c3p4c1(f_out1))
        x_2 = self.c3p4c2bn(self.c3p4c2(f_out1))
        x_1 = self.c3p4bc1bn(self.c3p4bc1(x_1))
        x_1 = self.c3p4bc2bn(self.c3p4bc2(x_1))
        x = torch.cat((x_1, x_2), dim=1)
        pan_out2 = self.c3p4c3bn(self.c3p4c3(x))

        p_out1 = self.buc2bn(self.buc2(pan_out2))  # 
        p_out1 = torch.cat([p_out1, fpn_out1], 1)  # 
        x_1 = self.c3p4c1bn(self.c3p4c1(p_out1))
        x_2 = self.c3p4c2bn(self.c3p4c2(p_out1))
        x_1 = self.c3p4bc1bn(self.c3p4bc1(x_1))
        x_1 = self.c3p4bc2bn(self.c3p4bc2(x_1))
        x = torch.cat((x_1, x_2), dim=1)
        pan_out1 = self.c3p4c3bn(self.c3p4c3(x))

        p_out0 = self.buc1bn(self.buc1(pan_out1))  # 
        p_out0 = torch.cat([p_out0, fpn_out0], 1)  # 
        x_1 = self.c3p4c1bn(self.c3p4c1(p_out0))
        x_2 = self.c3p4c2bn(self.c3p4c2(p_out0))
        x_1 = self.c3p4bc1bn(self.c3p4bc1(x_1))
        x_1 = self.c3p4bc2bn(self.c3p4bc2(x_1))
        x = torch.cat((x_1, x_2), dim=1)
        pan_out0 = self.c3p4c3bn(self.c3p4c3(x))

        outputs = [pan_out2, pan_out1, pan_out0]
        return outputs

In [122]:
dummy_images = [torch.rand(1, 256, 32, 40),torch.rand(1, 256, 16, 20),torch.rand(1, 256, 8, 10)]
flops, params = profile(fpn(), inputs=[dummy_images])
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>.
[91m[WARN] Cannot find rule for <class '__main__.fpn'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 4.7296512G


In [115]:
class fpn(nn.Module):
    """Focus width and height information into channel space."""

    def __init__(self):
        super().__init__()
        self.stem = nn.Conv2d(
            256,
            256,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.stembn = nn.BatchNorm2d(256)
        
        self.cls_conv1 = nn.Conv2d(
            256,
            256,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.cls_conv1bn = nn.BatchNorm2d(256)
        self.cls_conv2 = nn.Conv2d(
            256,
            256,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.cls_conv2bn = nn.BatchNorm2d(256)

        self.reg_conv1 = nn.Conv2d(
            256,
            256,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.reg_conv1bn = nn.BatchNorm2d(256)
        self.reg_conv2 = nn.Conv2d(
            256,
            256,
            kernel_size=3,
            stride=1,
            bias=False,
        )
        self.reg_conv2bn = nn.BatchNorm2d(256)

        self.cls_pred = nn.Conv2d(
            256,
            3,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.reg_pred = nn.Conv2d(
            256,
            4,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.obj_pred = nn.Conv2d(
            256,
            1,
            kernel_size=1,
            stride=1,
            bias=False,
        )

    
    def forward(self, x):
        x = self.stembn(self.stem(x))
        cls_x = x
        reg_x = x

        cls_feat = self.cls_conv2bn(self.cls_conv2(self.cls_conv1bn(self.cls_conv1(cls_x))))
        cls_output = self.cls_pred(cls_feat)

        reg_feat = self.reg_conv2bn(self.reg_conv2(self.cls_conv1bn(self.cls_conv1(reg_x))))
        reg_output = self.reg_pred(reg_feat)

        obj_output = self.obj_pred(reg_feat)

In [116]:
dummy_images = [torch.rand(1, 256, 32, 40),torch.rand(1, 256, 16, 20),torch.rand(1, 256, 8, 10)]
flops, params = profile(fpn(), inputs=(dummy_images[0],))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[91m[WARN] Cannot find rule for <class '__main__.fpn'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 5.24537856G


In [108]:
class Temporal_Active_Focus(Focus):
    def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="gelu"):
        super().__init__(in_channels, out_channels, ksize, stride, act)
        #self.conv2 = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)

        self.conv3 = nn.Conv2d(in_channels * 4, in_channels * 2, 1, 1)
        self.relu = nn.ReLU()
        self.conv4 = nn.Conv2d(in_channels * 2, in_channels * 4, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.conv4(self.relu(self.conv3(torch.exp(x)))))

In [111]:
dummy_images = torch.rand(1, 40, 128, 160)
flops, params = profile(Temporal_Active_Focus(10, 64, 3), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.Sigmoid'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.Temporal_Active_Focus'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.0679936G
