# **Convert StratifiedKFold Images To TFRecord**

### **If you wonder what is TFRecord, check out link below!**

* [Tensorflow.org](https://www.tensorflow.org/tutorials/load_data/tfrecord)

## **Reference**
#### This notebook is based in **nice kernels** below!
#### Appreciated for sharing kernels you all

* [Kaveh Shahhosseini](https://www.kaggle.com/kavehshahhosseini/sartorius-convert-images-and-masks-to-tfrecord)
* [Chris Deotte](https://www.kaggle.com/cdeotte/how-to-create-tfrecords)

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm

import cv2
import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf
from sklearn.model_selection import StratifiedKFold

In [None]:
IMAGE_WIDTH  = 704
IMAGE_HEIGHT = 520

df = pd.read_csv("../input/sartorius-cell-instance-segmentation/train.csv")

# **Convert RLE Encoded Masks!**

## What is RLE Encoding?

#### It is a method of expressing only the number and repeated values that appear consecutively in the data.
#### EXAMPLE : 11133333555 → 3 1 5 3 3 5

* [Wikipedia](https://en.wikipedia.org/wiki/Run-length_encoding)
* [dcode converter](https://www.dcode.fr/rle-compression)

#### It will be easy to understand if you refer to the links above!

In [None]:
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros((shape[0] * shape[1], shape[2]), dtype=np.uint8)
    for start, end in zip(starts, ends):
        img[start : end] = 255
    return img.reshape(shape)


def build_masks(labels, input_shape):
    height, width = input_shape
    mask = np.zeros((height, width, 1))
    for label in labels:
        mask += rle_decode(label, shape=(height, width, 1))
    mask = mask.reshape((520, 704))
    mask = mask.astype('uint8')
    return mask

# **StratifiedKFold our ids of images!**

### **Check whether each id has one cell_type!**

#### **If lengths of 'grouped' and 'ids' are same, there is only one cell_type in each id**

In [None]:
ids = df["id"].unique()
grouped = df.groupby('id')['cell_type'].value_counts()

if len(grouped) == len(ids):
    print("Same!")
else:
    print('Nope!')

### **Make a new list of folded ids!**

In [None]:
cell_data = grouped.index.to_frame().reset_index(drop = True)
for i in range(5):
    cell_data[f'{i}_type'] = 'na'
cell_data

In [None]:
stkf = StratifiedKFold(n_splits = 5,
                       shuffle = True,
                       random_state = 2021)

for fold, (train_index, valid_index) in enumerate(stkf.split(cell_data['id'], cell_data['cell_type'])):
    cell_data.loc[train_index, f'{fold}_type'] = 'train'
    cell_data.loc[valid_index, f'{fold}_type'] = 'valid'

cell_data

In [None]:
fig, ax = plt.subplots(1, 5, figsize = (12, 6))

for i in range(5):
    sns.barplot(cell_data[cell_data[f'{i}_type'] == 'train']['cell_type'].value_counts().index,
                cell_data[cell_data[f'{i}_type'] == 'train']['cell_type'].value_counts().values,
                ax = ax[i]).set_title(f'train Fold {i}')
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 5, figsize = (12, 6))

for i in range(5):
    sns.barplot(cell_data[cell_data[f'{i}_type'] == 'valid']['cell_type'].value_counts().index,
                cell_data[cell_data[f'{i}_type'] == 'valid']['cell_type'].value_counts().values,
                ax = ax[i]).set_title(f'valid Fold {i}')
plt.tight_layout()

### **Seems Done!**

# **Convert Images to TFRecords!!**

In [None]:
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_example(image, label):
    feature = {
        'image': _bytes_feature(image.tobytes()),
        'label': _bytes_feature(label.tobytes())
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
! mkdir -p ./tfrecords/fold_0
! mkdir -p ./tfrecords/fold_1
! mkdir -p ./tfrecords/fold_2
! mkdir -p ./tfrecords/fold_3
! mkdir -p ./tfrecords/fold_4

In [None]:
%%time
for fold in range(5):
    print('Fold ', fold)
    outpath = f"./tfrecords/fold_{fold}"
    train_path = os.path.join(outpath,'train.tfrec')
    valid_path = os.path.join(outpath,'valid.tfrec')
    
    train_ids = cell_data[cell_data[f'{fold}_type'] == "train"]['id'].values
    valid_ids = cell_data[cell_data[f'{fold}_type'] == "valid"]['id'].values
    
    # Train Data TFRecord
    with tf.io.TFRecordWriter(train_path, options=tf.io.TFRecordOptions(compression_type="GZIP")) as writer:
        for i in tqdm(train_ids, colour="#73d315", ncols=100):
            img_path = os.path.join("../input/sartorius-cell-instance-segmentation/train", f"{i}.png")
            img = cv2.imread(img_path)
            img = cv2.resize(img, (256, 256))
            img = (img/255.).astype('float32')
            labels = df[df["id"] == i]["annotation"].tolist()
            mask = build_masks(labels, input_shape=(520, 704))
            mask = cv2.resize(mask, (256, 256))
            example = serialize_example(img, mask)
            writer.write(example)
    
    # Valid Data TFRecord
    with tf.io.TFRecordWriter(valid_path, options=tf.io.TFRecordOptions(compression_type="GZIP")) as writer:
        for i in tqdm(valid_ids, colour="#73d315", ncols=100):
            img_path = os.path.join("../input/sartorius-cell-instance-segmentation/train", f"{i}.png")
            img = cv2.imread(img_path)
            img = (img/255.).astype('float32')
            labels = df[df["id"] == i]["annotation"].tolist()
            mask = build_masks(labels, input_shape=(520, 704))
            example = serialize_example(img, mask)
            writer.write(example)

# **Thanks!**