# Module Dependency

In [1]:
import matplotlib.pyplot as plt 
import os
import sys
import torch
from model_old import PSPModule, PSPNet, PSPUpsample
import numpy as np
from skimage import io, measure, transform
from torch.autograd import Variable
from torchvision import transforms

# Parameter Setting

In [2]:
width= 5000
high= 5000
edge_size= 512
crop_size= 512
confident_size= 400
USE_GPU = torch.cuda.is_available()
input_path= './input2/'
save_path= './save/'

# Function Definition 

## 1. Test Image Cropping

In [3]:
def crop(image, crop_size, confident_size):
    
    edge_size= 512
    
    high= image.shape[0]
    width= image.shape[1]
    depth= image.shape[2]
    assert(high, width)== (5000, 5000)
    
    atoll_size= (crop_size- confident_size)// 2
    crop_list= []
    
    for col in range(edge_size, width, edge_size):
        crop_list.append(image[0: edge_size, col- edge_size: col, :])
    crop_list.append(image[0: 0+ crop_size, -edge_size:, :])
    for row in range(edge_size* 2, high, edge_size):
        crop_list.append(image[row- edge_size: row, 0: edge_size, :])
        crop_list.append(image[row- edge_size: row, -edge_size:, :])
    for col in range(edge_size, width, edge_size):
        crop_list.append(image[-edge_size:, col- edge_size: col, :])
    crop_list.append(image[-edge_size:, -edge_size:, :])
    
    num_row= (high- 2* atoll_size)// confident_size
    num_col= (width- 2* atoll_size)// confident_size
    
    for row in range(num_row):
        for col in range(num_col):
            crop_list.append(image[row* confident_size: (row+ 1)* confident_size+ 2* atoll_size,
                                   col* confident_size: (col+ 1)* confident_size+ 2* atoll_size,
                                   :])
    return crop_list

## 2. Numpy Array to Tensor

In [4]:
def toTensor(image):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image= transform(image)
    image= image.unsqueeze(0)
    return image

## 3. Combination

In [5]:
def combine(crop_list, crop_size, confident_size, high= 5000, width= 5000):
    
    edge_size= 512
    
    comb= np.empty([width, high])
    atoll_size= (crop_size- confident_size)// 2 
    count= 0

    for col in range(edge_size, width, edge_size):
        comb[0: edge_size, col- edge_size: col]= lbl_list[count]
        count+= 1
    comb[0: 0+ crop_size, -edge_size:]= lbl_list[count]
    count+= 1
    for row in range(edge_size* 2, high, edge_size):
        comb[row- edge_size: row, 0: edge_size]= lbl_list[count]
        comb[row- edge_size: row, -edge_size:]= lbl_list[count+ 1]
        count+= 2
    for col in range(edge_size, width, edge_size):
        comb[-edge_size:, col- edge_size: col]= lbl_list[count]
        count+= 1
    comb[-edge_size:, -edge_size:]= lbl_list[count]
    count+= 1

    for row in range(atoll_size+ confident_size, high, confident_size):
        for col in range(atoll_size+ confident_size, width, confident_size):
            comb[row- confident_size: row,
                 col- confident_size: col]= lbl_list[count][atoll_size: confident_size+ atoll_size, 
                                                            atoll_size: confident_size+ atoll_size]
            count+= 1

    return comb

# Forwarding

In [18]:
test_list= os.listdir(input_path)

# model = torch.load('./model-0113-test-40ep-0.9607.pt')
model = torch.load('./model-0.9548.pt')

im_name= 'input5.tif'

image= io.imread(input_path+ im_name)
image= transform.resize(image, [512, 512, 3])

image*= 255
test= toTensor(image.astype('uint8'))

if USE_GPU:
    test= Variable(test.cuda())
else:
    test= Variable(test)

model.eval()
output= model(test)
output= transforms.ToPILImage()(output.data.cpu()[0])
output= np.array(output)

comb= output
comb[comb>= 127]= 255
comb[comb< 127]= 0
label, count= measure.label(comb, return_num= True)

comb= transform.resize(comb, [5000, 5000])
comb*= 255

io.imsave(save_path+ im_name, comb.astype('uint8'))
    

with open(save_path+ 'count.txt', 'a') as f:
    f.write('{}:{}\n'.format(im_name, count))


  warn("The default mode, 'constant', will be changed to 'reflect' in "
