In [1]:
import glob
import os
import shutil

import rasterio
import torch
from torchvision.transforms import RandomRotation, RandomHorizontalFlip
from tqdm import tqdm

In [2]:
basepath = "D:/users/holgerv/Ditches"

In [3]:
# Finetuning data directory
finetuning_dir = f"{basepath}/working/deep_learning/data/finetuning"

In [4]:
# Input HPMF files
fp_hpmf_list = [fp.replace("\\", "/") for fp in sorted(glob.glob(f"{finetuning_dir}/training/hpmf/*.tif"))]
hpmf_files = [os.path.basename(fp_hpmf) for fp_hpmf in fp_hpmf_list]

# Input labels
fp_labels_list = [fp.replace("\\", "/") for fp in sorted(glob.glob(f"{finetuning_dir}/training/labels/*.tif"))]
labels_files = [os.path.basename(fp_labels) for fp_labels in fp_labels_list]

# Loop over labels and collect matching image pairs
image_pairs = {}
for i in range(len(labels_files)):
    labels_file = labels_files[i]
    if labels_file in hpmf_files:
        fp_hpmf = f"{finetuning_dir}/training/hpmf/{labels_file}"
        fp_labels = f"{finetuning_dir}/training/labels/{labels_file}"
        image_pairs[fp_hpmf] = fp_labels

In [5]:
# Output directories
out_dir = f"{finetuning_dir}/training_augmented"
if os.path.exists(out_dir):
    shutil.rmtree(out_dir)
os.mkdir(out_dir)
out_dir_hpmf = f"{out_dir}/hpmf"
os.mkdir(out_dir_hpmf)
out_dir_labels = f"{out_dir}/labels"
os.mkdir(out_dir_labels)

In [6]:
# Generate augmented data from input image
def generate_augmented_data(fp: str, out_dir: str):
    
    with rasterio.open(fp) as src:
        
        # Read image
        img = src.read()
        
        # Convert to tensor
        img = torch.from_numpy(img)
        
        # Get profile
        out_profile = src.profile
        
        # Rotate image in 90 degree intervals and save the result
        degrees = [90, 180, 270]
        for degree in degrees:
            transform = RandomRotation(degrees=(degree, degree))
            img_rotated = transform(img)
            out_fp = f"{out_dir}/{os.path.basename(fp).split('.')[0]}_rot{degree}.tif"
            with rasterio.open(out_fp, "w", **out_profile) as dst:
                dst.write(img_rotated.squeeze(0).numpy(), 1)
            
        # Flip image horizontally and save the result
        transform = RandomHorizontalFlip(p=1)
        img_flipped = transform(img)
        out_fp = f"{out_dir}/{os.path.basename(fp).split('.')[0]}_flip.tif"
        with rasterio.open(out_fp, "w", **out_profile) as dst:
            dst.write(img_flipped.squeeze(0).numpy(), 1)
    
    return

In [7]:
%%time

for fp_hpmf in tqdm(fp_hpmf_list, position=0, leave=True):
    
    # Get corresponding labels
    fp_labels = image_pairs[fp_hpmf]
    
    # Copy HPMF to new directory
    shutil.copy(fp_hpmf, out_dir_hpmf)
    
    # Generate augmented data for HPMF
    generate_augmented_data(fp_hpmf, out_dir_hpmf)
    
    # Copy labels to new directory
    shutil.copy(fp_labels, out_dir_labels)
    
    # Generate augmented data for HPMF
    generate_augmented_data(fp_labels, out_dir_labels)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 58/58 [00:12<00:00,  4.55it/s]

CPU times: total: 1min 37s
Wall time: 12.7 s



