In [None]:
import os
import cv2
from glob import glob
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from albumentations import *

In [None]:
import zipfile
local_zip = '/content/drive/MyDrive/BlastsOnline.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./BlastsOnline')

zip_ref.close()

In [None]:
drive_path = '/content/BlastsOnline/BlastsOnline'

In [None]:
def create_dir(path):
    """Create a directory"""
    if not os.path.exists(path):
        os.mkdir(path)

In [None]:
def load_data(path):
    images = sorted(glob(f"{path}/Images/*.BMP"))
    masks_te = sorted(glob(f"{path}/GT_TE/*.bmp"))
    masks_icm = sorted(glob(f"{path}/GT_ICM/*.bmp"))
    masks_zp = sorted(glob(f"{path}/GT_ZP/*.bmp"))
    print(len(masks_zp), len(masks_icm), len(masks_te), len(images))

    # DATA SPLIT
    split_size=round(0.1*len(images))
    split_size2=round(0.1*len(images))

    train_i, test_i, train_te, test_te, train_zp, test_zp ,train_icm, test_icm = train_test_split(images,masks_te,masks_zp,masks_icm, test_size=split_size, random_state=42)
    train_i, valid_i, train_te, valid_te, train_zp, valid_zp ,train_icm, valid_icm = train_test_split(train_i,train_te,train_zp,train_icm, test_size=split_size2, random_state=42)

    print("train, valid, test")
    print(len(train_zp), len(valid_zp),len(test_zp))

    return (train_i, test_i, valid_i), (train_zp, test_zp, valid_zp), (train_icm, test_icm, valid_icm), (train_te, test_te, valid_te)

In [None]:
def read_this(image_file):

    image_src = cv2.imread(image_file)
    image_src = cv2.cvtColor(image_src, cv2.COLOR_BGR2GRAY)

    return image_src

In [None]:
def GT_blastocoel(maskA,maskB,maskC,mergedmasks):
    #Find the blastocoel GT through using the ICM, ZP and TE

    merge1=cv2.add(maskA,maskB)
    part_merged=cv2.add(merge1,maskC)
    GT_C=mergedmasks+part_merged
    GT_C[GT_C<255]=0

    return GT_C

In [None]:
def addmasks(maskA, maskB, maskC):
    # Full cell mask
    merge1 = cv2.add(maskA, maskB)
    merged = cv2.add(merge1, maskC)
    merged = 255 - merged
    th, im_th = cv2.threshold(merged, 220, 255, cv2.THRESH_BINARY_INV);
    im_floodfill = im_th.copy()
    h, w = im_th.shape[:2]
    mask = np.zeros((h + 2, w + 2), np.uint8)
    cv2.floodFill(im_floodfill, mask, (0, 0), 255);
    im_floodfill_inv = cv2.bitwise_not(im_floodfill)
    result = im_th | im_floodfill_inv
    merged = result
    return merged

In [None]:
def generate_mask(zp, icm, te):
  z=zp.copy()
  i=icm.copy()
  t=te.copy()
  z[z == 255] = 32
  i[i == 255] = 64
  t[t == 255] = 128
  merge1 = cv2.add(z, i)
  merged = cv2.add(merge1, t)
  merged = 255 - merged
  return merged

In [None]:
def resize_enhance(i, zp, icm, te,f,l,b):
    W = 256
    H = 256

    i = cv2.resize(i, (W, H))
    zp = cv2.resize(zp, (W, H))
    icm = cv2.resize(icm, (W, H))
    te = cv2.resize(te, (W, H))
    f = cv2.resize(f, (W, H))
    b = cv2.resize(b, (W, H))
    l = cv2.resize(l, (W, H))
    i = i.astype('uint8')

    return i, zp, icm, te, f,l,b

In [None]:
import csv
from google.colab.patches import cv2_imshow

In [None]:
def augment_data(images, masks_zp, masks_icm, masks_te, new_path, train)

    print(len(images))
    for idx, (i, zp, icm, te) in tqdm(enumerate(zip(images, masks_zp, masks_icm, masks_te)), total=len(images)):
        name = i.split("/")[-1].split(".")[0]
        i = read_this(i)
        zp = read_this(zp)
        icm = read_this(icm)
        te = read_this(te)

        mergedmask = addmasks(zp, icm, te)
        b = GT_blastocoel(te,zp,icm,mergedmask)
        label_mask = generate_mask(zp, icm, te)

        width, height = i.shape[0], i.shape[1]

        if train:

            aug = Compose([

                HorizontalFlip(p=1)

            ], additional_targets={'mask_zp': 'mask', 'mask_icm': 'mask', 'mask_te': 'mask','full':'mask', 'label':'mask', 'mask_blas':'mask'})

            augmented = aug(image=i, mask_zp=zp, mask_icm=icm, mask_te=te, full=mergedmask, label= label_mask, mask_blas= b)
            i1 = augmented['image']
            zp1 = augmented['mask_zp']
            icm1 = augmented['mask_icm']
            te1 = augmented['mask_te']
            f1 = augmented['full']
            l1 = augmented['label']
            b1 = augmented['mask_blas']

            aug = Compose([

                VerticalFlip(p=1)

            ], additional_targets={'mask_zp': 'mask', 'mask_icm': 'mask', 'mask_te': 'mask','full':'mask', 'label':'mask', 'mask_blas':'mask'})

            augmented = aug(image=i, mask_zp=zp, mask_icm=icm, mask_te=te, full=mergedmask, label= label_mask, mask_blas= b)
            i2 = augmented['image']
            zp2 = augmented['mask_zp']
            icm2 = augmented['mask_icm']
            te2 = augmented['mask_te']
            f2 = augmented['full']
            l2 = augmented['label']
            b2 = augmented['mask_blas']

            aug = Compose([

                Transpose(1)

            ], additional_targets={'mask_zp': 'mask', 'mask_icm': 'mask', 'mask_te': 'mask','full':'mask', 'label':'mask', 'mask_blas':'mask'})

            augmented = aug(image=i, mask_zp=zp, mask_icm=icm, mask_te=te, full=mergedmask, label= label_mask, mask_blas= b)
            i3 = augmented['image']
            zp3 = augmented['mask_zp']
            icm3 = augmented['mask_icm']
            te3 = augmented['mask_te']
            f3 = augmented['full']
            l3 = augmented['label']
            b3 = augmented['mask_blas']
            aug = Compose([

                HorizontalFlip(p=1)

            ], additional_targets={'mask_zp': 'mask', 'mask_icm': 'mask', 'mask_te': 'mask','full':'mask', 'label':'mask', 'mask_blas':'mask'})


            augmented = aug(image=i3, mask_zp=zp3, mask_icm=icm3, mask_te=te3, full=f3, label= l3, mask_blas = b3)
            i4 = augmented['image']
            zp4 = augmented['mask_zp']
            icm4 = augmented['mask_icm']
            te4 = augmented['mask_te']
            f4 = augmented['full']
            l4 = augmented['label']
            b4 = augmented['mask_blas']


            I = [i, i1, i2, i3,i4]
            ZP = [zp, zp1, zp2, zp3,zp4]
            ICM = [icm, icm1, icm2, icm3,icm4]
            TE = [te, te1, te2, te3,te4]
            F = [mergedmask, f1, f2, f3, f4]
            L = [label_mask, l1, l2, l3, l4]
            B = [b, b1, b2, b3, b4]

        else:
            I = [i]
            ZP = [zp]
            ICM = [icm]
            TE = [te]
            F = [mergedmask]
            L= [label_mask]
            B = [b]

        index = 0
        real_size = []
        for i_aug, zp_aug, icm_aug, te_aug, f_aug,l_aug, b_aug in zip(I, ZP, ICM, TE,F,L,B):

            i_w, i_h = i_aug.shape[0], i_aug.shape[1]

            i_res, zp_res, icm_res, te_res, f_res,l_res, b_res = resize_enhance(i_aug, zp_aug, icm_aug, te_aug,f_aug,l_aug, b_aug)


            if len(I) == 1:
                i_name = f"{name}.bmp"
                zp_name = f"{name} ZP_Mask.bmp"
                icm_name = f"{name} ICM_Mask.bmp"
                te_name = f"{name} TE_Mask.bmp"
                f_name = f"{name} full.bmp"
                roi_name = f"{name} ROI.bmp"
                l_name = f"{name} label.bmp"
                b_name = f"{name} blastocoel.bmp"
            else:
                i_name = f"{name}_{index}.bmp"
                zp_name = f"{name} ZP_Mask_{index}.bmp"
                icm_name = f"{name} ICM_Mask_{index}.bmp"
                te_name = f"{name} TE_Mask_{index}.bmp"
                f_name = f"{name} full_{index}.bmp"
                roi_name = f"{name} ROI_{index}.bmp"
                l_name = f"{name} label_{index}.bmp"
                b_name = f"{name} blastocoel_{index}.bmp"

            i_path = os.path.join(new_path, "images/", i_name)
            zp_path = os.path.join(new_path, "GT_ZP/", zp_name)
            icm_path = os.path.join(new_path, "GT_ICM/", icm_name)
            te_path = os.path.join(new_path, "GT_TE/", te_name)
            f_path = os.path.join(new_path, "GT_full/", f_name)
            roi_path = os.path.join(new_path, "ROI/", roi_name)
            l_path = os.path.join(new_path, "label/", l_name)
            b_path = os.path.join(new_path, "GT_blastocoel/", b_name)

            real_size.append([i_name, i_w, i_h])

            with open('/content/BlastsOnline/BlastsOnline/new_data_3z_tvt70/real_size.csv', 'a+', newline='\n') as file:
              writer = csv.writer(file)
              writer.writerows(real_size)


            roi = f_res * i_res
            cv2.imwrite(i_path, i_res)
            cv2.imwrite(roi_path, roi)
            cv2.imwrite(zp_path, zp_res)
            cv2.imwrite(icm_path, icm_res)
            cv2.imwrite(te_path, te_res)
            cv2.imwrite(f_path, f_res)
            cv2.imwrite(l_path, l_res)
            cv2.imwrite(b_path, b_res)

            index += 1


In [None]:
#Load the dataset
(train_i, test_i, valid_i), (train_zp, test_zp, valid_zp), (train_icm, test_icm, valid_icm), (train_te, test_te, valid_te)= load_data(drive_path)

create_dir(drive_path+"/new_data_3z_tvt70/")

test_path = drive_path+"/new_data_3z_tvt70/test/"
valid_path= drive_path+"/new_data_3z_tvt70/valid/"
train_path= drive_path+"/new_data_3z_tvt70/train/"

create_dir(train_path)
create_dir(test_path)
create_dir(valid_path)

create_dir(test_path+"images/")
create_dir(train_path+"images/")
create_dir(valid_path+"images/")

create_dir(test_path+"GT_ZP/")
create_dir(train_path+"GT_ZP/")
create_dir(valid_path+"GT_ZP/")

create_dir(test_path+"GT_ICM/")
create_dir(train_path+"GT_ICM/")
create_dir(valid_path+"GT_ICM/")

create_dir(test_path+"GT_TE/")
create_dir(train_path+"GT_TE/")
create_dir(valid_path+"GT_TE/")

create_dir(test_path+"GT_full/")
create_dir(train_path+"GT_full/")
create_dir(valid_path+"GT_full/")

create_dir(train_path+"label/")
create_dir(test_path+"label/")
create_dir(valid_path+"label/")

create_dir(test_path+"ROI/")
create_dir(train_path+"ROI/")
create_dir(valid_path+"ROI/")

create_dir(test_path+"GT_blastocoel/")
create_dir(train_path+"GT_blastocoel/")
create_dir(valid_path+"GT_blastocoel/")

augment_data(train_i, train_zp, train_icm, train_te, train_path, train=True)
augment_data(test_i, test_zp, test_icm, test_te, test_path, train=False)
augment_data(valid_i, valid_zp, valid_icm, valid_te, valid_path, train=False)

249 249 249 249
train, valid, test
199 25 25
199


100%|██████████| 199/199 [00:05<00:00, 33.18it/s]


25


100%|██████████| 25/25 [00:00<00:00, 54.56it/s]


25


100%|██████████| 25/25 [00:00<00:00, 49.73it/s]
