In [27]:
import os
import io
import cv2
import numpy as np
import pandas as pd

from PIL import Image
from skimage.io import imread, imsave
from skimage.transform import resize

import tensorflow as tf

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [28]:
TRAIN_PATH = '../data/stage1_train/'

dat = pd.read_csv('../data/classes.csv')
to_upsample = dat[dat['background'] == 'white']['filename'].tolist()
to_upsample = [x.replace('.png', '') for x in to_upsample]
to_upsample += to_upsample
to_upsample

['ed5be4b63e9506ad64660dd92a098ffcc0325195298c13c815a73773f1efc279',
 '94519eb45cbe1573252623b7ea06a8b43c19c930f5c9b685edb639d0db719ab0',
 '79fe419488ba98494e3baa35c6fef9662eda1efe325d0ab0ac002f5383245d96',
 '353ab00e964f71aa720385223a9078b770b7e3efaf5be0f66e670981f68fe606',
 'f4b7c24baf69b8752c49d0eb5db4b7b5e1524945d48e54925bff401d5658045d',
 '17b9bf4356db24967c4677b8376ac38f826de73a88b93a8d73a8b452e399cdff',
 '1631352dbafb8a90f11219fffd3bea368a30bc3bad3bbe0e84e19bd720df4945',
 '2c61fdcb36fd1b2944895af6204279e9f6c164ba894198b40c8b7a3c9bf500ea',
 '0f1f896d9ae5a04752d3239c690402c022db4d72c0d2c087d73380896f72c466',
 'a0325cb7aa59e9c0a75e64ba26855d8032c46161aa4bca0c01bac5e4a836485e',
 '420f43d21dbaba42bf8c0995b3a2c85537876d594433770c6c6f3d6b779ec15f',
 'ef3ef194e5657fda708ecbd3eb6530286ed2ba23c88efb9f1715298975c73548',
 'bb61fc17daf8bdd4e16fdcf50137a8d7762bec486ede9249d92e511fcb693676',
 '0e4c2e2780de7ec4312f0efcd86b07c3738d21df30bb4643659962b4da5505a3',
 '259b35151d4a7a5ffdd7ab7f171b142d

In [None]:
import scipy.ndimage as ndi

def find_bbox(img):
    '''
    Finds the bounding box for a binary segmentation mask.
    Finds the first row/col that sum to greater than 0 starting from the left/top for min. 
    Starting from the right/bottom for max. 
    '''
    ymin = 0
    xmin = 0
    ymax = 0
    xmax = 0
    
    for i in range(img.shape[0]):
        if sum(img[i,:]) > 0:
            ymin = max(i - 5, 0)
            break
    for j in range(img.shape[1]):
        if sum(img[:,j]) > 0:
            xmin = max(j - 5, 0)
            break
    for k in range(img.shape[0]-1, 0, -1):
        if sum(img[k,:]) > 0:
            ymax = k + 5
            break
    for l in range(img.shape[1]-1, 0, -1):
        if sum(img[:,l]) > 0:
            xmax = l + 5
            break
            
    return xmin, ymin, xmax, ymax

def make_crops(img_path, msk_path, bbox, n):
    img = imread(img_path)[:,:,:3]
    msk = imread(msk_path)
    
    if bbox[3] > bbox[1] and bbox[2] > bbox[0]:
        img_crop = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
        msk_crop = msk[bbox[1]:bbox[3], bbox[0]:bbox[2]]
    
        img_flnm = os.path.basename(img_path).replace('.png', '_img_{}.png'.format(n))
        msk_flnm = os.path.basename(img_path).replace('.png', '_msk_{}.png'.format(n))

        save_path = '/media/florian/Neumann/Kaggle/Data_Science_Bowl_2018/data/image_mask_crops'

        imsave('{}/{}'.format(save_path, img_flnm), img_crop)
        imsave('{}/{}'.format(save_path, msk_flnm), msk_crop)
    

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def make_tfrecord(img_id, aug=False):
    flip = False
    rotation = 0
    path = '../data/stage1_train/' + img_id
    img_path = path + '/images/' + img_id + '.png'
    
    with tf.gfile.GFile(img_path, 'rb') as fid:
        encoded_png = fid.read()
    encoded_png_io = io.BytesIO(encoded_png)
    image = Image.open(encoded_png_io)
    
    if aug:
        if np.random.random() > 0.5:
            flip = True
            image = image.transpose(Image.FLIP_TOP_BOTTOM)
        
        rotation = np.random.randint(0, 360)
        image = image.rotate(rotation, expand=0)
        
    width, height = image.size
    
    filename = img_path.encode('utf8')
    image_format = b'png'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []
    
    for i, mask_fl in enumerate(next(os.walk(path + '/masks/'))[2]):
        mask_path = path + '/masks/' + mask_fl
        mask_ = imread(mask_path)
        if flip:
            mask_ = np.flip(mask_, 0)
            
        if aug:
            mask_ = ndi.interpolation.rotate(mask_, angle=rotation, reshape=False)
        
        bbox =  find_bbox(mask_)
        
        # we do on the fly augmentations for the UNet so we don't need to save augmented imgs here
        #if (vertical_flip = False) and (rotation == 0):
        #    make_crops(img_path, mask_path, bbox, i)
        
        xmins.append(max(0, bbox[0] / width))
        ymins.append(max(0, bbox[1] / height))
        xmaxs.append(min(1, bbox[2] / width))
        ymaxs.append(min(1, bbox[3] / height))
        classes_text.append('Nuclei'.encode('utf8'))
        classes.append(1)
    
    #fig, ax = plt.subplots(1)
    #ax.imshow(np.asarray(image) / np.amax(np.asarray(image)))
    #rect = patches.Rectangle((xmins[0] * width, ymins[0] * height), (xmaxs[0] - xmins[0]) * width, 
    #                                                                (ymaxs[0] - ymins[0]) * height, 
    #                         linewidth=2, edgecolor='r', facecolor='none')
    #ax.add_patch(rect)
    #plt.show()
        
    tf_example = tf.train.Example(features=tf.train.Features(feature={
           'image/height': int64_feature(height),
           'image/width': int64_feature(width),
           'image/filename': bytes_feature(filename),
           'image/source_id': bytes_feature(filename),
           'image/encoded': bytes_feature(encoded_png),
           'image/format': bytes_feature(image_format),
           'image/object/bbox/xmin': float_list_feature(xmins),
           'image/object/bbox/xmax': float_list_feature(xmaxs),
           'image/object/bbox/ymin': float_list_feature(ymins),
           'image/object/bbox/ymax': float_list_feature(ymaxs),
           'image/object/class/text': bytes_list_feature(classes_text),
           'image/object/class/label': int64_list_feature(classes)
        }))
    
    return tf_example

In [None]:
from sklearn.model_selection import train_test_split

img_ids = next(os.walk(TRAIN_PATH))[1]

train_ids, test_ids = train_test_split(img_ids, test_size=0.1)
print(len(train_ids))

train_ids += [x for x in to_upsample if x not in test_ids]
print(len(train_ids))

writer = tf.python_io.TFRecordWriter('../data/train_nuclei_bbox_augmentations.record')
for img_id in train_ids:
    # make a normal
    tf_example = make_tfrecord(img_id, aug=False)
    writer.write(tf_example.SerializeToString())
    # make 10 augmented versions
    for _ in range(10):
        tf_example = make_tfrecord(img_id, aug=True)
        writer.write(tf_example.SerializeToString())
    # make 
    #tf_example_flipped = make_tfrecord(img_id, vertical_flip=True)
    #tf_example_rot90 = make_tfrecord(img_id, rotation=1)
    #tf_example_rot180 = make_tfrecord(img_id, rotation=2)
    #tf_example_rot270 = make_tfrecord(img_id, rotation=3)
    #writer.write(tf_example.SerializeToString())
    #writer.write(tf_example_flipped.SerializeToString())
    #writer.write(tf_example_rot90.SerializeToString())
    #writer.write(tf_example_rot180.SerializeToString())
    #writer.write(tf_example_rot270.SerializeToString())
writer.close()

writer = tf.python_io.TFRecordWriter('../data/valid_nuclei_bbox.record')
for img_id in test_ids:
    tf_example = make_tfrecord(img_id)
    writer.write(tf_example.SerializeToString())
writer.close()

print('Created {} records'.format(sum(1 for _ in tf.python_io.tf_record_iterator('../data/train_nuclei_bbox_augmentations.record'))))

603
705


In [None]:
# check the record
#for example in tf.python_io.tf_record_iterator('../data/train_nuclei_bbox.record'):
#    result = tf.train.Example.FromString(example)
    
#print(result)

In [None]:
import glob

%matplotlib inline
import matplotlib.pyplot as plt

img_files = glob.glob('../data/image_mask_crops/*_img*')

for fl in img_files[:20]:
    img = imread(fl)
    msk = imread(fl.replace('img', 'msk'))
    
    plt.figure()
    plt.subplot(121)
    plt.imshow(img.squeeze())
    plt.subplot(122)
    plt.imshow(msk.squeeze())
    plt.show()

In [None]:

for fl in img_files[:20]:
    img = imread(fl)
    msk = imread(fl.replace('img', 'msk'))
    
    plt.figure()
    plt.subplot(121)
    plt.imshow(img.squeeze())
    plt.subplot(122)
    plt.imshow(msk.squeeze())
    plt.show()