# Augmentasi Multiple Layer

In [2]:
import os 
import cv2 
from tqdm import tqdm 
from glob import glob 
from albumentations import Transpose, RandomRotate90, GridDistortion, HorizontalFlip, VerticalFlip
import rasterio as rio 
from osgeo import gdal
import numpy as np 
from numpy import moveaxis
from numpy import asarray 
from numpy import expand_dims 
from osgeo import gdal_array 

  check_for_updates()


### Load data

In [None]:
def load_data(path): 
    images = sorted(glob(os.path.join(path, "Image/*")))      
    masks = sorted(glob(os.path.join(path, "Mask/*"))) 
    return images, masks 

In [None]:
#Buat direktori penyimpanan
def create_dir(path): 
    if not os.path.exists(path): 
        os.makedirs(path) 

### Fungsi Augmentasi Data

In [None]:
def augment_data(images, masks, save_path, augment=True): 
    H = 1024
    W = 1024

    for x, y in tqdm(zip(images, masks), total=len(images)): 
        name = x.split("/")[-1].split(".") 

        """ Extracting the name and extension of the image and the mask. """ 

        image_name = name[0] 
        image_extn = name[1] 

        
        name = y.split("/")[-1].split(".") 
        mask_name = name[0] 
        mask_extn = name[1] 


        """ Reading image and mask. """ 

        xa = gdal.Open(x) 
        t0 = rio.open(x) 
        t1 = rio.open(y) 
        a1 = xa.GetRasterBand(1).ReadAsArray() 
        a2 = xa.GetRasterBand(2).ReadAsArray() 
        a3 = xa.GetRasterBand(3).ReadAsArray()
        a4 = xa.GetRasterBand(4).ReadAsArray()

        x = np.dstack((a1,a2,a3,a4))
        y = gdal.Open(y).ReadAsArray() 
        y = expand_dims(y, axis=2) 


        """ metadata """ 

        meta2 = t0.meta 
        meta3 = t1.meta 
        

        meta2.update({'driver':'GTiff', 
                 'width':1024,
                 'height':1024,
                 'count':4, 
                 'dtype':'float32', 
                 'nodata':-32768}) 

         
        meta3.update({'driver':'GTiff', 
                 'width':1024,
                 'height':1024,
                 'count':1, 
                 'dtype':'int16', 
                 'nodata':-32768}) 
        

        """ Augmentation """ 

        if augment == True: 
            
            aug = Transpose(p=1.0) 
            augmented = aug(image=x, mask=y) 
            x1 = augmented["image"] 
            y1 = augmented['mask'] 

            aug = RandomRotate90(p=1.0) 
            augmented = aug(image=x, mask=y) 
            x2 = augmented["image"] 
            y2 = augmented['mask'] 


            aug = GridDistortion(p=1.0) 
            augmented = aug(image=x, mask=y) 
            x3 = augmented["image"] 
            y3 = augmented['mask'] 

 
            aug = HorizontalFlip(p=1.0) 
            augmented = aug(image=x, mask=y) 
            x4 = augmented["image"] 
            y4 = augmented['mask'] 

 
            aug = VerticalFlip(p=1.0) 
            augmented = aug(image=x, mask=y) 
            x5 = augmented["image"] 
            y5 = augmented['mask'] 

            save_images = [x1, x2, x3, x4, x5] 
            save_masks =  [y1, y2, y3, y4, y5] 

 
        else: 

            save_images = [x] 
            save_masks = [y] 

 
        """ Saving the image and mask. """ 

        idx = 0 
        for i, m in zip(save_images, save_masks): 
            i = cv2.resize(i, (W, H)) 
            m = cv2.resize(m, (W, H)) 


            if len(images) == 1: 
                tmp_img_name = f"{image_name}.{image_extn}" 
                tmp_mask_name = f"{mask_name}.{mask_extn}" 
 
            else: 
                tmp_img_name = f"{image_name}_{idx}.{image_extn}" 
                tmp_mask_name = f"{mask_name}_{idx}.{mask_extn}" 

 
            image_path = os.path.join(save_path, "image", tmp_img_name) 
            mask_path = os.path.join(save_path, "mask", tmp_mask_name) 
            

            i2 = np.rollaxis(i, axis = 2) 
            with rio.open(image_path,'w',**meta2) as dst: 
                dst.write(i2) 
            with rio.open(mask_path,'w',**meta3) as dst: 
                dst.write(m, 1)  
            idx += 1 

### Setting folder penyimpanan data

In [None]:
path = "D:/Kanwil Riau/Updating NPGT Perkebunan/OLAH TIM/RIAU/Run" 
images, masks = load_data(path) 
print(f"Original Images: {len(images)} - Original Masks: {len(masks)}") 

In [None]:
### Proses augmentasi
augment_data(images, masks, "test", augment=True) 