In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import models
import torchvision
import torch.nn.functional as F
import numpy as np
from torch.utils import model_zoo






In [None]:
class PyramidPool(nn.Module):
    def __init__(self, pool_size):
        super(PyramidPool,self).__init__()
        poolchannl, outpoolchannl=2048,512        
        self.pooldown =  nn.Sequential(
        nn.AdaptiveAvgPool2d(pool_size),             
        nn.Conv2d(poolchannl, outpoolchannl, 1, bias=False),
        nn.BatchNorm2d(outpoolchannl, momentum=.95),                         
        nn.ReLU(inplace=True)
)
    def forward(self, x):
        size=x.size()
        return F.upsample(self.pooldown(x), size[2:], mode='bilinear')

In [None]:
def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

In [None]:
class PSPNet(nn.Module):
    def __init__(self, classes,poolchannl,outpoolchannl):
        super(PSPNet,self).__init__()
        pool_scales=[1, 2, 3, 6]
        
    
        self.pool1 = PyramidPool(1)
        self.pool2 = PyramidPool(2)
        self.pool3 = PyramidPool(3)
        self.pool6 = PyramidPool(6) 
     
        self.layer_cat = nn.Sequential(
            nn.Conv2d(poolchannl*2, outpoolchannl, 3, padding=1, bias=False),
            nn.BatchNorm2d(outpoolchannl, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(.1),
            nn.Conv2d(outpoolchannl, classes, 1),
        )
        
        self.layer_aux= nn.Sequential(
            nn.Conv2d(poolchannl//4, poolchannl//2, 3, padding=1, bias=False),
            nn.BatchNorm2d(poolchannl//2),
            nn.ReLU(inplace=True),
            nn.Dropout(.1),
            nn.Conv2d(poolchannl//2,classes, 1),
        )
            
        self.outpot_layer=nn.Sequential(
            nn.Conv2d(4, outpoolchannl//2, 3, padding=1, bias=False),
            nn.BatchNorm2d(outpoolchannl//2, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(.1),
            nn.Conv2d(outpoolchannl//2, outpoolchannl, 3, padding=1, bias=False),
            nn.BatchNorm2d(outpoolchannl, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(.1),
            nn.Conv2d(outpoolchannl, classes, 1),
        )
            
        initialize_weights(self.pool1,self.pool2,self.pool3,self.pool6,self.layer_cat)
        
        self.resnet = torchvision.models.resnet101(pretrained = True)
        
    def forward(self, x):
        size1 = (x.size()[2],x.size()[3])
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.layer1(x)
        x_aux= self.resnet.layer2(x)
        x= self.resnet.layer3(x_aux)
        x = self.resnet.layer4(x)

        x1 =self.pool1(x)
        x2 =self.pool2(x)
        x3 =self.pool3(x)
        x6 =self.pool6(x)
        
        x = torch.cat([x,x1,x2,x3,x6], 1)
        x = self.layer_cat(x)
        aux = self.layer_aux(x_aux)
        size2 = (aux.size()[2],aux.size()[3])

        x =  F.interpolate(x,size =size2,mode='bilinear',align_corners=True )
        x = self.outpot_layer(torch.cat([x,aux], 1))
        out = F.interpolate(x,size =size1,mode='bilinear',align_corners=True )
        return out
         