#  Image Segmentation

### Import modules

In [None]:
import tensorflow as tf

from IPython.display import clear_output
import matplotlib.pyplot as plt

import time

import pandas as pd
# General
from glob import glob
import resource
from tqdm.notebook import tqdm

# Data Handling
import numpy as np
import pandas as pd
import json

# Plotting
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches


## Data loader 

In [None]:
info_df = pd.read_csv('hubmap-kidney-segmentation/HuBMAP-20-dataset_information.csv')

In [None]:
train_df = pd.read_csv('hubmap-kidney-segmentation/train.csv')

# Unroll encoding column
temp_df = pd.DataFrame(columns=['id', 'start', 'run'])
for i, [image_id, encoding] in train_df.iterrows():
    new_section = pd.DataFrame(columns=temp_df.columns)
    encoding = encoding.split()
    
    start = encoding[::2]
    run = encoding[1::2]
    encoding = np.array(list(zip(start, run))).astype(int)
    
    new_section['start'] = encoding[:,0]
    new_section['run'] = encoding[:,1]
    new_section['id'] = image_id
    
    temp_df = temp_df.append(new_section)
train_df = temp_df

In [None]:
image_data_json = {}

for i, row in info_df.iterrows():
    image_file = row.image_file
    image_id = image_file.split('.')[0]
    
    if image_id not in train_df.id.unique():
        continue
    
    with open(f'hubmap-kidney-segmentation/train/{image_id}-anatomical-structure.json', 'r') as json_file:
        anat_data = json.load(json_file)
    with open(f'hubmap-kidney-segmentation/train/{image_id}.json', 'r') as json_file:
        glom_data = json.load(json_file)
        
    image_data_json[image_id] = {'anat': anat_data, 'glom': glom_data}

In [None]:
glom_size_dict = {}
for image_id, data in image_data_json.items():
    
    glom_size_dict[image_id] = []
    path = [p for p in glob('./*/*/*')if f'{image_id}.tiff' in p][0]


    for glom in data['glom']:

        polygon = np.array(glom['geometry']['coordinates']).reshape(-1, 2)
        x = polygon[:, 0]
        y = polygon[:, 1]

        min_x = x.min()
        max_x = x.max()
        min_y = y.min()
        max_y = y.max()

        h = max_y-min_y
        w = max_x-min_x
        
        
        glom_size_dict[image_id].append(h*w)
    
    print(f'{image_id} glom size in squared pixels:{np.mean(glom_size_dict[image_id]):.3f}')

In [None]:
def get_mask(image_id, window=None, out_shape=None):

    w, h = info_df[info_df.image_file == image_id +
                   '.tiff'][['width_pixels', 'height_pixels']].values.flatten()
    
    mask = np.zeros((w*h,), dtype=bool)

    for i, row in train_df[train_df.id == image_id].iterrows():
        start = row.start
        mask[start] = 1

        for j in range(row.run):
            start += 1
            mask[start] = 1

    mask = mask.reshape(w, h).transpose()
    
    if window:
        min_y, max_y, min_x, max_x = window
        mask = mask[min_y:max_y, min_x:max_x]
    
    if out_shape:
        mask = array_resize(mask, out_shape)
    
    return mask

In [None]:
def array_resize(array, out_shape):
    h, w = out_shape

    row_idx = np.round(np.linspace(0, array.shape[0]-1, h)).astype(int)
    col_idx = np.round(np.linspace(0, array.shape[1]-1, w)).astype(int)

    array = array[row_idx][:,col_idx]
    
    return array

In [None]:
def get_cortex_mask(image_id, sample, size=None, x_anchor=0, y_anchor=0):
    anat_data = image_data_json[image_id]['anat']
    
    w, h = sample[['width_pixels', 'height_pixels' ]].values[0]
    
    if size:
        w = h = size
    
    
    # keep only the cortex information
    cortex_data = [tissue for tissue in anat_data if tissue['properties']
                   ['classification']['name'] == 'Cortex'][0]
    # Extracting polygon vertex
    polygon = np.array(cortex_data['geometry']['coordinates']).reshape(-1, 2)
    if cortex_data['geometry']['type'] == 'Polygon':
        polygon = np.array([polygon])
    else:
        polygon = np.array([polygon]).reshape(-1)
        
    cortex_mask_shape = (256,256)
    
    cortex_mask = np.zeros(cortex_mask_shape, dtype=bool)
    
    for subpolygon in polygon:
        subpolygon = np.array(subpolygon)

        subpolygon[:,0] -= x_anchor
        subpolygon[:,1] -= y_anchor
        
        subpolygon[:,0] = subpolygon[:,0]/w*256
        subpolygon[:,1] = subpolygon[:,1]/h*256        

        cortex_mask = cortex_mask + polygon2mask(cortex_mask_shape, subpolygon[:,::-1])
        
    return cortex_mask

In [None]:
def get_input_target():
    # Pick a sample from the training dataset.
    train_mask = info_df.image_file.isin(
        [image+'.tiff' for image in train_df.id.unique()])
    sample = info_df[train_mask].sample()

    image_file, w, h = sample[['image_file',
                               'width_pixels', 'height_pixels']].values[0]

    image_id = image_file.split('.')[0]

    # deciding area to cover
    max_side = int(np.sqrt(np.array(glom_size_dict[image_id])).max())
    size = np.random.randint(max_side, high=max_side*3)

    # Deciding the localization of the image
    x_anchor = np.random.randint(0, w-size)
    y_anchor = np.random.randint(0, h-size)

    # get image
    path = [p for p in glob('./*/*/*')if f'{image_id}.tiff' in p][0]

    with rasterio.open(path) as src:
        sample_image = src.read(out_shape=(256, 256),
                                window=Window.from_slices((y_anchor, y_anchor+size),
                                                          (x_anchor, x_anchor+size)),
                                resampling=rasterio.enums.Resampling.cubic)

    sample_image = np.moveaxis(sample_image, 0, -1)

    # Get corresponding mask
    window = (y_anchor, y_anchor+size, x_anchor, x_anchor+size)
    mask = get_mask(image_id, window, out_shape=(256, 256))

    # Get the cortex mask
    cortex_mask = get_cortex_mask(image_id, sample, size=size, x_anchor=x_anchor, y_anchor=y_anchor)

    # Randomply flip the arrays along axis 1 and 2
    flip_vertically = np.random.choice([-1, 1])
    flip_horizontally = np.random.choice([-1, 1])
    result = []
    for channel in [sample_image, mask, cortex_mask]:
        # apply transformatio
        result.append(channel[::flip_vertically, ::flip_horizontally])
    sample_image, mask, cortex_mask = result
    
    sample_image = sample_image/255
    
    input_image = np.concatenate([sample_image, cortex_mask[..., np.newaxis]], axis=-1)
    
    
    # TODO include a parameter to generate n inputs and targets
#     input_image = input_image[np.newaxis, ...].astype(np.float32)
#     mask = mask[np.newaxis, ...].astype(np.float32)
    
    input_image = input_image.astype(np.float32)
    mask = mask.astype(np.float32)
    
    return (input_image, mask)

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        # BytesList won't unpack a string from an EagerTensor.
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [None]:
def image_example():
    
    input_image, mask = get_input_target()

    feature = {'image': _bytes_feature(input_image.tobytes()),
                'target':_bytes_feature(mask.tobytes())}

    return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:
max_num = 0
for file in glob('*.tfrecords'):
    max_num = max(max_num, int(file.split('.')[0].split('_')[-1]))
record_file = f"image_{max_num+1}.tfrecords"

with tf.io.TFRecordWriter(record_file) as writer:
    for i in tqdm(range(2000)):
        tf_example = image_example()
        writer.write(tf_example.SerializeToString())

In [None]:
input_image, mask = get_input_target()
# Plot
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(input_image[:,:,:3])

plt.subplot(1, 3, 2)
plt.imshow(mask)

plt.subplot(1, 3, 3)
plt.imshow(input_image[:,:,3])

plt.show()

### Input pipeline

In [None]:
dataset = tf.data.TFRecordDataset(filenames = glob("*tfrecord*"))

In [None]:
def get_image_from_example(raw_example):
    parsed = tf.train.Example.FromString(raw_example.numpy())
    
    target_bytes_string = parsed.features.feature['target'].bytes_list.value[0]
    image_bytes_string = parsed.features.feature['image'].bytes_list.value[0]
    
    mask = np.frombuffer(image_bytes_string, dtype='<f4').reshape(256,256, 1)
    input_image = np.frombuffer(target_bytes_string, dtype='<f4').reshape(256,256,4)
    
    return mask, input_image

In [None]:
def batch_generator(epochs, batch_size):
    batched_dataset = dataset.batch(batch_size)
    for epoch in range(epochs):
        for batch in batched_dataset:
            image_list = []
            target_list = []

            for example in batch:
                target, input_image = get_image_from_example(example)

                image_list.append(input_image)
                target_list.append(target)
            image_list = np.stack(image_list)
            target_list = np.stack(target_list)

            yield (image_list, target_list)

## Define the model

In [None]:
OUTPUT_CHANNELS = 0

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                               kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

In [None]:
def Generator():
    
    inputs = tf.keras.layers.Input(shape=[256, 256, 4])

    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (bs, 128, 128, 64)
        downsample(128, 4),  # (bs, 64, 64, 128)
        downsample(256, 4),  # (bs, 32, 32, 256)
        downsample(512, 4),  # (bs, 16, 16, 512)
        downsample(512, 4),  # (bs, 8, 8, 512)
        downsample(512, 4),  # (bs, 4, 4, 512)
        downsample(512, 4),  # (bs, 2, 2, 512)
        downsample(512, 4),  # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
        upsample(512, 4),  # (bs, 16, 16, 1024)
        upsample(256, 4),  # (bs, 32, 32, 512)
        upsample(128, 4),  # (bs, 64, 64, 256)
        upsample(64, 4),  # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(1, 4,
                                           strides=2,
                                           padding='same',
                                           kernel_initializer=initializer,
                                           activation='sigmoid')  # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

## Train the model

In [None]:
def generate_images(input_image, mask):    

    pred = generator.predict(input_image.copy()[np.newaxis, ...]).reshape(256,256, 1)
    
    print(dice_coef(pred, mask))

    # Plot
    plt.figure(figsize=(15, 15))
    plt.subplot(2, 2, 1)
    plt.imshow(input_image[:,:,:3])

    plt.subplot(2, 2, 2)
    plt.imshow(mask)

    plt.subplot(2, 2, 3)
    plt.imshow(input_image[:,:,3])

    plt.subplot(2, 2, 4)
    plt.imshow(pred)

    plt.show()
    
    return input_image, mask

In [None]:
def dice_coef(a, b, smooth=1e-5):
    
    sum_a = tf.reduce_sum(a)
    sum_b = tf.reduce_sum(b)
    
    join_sum = tf.reduce_sum(tf.multiply(a, b))

    dice = -(join_sum*2+smooth)/(sum_a+sum_b+smooth)+1
    
    return dice

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-3, beta_1=0.5)

In [None]:
generator.compile(generator_optimizer,dice_coef)

In [None]:
batch_size = 16
epochs = 10
steps_per_epoch = int(2000/batch_size)

In [None]:
generator.fit(batch_generator(epochs, batch_size),
              steps_per_epoch=steps_per_epoch, epochs=epochs)

In [None]:
# checkpoint_dir = './training_checkpoints'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
#                                  generator=generator)

In [None]:
# @tf.function
def train_step(input_image, target, step):
    with tf.GradientTape() as gen_tape:
        
        gen_output = generator(input_image, training=True)

#         gen_total_loss = generator_loss(gen_output, target)

        gen_output = tf.reshape(gen_output, (256,256))
        gen_total_loss = dice_coef(gen_output, target)
        

#     generate_images(input_image, target, gen_output)

#     tf.print(gen_total_loss)


    generator_gradients = gen_tape.gradient(gen_total_loss,
                                            generator.trainable_variables)


    generator_optimizer.apply_gradients(zip(generator_gradients,
                                            generator.trainable_variables))

#     with summary_writer.as_default():
#         tf.summary.scalar('gen_total_loss', gen_total_loss, step=step)

In [None]:
np.newaxis

In [None]:
input_image.shape

In [None]:
def fit(epochs):

    for epoch in range(epochs):
        start = time.time()

        clear_output(wait=True)

#         clear_output(wait=True)

#         if (step + 1) % 5 == 0:
#             generate_images()

        print("Epoch: ", epoch)

        # Train
#         input_image, target = get_input_target()

        for n, raw_example in dataset.enumerate():
            print('.', end='')
            if (n+1) % 100 == 0:
                print()
            target, input_image = get_image_from_example(raw_example)

            train_step(input_image, target, epoch)

        print()

#         # saving (checkpoint) the model every 20 epochs
#         if (epoch + 1) % 20 == 0:
#             checkpoint.save(file_prefix=checkpoint_prefix)

        print('Time taken for step {} is {} sec\n'.format(
            step + 1, time.time()-start))
#     checkpoint.save(file_prefix=checkpoint_prefix)

In [None]:
EPOCHS = 100

fit(EPOCHS)

In [None]:
for n, raw_example in dataset.enumerate():
    target, input_image = get_image_from_example(raw_example)
    
    generate_images(input_image, target)
    
    input()
    clear_output(wait=False)

In [None]:
generator.fit()