In [1]:
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from farmdataset import FarmDataset
from nestedunet import NestedUNet
from vggunet import VggUNet
import segmentation_models_pytorch as smp
import time
from PIL import Image
from torch.autograd import Variable
import numpy as np
import pydensecrf.densecrf as dcrf

In [2]:
use_cuda = True

# Load model
model1 = torch.load('./ensemble/nestedunet_pretrained_model40')
model2 = torch.load('./ensemble/se_resnextunet_pretrained_model3')
# model3 = torch.load('./ensemble/vggpspnet_pretrained_model20')
model4 = torch.load('./ensemble/VggUnet_model47')
model5 = torch.load('./ensemble/vggunet_pretrained_model19')
# print(segnet_model)
device = torch.device("cuda" if use_cuda else "cpu")
for model in [model1, model2, model4, model5]:
    model = model.to(device)
    model.eval()
ds = FarmDataset(istrain=False, isaug=False, isval=False)



In [3]:
def dense_crf(img, output_probs):
    # img is an image-array, e.g. im.dtype == np.uint8 and im.shape == (3,xxx,xxx)
    h = output_probs.shape[1]
    w = output_probs.shape[2]

    d = dcrf.DenseCRF2D(w, h, 5) # The last number is the number of class
    U = -np.log(output_probs)
    U = U.reshape((5, -1)) # The first is the number of class
    U = np.ascontiguousarray(U)
    img = np.ascontiguousarray(img)

    d.setUnaryEnergy(U)

    d.addPairwiseGaussian(sxy=20, compat=3)
    d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10)

    Q = d.inference(5)
    Q = np.argmax(np.array(Q), axis=0).reshape((h, w))

    return Q

In [4]:
def predict(d, outputname='tmp.bmp'):
    wx = d.RasterXSize   
    wy = d.RasterYSize   
    print(wx,wy)
    od = data = np.zeros((wy,wx),np.uint8)
    blocksize = 1024
    step = 512
    for cy in range(step,wy-blocksize,step):
        print('current cy is: {}'.format(cy))
        for cx in range(step,wx-blocksize,step):
            img = d.ReadAsArray(cx-step,cy-step,blocksize,blocksize)[0:3,:,:] #channel*h*w
            if (img.sum()==0): 
                continue
        
            x = torch.from_numpy(img/255.0).float()        
            x = x.unsqueeze(0).to(device)
            
            # Use each model to do the prediction, save the result as probability of each class
            img = img.transpose(1,2,0)
            pre = []
            for model in [model1, model2, model4, model5]:
                i = F.avg_pool2d(x, 4)
                model_pre = model.forward(i)
                model_pre = F.interpolate(model_pre, mode='bilinear', scale_factor=4)
                pre.append(model_pre)
        
            # Ensemble the probability result, then apply softmax and crf
            r = 1.5 * pre[0] + 0.9 * pre[1] + 0.9 * pre[2] + 0.9 * pre[3]
            r = F.softmax(r, dim=1)
            r = r[0].cpu().data.numpy()
            r = dense_crf(img, r)
#             r = torch.argmax(r.cpu()[0],0).byte().numpy()  #512*512
                    
            od[cy-step//2:cy+step//2, cx-step//2:cx+step//2] = r[blocksize//4:step+blocksize//4, blocksize//4:step+blocksize//4]

    cv2.imwrite('./tmp/upload/' + outputname, od)
    cv2.imwrite('./tmp/obvious/' + outputname, od*60)
    return

In [5]:
start = time.time()
print("start predict.....")
predict(ds[0],'image_5_predict.png')
print("start predict 2 .....")
predict(ds[1],'image_6_predict.png')
end = time.time()
print('prediction time: {}'.format(end - start))

start predict.....
43073 20115
current cy is: 512


  "See the documentation of nn.Upsample for details.".format(mode))
  import sys


current cy is: 1024
current cy is: 1536
current cy is: 2048
current cy is: 2560
current cy is: 3072
current cy is: 3584
current cy is: 4096
current cy is: 4608
current cy is: 5120
current cy is: 5632
current cy is: 6144
current cy is: 6656
current cy is: 7168
current cy is: 7680
current cy is: 8192
current cy is: 8704
current cy is: 9216
current cy is: 9728
current cy is: 10240
current cy is: 10752
current cy is: 11264
current cy is: 11776
current cy is: 12288
current cy is: 12800
current cy is: 13312
current cy is: 13824
current cy is: 14336
current cy is: 14848
current cy is: 15360
current cy is: 15872
current cy is: 16384
current cy is: 16896
current cy is: 17408
current cy is: 17920
current cy is: 18432
current cy is: 18944
start predict 2 .....
62806 21247
current cy is: 512
current cy is: 1024
current cy is: 1536
current cy is: 2048
current cy is: 2560
current cy is: 3072
current cy is: 3584
current cy is: 4096
current cy is: 4608
current cy is: 5120
current cy is: 5632
current c