In [1]:

import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import models
import torchvision
import torch.nn.functional as F
import numpy as np
from IPython.core.debugger import set_trace

from torch.utils import model_zoo
#import deeplab_resnet
from torch.autograd import Variable
import scipy.misc
from PIL import Image



def initialize_weights(*models):
	for model in models:
		for module in model.modules():
			if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
				nn.init.kaiming_normal(module.weight)
				if module.bias is not None:
					module.bias.data.zero_()
			elif isinstance(module, nn.BatchNorm2d):
				module.weight.data.fill_(1)
				module.bias.data.zero_()

class PyramidPool(nn.Module):

	def __init__(self, in_features, out_features, pool_size):
		super(PyramidPool,self).__init__()

		self.features = nn.Sequential(
			nn.AdaptiveAvgPool2d(pool_size),
			nn.Conv2d(in_features, out_features, 1, bias=False),
			nn.BatchNorm2d(out_features, momentum=.95),
			nn.ReLU(inplace=True)
		)


	def forward(self, x):
		size=x.size()
		output=F.upsample(self.features(x), size[2:], mode='bilinear')
		return output
    
class ConvBnRelu(nn.Module):
    def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
                 groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, bn_eps=1e-5,
                 has_relu=True, inplace=True, has_bias=False):
        super(ConvBnRelu, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
                              stride=stride, padding=pad,
                              dilation=dilation, groups=groups, bias=has_bias)
        self.has_bn = has_bn
        if self.has_bn:
            self.bn = norm_layer(out_planes, eps=bn_eps)
        self.has_relu = has_relu
        if self.has_relu:
            self.relu = nn.ReLU(inplace=inplace)

    def forward(self, x):
        x = self.conv(x)
        if self.has_bn:
            x = self.bn(x)
        if self.has_relu:
            x = self.relu(x)

        return x

class PSPNet(nn.Module):

    def __init__(self, num_classes, pretrained = False):
        super(PSPNet,self).__init__()
        print("initializing model")
        #init_net=deeplab_resnet.Res_Deeplab()
        #state=torch.load("models/MS_DeepLab_resnet_trained_VOC.pth")
        #init_net.load_state_dict(state)
        self.resnet = torchvision.models.resnet101(pretrained = pretrained)


        self.layer5a = PyramidPool(2048, 512, 1)
        self.layer5b = PyramidPool(2048, 512, 2)
        self.layer5c = PyramidPool(2048, 512, 3)
        self.layer5d = PyramidPool(2048, 512, 6)




        self.final = nn.Sequential(
        	nn.Conv2d(4096, 512, 3, padding=1, bias=False),
        	nn.BatchNorm2d(512, momentum=.95),
        	nn.ReLU(inplace=True),
        	nn.Dropout(.1),
        	nn.Conv2d(512, num_classes, 1),
        )
        self.final2 = nn.Sequential(
        	nn.Conv2d(4, 256, 3, padding=1, bias=False),
        	nn.BatchNorm2d(256, momentum=.95),
        	nn.ReLU(inplace=True),
        	nn.Dropout(.1),
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
        	nn.BatchNorm2d(512, momentum=.95),
        	nn.ReLU(inplace=True),
        	nn.Dropout(.1),
        	nn.Conv2d(512, num_classes, 1),
        )

        initialize_weights(self.layer5a,self.layer5b,self.layer5c,self.layer5d,self.final)
        
        
        pool_scales=[1, 2, 3, 6]
        self.conv6 = nn.Sequential(
            ConvBnRelu(4096 + len(pool_scales) * 512, 512, 3, 1, 1,
                       has_bn=True,
                       has_relu=True, has_bias=False, norm_layer=nn.BatchNorm2d),
            nn.Dropout2d(0.1, inplace=False),
            nn.Conv2d(512, 2, kernel_size=1)
        )
        
        chann = 2048
        
        self.aux_branch = nn.Sequential(
        	nn.Conv2d(chann//4, chann//2, 3, padding=1, bias=False),
        	nn.BatchNorm2d(chann//2),
        	nn.ReLU(inplace=True),
        	nn.Dropout(.1),
        	nn.Conv2d(chann//2, num_classes, 1),
        )


    def forward(self, x):
        count=0
        size=x.size()
        input_size = (x.size()[2],x.size()[3])
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        #x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x)
        x_aux= self.resnet.layer2(x)
        x= self.resnet.layer3(x_aux)
        x = self.resnet.layer4(x)
        
        x = self.final(torch.cat([x,
        	self.layer5a(x),
        	self.layer5b(x),
        	self.layer5c(x),
        	self.layer5d(x),
        ], 1))
        
            
        aux = self.aux_branch(x_aux)
        size2 = (aux.size()[2],aux.size()[3])

        x =  F.interpolate(x,size =size2,mode='bilinear',align_corners=True )
        x = self.final2(torch.cat([x,
        	aux
        ], 1))
        

        
    
        #x =  F.interpolate(x,size =size2,mode='bilinear',align_corners=True )
        #out = torch.cat([x,self.layer5a(x),self.layer5b(x),self.layer5c(x),self.layer5d(x),], 1))
        
        
        
        
        #aux = F.interpolate(aux,size =input_size,mode='bilinear',align_corners=True )
        out = F.interpolate(x,size =input_size,mode='bilinear',align_corners=True )
        
        return out
        #x = self.conv6(x)
        
        #return F.upsample_bilinear(x,size[2:])