In [1]:
import torch
from torch import nn

class Bottleneck(nn.Module):

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, bias=False, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes, kernel_size=3, bias=False, padding=1)
        self.bn3 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        # SE
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_down = nn.Conv2d(
            planes, planes // 4, kernel_size=1, bias=False)
        self.conv_up = nn.Conv2d(
            planes // 4, planes, kernel_size=1, bias=False)
        self.sig = nn.Sigmoid()
        # Downsample
        self.downsample = nn.Conv2d(inplanes, planes,
                          kernel_size=1, stride=stride, bias=False)

        self.bn4 = nn.BatchNorm2d(planes)
        self.stride = stride


    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out1 = self.global_pool(out)
        out1 = self.conv_down(out1)
        out1 = self.relu(out1)
        out1 = self.conv_up(out1)
        out1 = self.sig(out1)

        if self.downsample is not None:
            residual = self.downsample(x)
            residual = self.bn4(residual)

        res = out1 * out + residual

        return res


In [12]:
class SEResNet(nn.Module):
    def __init__(self, in_channels):
        super(SEResNet,self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=7, stride=2, padding=3,
                                bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        return x

In [9]:
from thop import profile

dummy_images = torch.rand(1, 64, 60, 80)
flops, params = profile(Bottleneck(64, 128, 2), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.Sigmoid'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.Bottleneck'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 1.43986304G


In [13]:
dummy_images = torch.rand(1, 10, 480, 640)
flops, params = profile(SEResNet(10), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class '__main__.SEResNet'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 2.4182784G


In [31]:
import math
from torch.autograd import Variable

class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, stride):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel for both cnn and rnn.
        cnn_dropout, rnn_dropout: float
            cnn_dropout: dropout rate for convolutional input.
            rnn_dropout: dropout rate for convolutional state.
        bias: bool
            Whether or not to add the bias.
        peephole: bool
            add connection between cell state to gates
        layer_norm: bool
            layer normalization 
        """

        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = int(self.kernel_size/2)
        self.stride = stride
        self.bias = True
        
        self.input_conv = nn.Conv2d(in_channels=self.input_dim, out_channels=4*self.hidden_dim,
                                  kernel_size=self.kernel_size,
                                  stride = self.stride,
                                  padding=self.padding,
                                  bias=self.bias)
        self.rnn_conv = nn.Conv2d(self.hidden_dim, out_channels=4*self.hidden_dim, 
                                  kernel_size = self.kernel_size,
                                  padding=math.floor(self.kernel_size/2),
                                  bias=self.bias)
        
    
    def forward(self, x):

        x_conv = self.input_conv(x)

        self.h = Variable(torch.zeros((x_conv.shape[0], self.hidden_dim, x_conv.shape[2], x_conv.shape[3])).to(x.device))
        self.c = Variable(torch.zeros((x_conv.shape[0], self.hidden_dim, x_conv.shape[2], x_conv.shape[3])).to(x.device))
        
        h_cur, c_cur = self.h, self.c

        # separate i, f, c o
        x_i, x_f, x_c, x_o = torch.split(x_conv, self.hidden_dim, dim=1)
        
        h_conv = self.rnn_conv(h_cur)
        # separate i, f, c o
        h_i, h_f, h_c, h_o = torch.split(h_conv, self.hidden_dim, dim=1)
        
        f = torch.sigmoid((x_f + h_f))  #6 * (H/2 x W/2 x 256)
        i = torch.sigmoid((x_i + h_i))  #6 * (H/2 x W/2 x 256)
        
        g = torch.tanh((x_c + h_c))  #6 * (H/2 x W/2 x 256)
        c_next = f * c_cur + i * g

        o = torch.sigmoid((x_o + h_o))

        h_next = o * torch.tanh(c_next)

        self.h = h_next
        self.c = c_next
        self.has_memory = True

        return h_next

In [32]:
dummy_images = torch.rand(1, 256, 15, 20)
flops, params = profile(ConvLSTMCell(256,256,3,2), inputs=(dummy_images,))
print('FLOPs = ' + str(flops*2/1000**3) + 'G')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[91m[WARN] Cannot find rule for <class '__main__.ConvLSTMCell'>. Treat it as zero Macs and zero Params.[00m
FLOPs = 0.7553024G
