In [1]:
import os
import numpy as np
import torch
import time
import math
import torch
import collections
from torch import nn
from PIL import Image
from argparse import ArgumentParser
from torch.nn.modules.utils import _pair
from collections import OrderedDict
from torch.autograd import Variable

import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets, models
cudnn.benchmark = True

from torch.nn import init
try:
    torch._utils._rebuild_tensor_v2
except AttributeError:
    def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
        tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
        tensor.requires_grad = requires_grad
        tensor._backward_hooks = backward_hooks
        return tensor
    torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

In [2]:
torch.__version__

'0.3.1'

In [3]:
#加载预训练的mobilenetv2
def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, dalited):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        self.d = dalited
        assert stride in [1, 2]

        self.use_res_connect = self.stride == 1 and inp == oup

        self.conv = nn.Sequential(
            # pw
            nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
            nn.BatchNorm2d(inp * expand_ratio),
            nn.ReLU6(inplace=True),
            # dw
            nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, padding=self.d, dilation=self.d, groups=inp * expand_ratio, bias=False),
            nn.BatchNorm2d(inp * expand_ratio),
            nn.ReLU6(inplace=True),
            # pw-linear
            nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        # setting of inverted residual blocks
        self.interverted_residual_setting = [
            # t, c, n, s, d
            [1, 16, 1, 1, 1],    # 1/2
            [6, 24, 2, 2, 1],    # 1/4
            [6, 32, 3, 2, 1],    # 1/8
            [6, 64, 4, 1, 2],    # 1/8
            [6, 96, 3, 1, 3],    # 1/8
            [6, 160, 3, 1, 5],   # 1/8
            [6, 320, 1, 1, 7],  # 1/8
        ]

        # building first layer
        assert input_size % 32 == 0
        input_channel = int(32 * width_mult)
        self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s ,d in self.interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(InvertedResidual(input_channel, output_channel, s, t,d))
                else:
                    self.features.append(InvertedResidual(input_channel, output_channel, 1, t,d))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        self.features.append(nn.AvgPool2d(input_size/32))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, self.last_channel)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
                
                
net = MobileNetV2(n_class=1000)
mobilenet_v2 = list(net.features.children())
len(mobilenet_v2)

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    assert (num_channels % groups == 0)
    channels_per_group = num_channels // groups
    # reshape
    x = x.view(batchsize, groups, channels_per_group, height, width)
    # transpose
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)
    return x

class ShuffleBlock(nn.Module):
    def __init__(self, groups):
        super(ShuffleBlock, self).__init__()
        self.groups = groups

    def forward(self, x):
        '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
        N,C,H,W = x.size()
        g = self.groups
        return x.view(N,g,int(C/g),H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W)


class ABN(nn.Sequential):
    def __init__(self, num_features):
        super(ABN, self).__init__(OrderedDict([
            ("bn",  nn.BatchNorm2d(num_features,eps=1e-05, momentum=0.1, affine=True)),
            ("act", nn.PReLU(num_features))
        ]))

class DSP(nn.Module):
    def __init__(self, inplanes, outplanes, c_tag=0.2, groups=4, dilation=(1,2,3,4)):
        super(DSP, self).__init__()
        
        self.out_c = round(c_tag * outplanes)
        
        self.down = nn.Sequential(
                nn.Conv2d(inplanes, self.out_c, 1, stride=1, groups=groups, bias=False),
                ABN(self.out_c)
          )
        
        self.pool =  nn.Sequential(
                                nn.AdaptiveAvgPool2d((1, 1)),
                                nn.Conv2d(self.out_c, self.out_c, 1, stride=1),
                                nn.BatchNorm2d(self.out_c,eps=1e-05, momentum=0.1, affine=True)
            )
        
        self.branch_1 = nn.Sequential(
                nn.Conv2d(self.out_c, self.out_c, kernel_size=3, padding=dilation[0], dilation = dilation[0],groups=self.out_c, bias=False),
                ABN(self.out_c),
                nn.Conv2d(self.out_c, self.out_c, kernel_size=1, bias=False),
                nn.BatchNorm2d(self.out_c,eps=1e-05, momentum=0.1, affine=True)
        )
        self.branch_2 = nn.Sequential(
                nn.Conv2d(self.out_c, self.out_c, kernel_size=3, padding=dilation[1],dilation = dilation[1],groups=self.out_c, bias=False),
                ABN(self.out_c),
                nn.Conv2d(self.out_c, self.out_c, kernel_size=1, bias=False),
                nn.BatchNorm2d(self.out_c,eps=1e-05, momentum=0.1, affine=True)
            )
        self.branch_3 = nn.Sequential(
                nn.Conv2d(self.out_c, self.out_c, kernel_size=3, padding=dilation[2],dilation = dilation[2],groups=self.out_c, bias=False),
                ABN(self.out_c),
                nn.Conv2d(self.out_c, self.out_c, kernel_size=1, bias=False),
                nn.BatchNorm2d(self.out_c,eps=1e-05, momentum=0.1, affine=True)
        )
        
        self.branch_4 = nn.Sequential(
                nn.Conv2d(self.out_c, self.out_c, kernel_size=3, padding=dilation[3],dilation = dilation[3],groups=self.out_c, bias=False),
                ABN(self.out_c),
                nn.Conv2d(self.out_c, self.out_c, kernel_size=1, bias=False),
                nn.BatchNorm2d(self.out_c,eps=1e-05, momentum=0.1, affine=True)
        )
       
        self.groups = groups
        
        self.module_act = nn.PReLU(outplanes)
      
    def forward(self, x):
        input_x = x
        x_size = x.size()
        x= self.down(x)
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        branch_3 = self.branch_3(x)
        branch_4 = self.branch_4(x)
        pool = F.upsample(self.pool(x),x_size[2:], mode= "bilinear")
        out = channel_shuffle(torch.cat((branch_1,branch_2,branch_3,branch_4,pool), 1), self.groups)
        
        if out.size() == input_x.size():
            out = out + input_x
        return self.module_act(out)
    
class MCIM(nn.Module):
    def __init__(self, in_chs = 320, out_chs = 320):
        super(MCIM, self).__init__()
        
        self.pool =  nn.Sequential(
                                nn.AdaptiveAvgPool2d((1, 1)),
                                nn.Conv2d(in_chs, 80, 1, stride=1),
                                nn.BatchNorm2d(80),
                                nn.LeakyReLU(0.1)
            )
        
        self.conv_small = DSP(in_chs, out_chs, dilation = (1,2,3,5))
        
        self.conv_middle = DSP(out_chs, out_chs, dilation = (7,9,11,13))

        self.conv_larger = DSP(out_chs, out_chs, dilation = (17,19,21,23))
   
    def forward(self, x):
        x_size = x.size()
        small = self.conv_small(x)
        middle = self.conv_middle(small)
        larger = self.conv_larger(middle)
        pool = F.upsample(self.pool(x),x_size[2:], mode= "bilinear")
        output = small+middle+larger
        output = torch.cat([output,pool],1)
        return output
    
    
#语义分割模型 mobilenet_v2
def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

class Attention_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(Attention_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.channel_excitation = nn.Sequential(nn.Linear(channel, int(channel//reduction)),
                                                nn.ReLU(inplace=True),                                             
                                                nn.Linear(int(channel//reduction), channel),
                                                nn.Sigmoid())
    def forward(self, x):
        bahs, chs, _, _ = x.size()
        chn_se = self.avg_pool(x).view(bahs, chs)        
        return self.channel_excitation(chn_se).view(bahs, chs, 1, 1)

                
class M2_semantic(nn.Module):
    def __init__(self, num_classes=19):
        super(M2_semantic, self).__init__()
        
        # building inverted residual blocks
        self.mod1 = mobilenet_v2[0]
        self.mod2 = mobilenet_v2[1]
        self.mod3 = nn.Sequential(mobilenet_v2[2],mobilenet_v2[3])
        self.mod4 = nn.Sequential(mobilenet_v2[4],mobilenet_v2[5],mobilenet_v2[6])
        self.mod5 = nn.Sequential(mobilenet_v2[7],mobilenet_v2[8],mobilenet_v2[9],mobilenet_v2[10])
        self.mod6 = nn.Sequential(mobilenet_v2[11],mobilenet_v2[12],mobilenet_v2[13])
        self.mod7 = nn.Sequential(mobilenet_v2[14],mobilenet_v2[15],mobilenet_v2[16])
        self.mod8 = mobilenet_v2[17]
        
        self.LRM =  nn.Sequential(
                        nn.Conv2d(32, 32, kernel_size=3, padding = 1, groups = 32), 
                        ABN(32),
                        nn.Conv2d(32, 160, kernel_size=1),    
                        ABN(160)
                            )
        
        self.multi_scale = MCIM(320, 320)
        
        self.attention = Attention_Block(160)
        
        self.final = nn.Sequential(
            nn.Conv2d(400, 256, kernel_size = 1, stride = 1,groups =16),
            ShuffleBlock(16),
            nn.BatchNorm2d(256,eps=1e-05, momentum=0.1, affine=True),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.2),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )
        initialize_weights(self.LRM, self.attention, self.multi_scale, self.final)

    def forward(self, x):
        x_size = x.size()  
        stg0 = self.mod1(x)         # torch.Size([1, 32, 112, 224])
        
        stg1 = self.mod2(stg0)      # torch.Size([1, 16, 112, 224])
        stg2 = self.mod3(stg1)      # torch.Size([1, 24, 56, 112])     
        stg3 = self.mod4(stg2)  # torch.Size([1, 32, 28, 56]) 
        
        LRM = self.LRM(stg3)  # torch.Size([1, 32, 28, 56])
        
        stg4 = self.mod5(stg3)  # torch.Size([1, 64, 28, 56])
        stg5 = self.mod6(stg4) # torch.Size([1, 96, 28, 56])   
        stg6 = self.mod7(stg5)  # torch.Size([1, 160, 28, 56])
        
        attention = self.attention(stg6)
        
        stg6 = torch.mul(attention,LRM) + stg6
        
        stg7 = self.mod8(stg6)  # torch.Size([1, 320, 28, 56])

        multi_scale = self.multi_scale(stg7)
        
#         print(modified_aspp.size())
        out = self.final(multi_scale)
        
        return F.upsample(out, x_size[2:], mode='bilinear')

In [4]:
def main():

    model = M2_semantic(num_classes=19).cuda()
    model.eval()

#     images = torch.randn(1, 3, 1024, 448)
    images = torch.randn(1, 3, 1024, 512)
#     images = torch.randn(1, 3, 713, 713)
#     images = torch.randn(1, 3, 512, 512)
#     images = torch.randn(1, 3, 640, 360)
  
    images = images.cuda()#.half()

    time_train = []

    i=0

    while(True):
        if i == 30:
            break
            
        start_time = time.time()

        inputs = Variable(images, volatile=True)
        outputs = model(inputs)


        torch.cuda.synchronize()    #wait for cuda to finish (cuda is asynchronous!)

        if i!=0:    #first run always takes some time for setup
            fwt = time.time() - start_time
            time_train.append(fwt)
            print ("Forward time per img (b=%d): %.3f (Mean: %.3f)" % (1, fwt/1, sum(time_train) / len(time_train) /1))
        
        time.sleep(1)   #to avoid overheating the GPU too much
        i+=1

In [5]:
#     images = torch.randn(1, 3, 640, 360)
main()

Forward time per img (b=1): 0.017 (Mean: 0.017)
Forward time per img (b=1): 0.017 (Mean: 0.017)
Forward time per img (b=1): 0.016 (Mean: 0.016)
Forward time per img (b=1): 0.016 (Mean: 0.016)
Forward time per img (b=1): 0.016 (Mean: 0.016)
Forward time per img (b=1): 0.016 (Mean: 0.016)
Forward time per img (b=1): 0.016 (Mean: 0.016)
Forward time per img (b=1): 0.013 (Mean: 0.016)
Forward time per img (b=1): 0.013 (Mean: 0.016)
Forward time per img (b=1): 0.013 (Mean: 0.015)
Forward time per img (b=1): 0.013 (Mean: 0.015)
Forward time per img (b=1): 0.013 (Mean: 0.015)
Forward time per img (b=1): 0.013 (Mean: 0.015)
Forward time per img (b=1): 0.016 (Mean: 0.015)
Forward time per img (b=1): 0.016 (Mean: 0.015)
Forward time per img (b=1): 0.016 (Mean: 0.015)
Forward time per img (b=1): 0.016 (Mean: 0.015)
Forward time per img (b=1): 0.016 (Mean: 0.015)
Forward time per img (b=1): 0.016 (Mean: 0.015)
Forward time per img (b=1): 0.017 (Mean: 0.015)
Forward time per img (b=1): 0.030 (Mean:

In [5]:
#     images = torch.randn(1, 3, 512, 512)
main()

Forward time per img (b=1): 0.017 (Mean: 0.017)
Forward time per img (b=1): 0.017 (Mean: 0.017)
Forward time per img (b=1): 0.017 (Mean: 0.017)
Forward time per img (b=1): 0.017 (Mean: 0.017)
Forward time per img (b=1): 0.017 (Mean: 0.017)
Forward time per img (b=1): 0.031 (Mean: 0.019)
Forward time per img (b=1): 0.017 (Mean: 0.019)
Forward time per img (b=1): 0.017 (Mean: 0.019)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean: 0.018)
Forward time per img (b=1): 0.017 (Mean:

In [5]:
# images = torch.randn(1, 3, 713, 713)
main()

Forward time per img (b=1): 0.031 (Mean: 0.031)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean: 0.030)
Forward time per img (b=1): 0.030 (Mean:

In [5]:
# images = torch.randn(1, 3, 1024, 448)
main()

Forward time per img (b=1): 0.028 (Mean: 0.028)
Forward time per img (b=1): 0.026 (Mean: 0.027)
Forward time per img (b=1): 0.026 (Mean: 0.027)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.026 (Mean: 0.026)
Forward time per img (b=1): 0.028 (Mean:

In [5]:
# images = torch.randn(1, 3, 1024, 512)
main()

Forward time per img (b=1): 0.031 (Mean: 0.031)
Forward time per img (b=1): 0.029 (Mean: 0.030)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.029 (Mean: 0.029)
Forward time per img (b=1): 0.031 (Mean: