In [2]:
import os
import socket
import utils
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

DATA_DIR = {
    3 : "/mnt/server5/sdi/datasets",
    4 : "/mnt/server5/sdi/datasets",
    5 : "/data1/sdi/datasets"
}

def get_datadir():
    if socket.gethostname() == "server3":
        return DATA_DIR[3]
    elif socket.gethostname() == "server4":
        return DATA_DIR[4]
    elif socket.gethostname() == "server5":
        return DATA_DIR[5]
    else:
        raise NotImplementedError

datadir = os.path.join(get_datadir(), 'CPN_all/Images')
maskdir = os.path.join(get_datadir(), 'CPN_all/Masks')

In [49]:
CROP = 256

for fname in os.listdir(datadir):
    mname = fname.split('.')[0] + "_mask." + fname.split('.')[-1]
    mask = os.path.join(maskdir, mname)
    image = os.path.join(datadir, fname)

    if not os.path.exists(mask) or not os.path.exists(image):
        raise Exception ("File Not Exists", mask, image)
    
    image = np.array(Image.open(image).convert("L"), dtype=np.uint8)
    mask = np.array(Image.open(mask).convert("L"), dtype=np.uint8)
    h, w = np.where(mask > 0)
    tl = (h.min(), w.min())
    rb = (h.max(), w.max())
    hp = int((tl[0] + rb[0])/2)
    wp = int((tl[1] + rb[1])/2)
    
    if image.shape[0] < 256 or image.shape[1] < 256:
        raise Exception (image.shape)
    hlen = rb[0] - tl[0] + 1
    wlen = rb[1] - tl[1] + 1

    if hlen > CROP:
        i = int(hp - hlen/2)
        h = i + CROP
    else:
        hextra = (CROP - hlen) % CROP
        if hextra % 2 == 0:
            i = int(tl[0] - hextra/2)
            h = int(rb[0] + hextra/2 + 1)
        else:
            i = int(tl[0] - np.floor(hextra/2))
            h = int(rb[0] + np.floor(hextra/2) + 2)
        if i < 0:
            i = 0
            h = CROP
        if h >= image.shape[0]:
            i = image.shape[0] - CROP
            h = image.shape[0]
    if wlen > CROP:
        j = int(wp - wlen/2)
        w = j + CROP
    else:
        wextra = (CROP - wlen) % CROP
        if wextra % 2 == 0:
            j = int(tl[1] - wextra/2)
            w = int(rb[1] + wextra/2 + 1)
        else:
            j = int(tl[1] - np.floor(wextra/2))
            w = int(rb[1] + np.floor(wextra/2) + 2)
        if j < 0:
            j = 0
            w = CROP
        if w >= image.shape[1]:
            j = image.shape[1] - CROP
            w = image.shape[1]

    print(f'(i,j,h,w):({i},{j},{h},{w}), shape: {image.shape}, ({hp}, {wp})')
    if h - i != 256 or w - j != 256:
        raise Exception
    image = Image.fromarray(image[i:h, j:w])
    mask = Image.fromarray(mask[i:h, j:w])
    if image.size != (256, 256):
        raise Exception
    image.save(os.path.join(get_datadir(), 'CPN_trim/Images', fname))
    mask.save(os.path.join(get_datadir(), 'CPN_trim/Masks', mname))

    '''plt.subplot(121)
    plt.imshow(image[i:h, j:w], cmap='gray')
    plt.subplot(122)
    plt.imshow(mask[i:h, j:w], cmap='gray')
    plt.show()
    plt.subplot(121)
    plt.imshow(image, cmap='gray')
    plt.subplot(122)
    plt.imshow(mask, cmap='gray')
    plt.show()'''


(i,j,h,w):(2,96,258,352), shape: (438, 413), (129, 223)
(i,j,h,w):(0,28,256,284), shape: (487, 419), (81, 155)
(i,j,h,w):(0,76,256,332), shape: (487, 417), (66, 203)
(i,j,h,w):(0,88,256,344), shape: (487, 414), (126, 215)
(i,j,h,w):(0,93,256,349), shape: (483, 419), (60, 220)
(i,j,h,w):(0,0,256,256), shape: (478, 417), (116, 123)
(i,j,h,w):(0,80,256,336), shape: (476, 612), (92, 207)
(i,j,h,w):(0,179,256,435), shape: (503, 601), (120, 306)
(i,j,h,w):(0,389,256,645), shape: (568, 715), (106, 516)
(i,j,h,w):(0,4,256,260), shape: (518, 415), (66, 131)
(i,j,h,w):(72,235,328,491), shape: (478, 609), (199, 374)
(i,j,h,w):(0,140,256,396), shape: (521, 396), (100, 272)
(i,j,h,w):(0,76,256,332), shape: (498, 417), (63, 203)
(i,j,h,w):(4,114,260,370), shape: (510, 417), (131, 241)
(i,j,h,w):(7,32,263,288), shape: (522, 413), (134, 159)
(i,j,h,w):(18,104,274,360), shape: (567, 421), (145, 231)
(i,j,h,w):(0,50,256,306), shape: (491, 406), (120, 177)
(i,j,h,w):(57,28,313,284), shape: (481, 420), (1