In [1]:
import os
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
import imageio
from albumentations import HorizontalFlip, VerticalFlip, ElasticTransform, GridDistortion, OpticalDistortion
from utils import load_data

In [2]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def augment_data(images,masks,save_path,i_format,m_format,augment=True):
    H = W = 512
    
    for idx, (x, y) in tqdm(enumerate(zip(images,masks))):
        name = x.split("\\")[-1].split('.')[0]
        
        if i_format == 'gif':
            x = imageio.mimread(x)[0]
        else:
            x = cv2.imread(x, cv2.IMREAD_COLOR)
        if m_format == 'gif':
            y = imageio.mimread(y)[0]
        else:
            y = cv2.imread(y, cv2.IMREAD_COLOR)
            
        
        if augment:
            aug = HorizontalFlip(p=1.0)
            augmented = aug(image=x,mask=y)
            x1 = augmented['image']
            y1 = augmented['mask']
            
            aug = VerticalFlip(p=1.0)
            augmented = aug(image=x,mask=y)
            x2 = augmented['image']
            y2 = augmented['mask']
            
            aug = ElasticTransform(p=1.0)
            augmented = aug(image=x,mask=y)
            x3 = augmented['image']
            y3 = augmented['mask']
            
            aug = GridDistortion(p=1.0)
            augmented = aug(image=x,mask=y)
            x4 = augmented['image']
            y4 = augmented['mask']
            
            aug = OpticalDistortion(p=1.0)
            augmented = aug(image=x,mask=y)
            x5 = augmented['image']
            y5 = augmented['mask']
            
            X =[x,x1,x2,x3,x4,x5]
            Y = [y,y1,y2,y3,y4,y5]
        else:
            X = [x]
            Y = [y]
        
        for index,(i,m) in enumerate(zip(X,Y)):
            i = cv2.resize(i, (W,H))
            m = cv2.resize(m, (W,H))
            
            if len(X) == 1:
                tmp_image_name = f"{name}.jpg"
                tmp_mask_name = f"{name}.jpg"
            else:
                tmp_image_name = f"{name}_{index}.jpg"
                tmp_mask_name = f"{name}_{index}.jpg"
            
            image_path = os.path.join(save_path,"image",tmp_image_name)
            mask_path = os.path.join(save_path,"mask",tmp_mask_name)
            cv2.imwrite(image_path,i)
            cv2.imwrite(mask_path,m)

In [3]:
data_path_1 = "DRIVE"
data_path_2 = "CHASE"
data_path_3 = "RETA"
data_path_4 = "HRF"
data_path_5 = "STARE"

In [4]:
save_path_train = "new_data_2/training"
save_path_test = "new_data_2/test"

# Augment DRIVE Dataset

In [5]:
train_X, train_Y, test_X, test_Y = load_data(data_path_1,
                                             images_folder="images",
                                             masks_folder="1st_manual"
                                             ,i_format='tif',
                                             m_format='gif')

In [6]:
create_dir("new_data_2/training/image")
create_dir("new_data_2/training/mask")
create_dir("new_data_2/test/image")
create_dir("new_data_2/test/mask")

In [7]:
augment_data(train_X,train_Y,save_path_train,'tif','gif',True)

20it [00:09,  2.11it/s]


In [8]:
augment_data(test_X,test_Y,save_path_test,'tif','gif',False)

20it [00:00, 60.72it/s]


# AUGMENT CHASE Dataset

In [9]:
train_X, train_Y, test_X, test_Y = load_data(data_path_2,
                                             images_folder="images",
                                             masks_folder="masks"
                                             ,i_format='jpg',
                                             m_format='png')

In [10]:
augment_data(train_X,train_Y,save_path_train,'jpg','png',True)

28it [00:35,  1.28s/it]


# EXTRACT STARE DATASET FOR TESTING

In [11]:
train_X, train_Y, test_X, test_Y = load_data(data_path_3,
                                             images_folder="images",
                                             masks_folder="masks"
                                             ,i_format='jpg',
                                             m_format='png')
augment_data(train_X,train_Y,save_path_train,'jpg','png',True)

54it [01:17,  1.43s/it]


In [12]:
train_X, train_Y, test_X, test_Y = load_data(data_path_4,
                                             images_folder="images",
                                             masks_folder="masks"
                                             ,i_format='jpg',
                                             m_format='tif')
augment_data(train_X,train_Y,save_path_train,'jpg','tif',True)

45it [07:56, 10.59s/it]


In [13]:
train_X, train_Y, test_X, test_Y = load_data(data_path_5,
                                             images_folder="images",
                                             masks_folder="masks"
                                             ,i_format='ppm',
                                             m_format='ppm')


In [14]:
augment_data(test_X,test_Y,save_path_train,'ppm','ppm',True)

20it [00:12,  1.63it/s]
