In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import PIL
import torch.nn.functional as F
from skimage import filters
from os import listdir
from os.path import isfile, join
import densenet as ds
import pspmodule as PSP
densenet = ds.densenet121(pretrained=True)
densenet

DenseNet(
  (features): Sequential(
    (conv0): Conv2d (3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (relu0): ReLU(inplace)
    (pool0): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm.1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
        (relu.1): ReLU(inplace)
        (conv.1): Conv2d (64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm.2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
        (relu.2): ReLU(inplace)
        (conv.2): Conv2d (128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm.1): InstanceNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
        (relu.1): ReLU(inplace)
        (conv.1): Conv2d (96, 128, kernel_size=(1, 1), stride=(1, 1), bias=

In [2]:

class Dense_Net_Feature(nn.Module):
    def __init__(self):
        super(Dense_Net_Feature, self).__init__()
        
        
        self.layer0 = nn.Sequential(densenet.features.conv0, densenet.features.norm0, densenet.features.relu0, densenet.features.pool0)
        self.denseblock1 = densenet.features.denseblock1
        self.transition1 = densenet.features.transition1
        self.denseblock2 = densenet.features.denseblock2
        self.transition2 = densenet.features.transition2
        self.denseblock3 = densenet.features.denseblock3
        self.transition3 = densenet.features.transition3
        self.denseblock4 = densenet.features.denseblock4
        self.transition4 = densenet.features.transition4
        self.denseblock5 = densenet.features.denseblock5
        
        self.final = nn.Sequential(
            nn.Conv2d(20, 10, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(10, 2, kernel_size=3,padding=1)
        )
        
#         self.BR0 = PSP.BoundaryRefine(50,50)
        self.BR1 = PSP.BoundaryRefine(64,64)
        self.BR2 = PSP.BoundaryRefine(128+64,128+64)
        self.BR3 = PSP.BoundaryRefine(128,128)
        
        
        self.reducedim1 = nn.Conv2d(64,20,kernel_size=3,padding=1,bias=False)
        self.reducedim2 = nn.Conv2d(128+64,32,kernel_size=3,padding=1,bias=False)
        self.reducedim3 = nn.Conv2d(1024+1024,128,kernel_size=3,padding=1,bias=False)
        self.reducedim3 = nn.Conv2d(1024+1024,128,kernel_size=3,padding=1,bias=False)
#         self.reducedim = nn.Conv2d(256+256,50,kernel_size=3,padding=1,bias=False)
        self.reducedim11 = nn.Conv2d(256,32,kernel_size=3,padding=1,bias=False)
        self.reducedim22 = nn.Conv2d(512,64,kernel_size=3,padding=1,bias=False)
        self.reducedim33 = nn.Conv2d(1024,64,kernel_size=3,padding=1,bias=False)
        self.reducedim44 = nn.Conv2d(1024,64,kernel_size=3,padding=1,bias=False)
        

    def forward(self, x):
        x_size = x.size()
        
        out = self.layer0(x)
        
        
        out = self.denseblock1(out)#256
        out4 = self.reducedim11(out)#32
        
        out = self.transition1(out)
        out = self.denseblock2(out)#512
        out8 = self.reducedim22(out)#64
        
        out = self.transition2(out)
        out = self.denseblock3(out)#1024
        out16 = self.reducedim33(out)#64
        
        out = self.transition3(out)
        out = self.denseblock4(out)#1024
        out32 = self.reducedim44(out)#64
        
        
        out = torch.cat((F.upsample(out32,(out16.size()[2],out16.size()[3]),mode='bilinear'),out16),1)#64+64
#         out = self.reducedim3(out)
        out = self.BR3(out)
#         out = self.reducedim3(out)
        
        out = torch.cat((F.upsample(out,(out8.size()[2],out8.size()[3]),mode='bilinear'),out8),1)#128+64
#       
        out = self.BR2(out)
        out = self.reducedim2(out)
        
        out = torch.cat((F.upsample(out,(out4.size()[2],out4.size()[3]),mode='bilinear'),out4),1)#32+32
        
        out = self.BR1(out)
        out = self.reducedim1(out)
        
        out = F.upsample(out,(x_size[2],x_size[3]),mode='bilinear')
    
        out = self.final(out)
        
        return out


In [3]:
densenet_f = Dense_Net_Feature()
densenet_f.load_state_dict(torch.load('DenseNet121_instance_BR1.pt'))
use_gpu = torch.cuda.is_available()
if use_gpu:
    densenet_f = densenet_f.cuda()

In [5]:
test_path = 'D:/DataSet/AerialImageDataset/counting/original/'
filelist = [f for f in listdir('D:/DataSet/AerialImageDataset/counting/original/') if isfile(join('D:/DataSet/AerialImageDataset/counting/original/', f))]
save_path = 'D:/DataSet/AerialImageDataset/counting/bin/'

test_transform = transforms.Compose([
    transforms.ToTensor()
])
CROP_SIZE = 1000
for name in filelist:
    ori = test_transform(PIL.Image.open(test_path+name))
    out_bin = np.zeros((5000,5000),dtype=np.uint8)
    for i in range(5000//CROP_SIZE):
        for j in range(5000//CROP_SIZE):
            
            test = ori[:,i*CROP_SIZE:i*CROP_SIZE+CROP_SIZE,j*CROP_SIZE:j*CROP_SIZE+CROP_SIZE].unsqueeze(0)
        
            test = Variable(test.cuda(),volatile=True)
            densenet_f.eval()
            out = densenet_f(test)
            m = nn.Softmax2d()
            out = m(out)
            out = (out.cpu()[0].data.numpy()).transpose((1, 2, 0))
            
    
    
#             a = out[:,:,1].copy()
#             a[a<0] = 0
#             a /= a.max()
            a = out[:,:,1].copy()
            a[a<0.5] = 0
            a /= a.max()
            a *= 255
            a = a.astype(np.uint8)
            out_bin[CROP_SIZE*i:CROP_SIZE*i+CROP_SIZE,CROP_SIZE*j:CROP_SIZE*j+CROP_SIZE] = a.copy()
    PIL.Image.fromarray(out_bin).save(save_path+name)



