In [None]:
import copy
import cv2
import os
import json
import random
import numpy as np
import shutil
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from tqdm.notebook import tqdm


import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.polys import Polygon, PolygonsOnImage
from tqdm.notebook import tqdm

from PIL import Image
np.sctypes = {
    "float": [np.float16, np.float32, np.float64],
    "int": [np.int8, np.int16, np.int32, np.int64],
    "uint": [np.uint8, np.uint16, np.uint32, np.uint64],
    "complex": [np.complex64, np.complex128]
}

In [None]:
# import sys
#sys.path.append('/Users/peiyu.li/CS231N')

apklot_training_path = '/Users/joshneutel/code/APKLOT/1. Satellite/Dataset/World/training'
apklot_testing_path = '/Users/joshneutel/code/APKLOT/1. Satellite/Dataset/World/testing'

train_output_path = '/Users/joshneutel/Desktop/APKLOT/training'
test_output_path = '/Users/joshneutel/Desktop/APKLOT/testing'

aug_output_path = '/Users/joshneutel/Desktop/augmented'
final_aug_output_path = '/Users/joshneutel/Desktop/final_augmented'

## Visualize APKLOT

In [None]:
RAW_FILES = os.listdir(apklot_training_path)
os.makedirs(aug_output_path, exist_ok=True)
RAW_FILES = [
    file.replace(".json", "").replace(".png", "")
    for file in RAW_FILES
]
RAW_FILES = list(set(RAW_FILES))  # Remove duplicates
print(f"Total number of RAW_FILES: {len(RAW_FILES)}")

In [None]:
def load_image(image_str, with_annotation=True):
    # Load image
    image = Image.open(f"{apklot_training_path}/{image_str}.png")

    # Load JSON annotations
    with open(f"{apklot_training_path}/{image_str}.json") as f:
        data = json.load(f)

    # Plot image
    fig, ax = plt.subplots()
    ax.imshow(image)

    if with_annotation:
        # Plot polygons
        for shape in data['shapes']:
            points = shape['points']
            polygon = patches.Polygon(points, closed=True, edgecolor='red', facecolor='red', linewidth=2, alpha=0.5)
            ax.add_patch(polygon)

    plt.axis('off')
    return plt

In [None]:
file = RAW_FILES[0]
plt = load_image(file, with_annotation=True)
plt.show()

# Data Augmentation

In [None]:
# Set random seed for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
ia.seed(RANDOM_SEED)

# Number of augmented versions to create per original image
SAMPLES_PER_IMAGE = 100

## Helper functions 

In [None]:
# Define the data augmentation sequence for jittering
def create_aug_seq():
    """Create an imgaug augmentation sequence for jittering"""
    return iaa.Sequential([
        iaa.Crop(px=(0, 50)),  # Random crop between 0-50 pixels
        iaa.Fliplr(0.5),       # Horizontal flip with 50% probability
        iaa.Flipud(0.5),       # Vertical flip with 50% probability
        iaa.Affine(rotate=(-45, 45))  # Random rotation between -45 and 45 degrees
    ])

def load_image_and_json(image_path, json_path):
    """Load image and its JSON annotation"""
    image = Image.open(image_path)
    image_array = np.array(image)

    with open(json_path, 'r') as f:
        json_data = json.load(f)

    return image_array, json_data

def extract_polygons(json_data):
    """Extract polygons from JSON data in imgaug format"""
    polygons = []

    for shape in json_data['shapes']:
        points = shape['points']
        # Convert points to imgaug polygon format
        polygon = Polygon(points)
        polygons.append(polygon)

    return polygons

def apply_augmentation(image, polygons, seq):
    """Apply augmentation to image and polygons"""
    # Create PolygonsOnImage object
    polys_on_image = PolygonsOnImage(polygons, shape=image.shape)

    # Apply augmentation
    image_aug, polys_aug = seq(image=image, polygons=polys_on_image)

    return image_aug, polys_aug

def update_json_with_augmented_polygons(json_data, polys_aug):
    """Update JSON data with augmented polygons"""
    aug_json = copy.deepcopy(json_data)

    for i, shape in enumerate(aug_json['shapes']):
        if i < len(polys_aug):
            shape['points'] = polys_aug[i].exterior.tolist()

    return aug_json

def save_augmented_image_and_json(image_aug, json_aug, output_img_path, output_json_path):
    """Save augmented image and JSON data"""
    # Save image
    Image.fromarray(image_aug).save(output_img_path)

    # Save JSON
    with open(output_json_path, 'w') as f:
        json.dump(json_aug, f, indent=2)

# Define function to load and visualize images with annotations
def load_image(image_str, path, with_annotation=True):
    # Load image
    image = Image.open(f"{path}/{image_str}.png")

    # Load JSON annotations
    with open(f"{path}/{image_str}.json") as f:
        data = json.load(f)

    # Plot image
    fig, ax = plt.subplots()
    ax.imshow(image)

    if with_annotation:
        # Plot polygons
        for shape in data['shapes']:
            points = shape['points']
            polygon = patches.Polygon(points, closed=True, edgecolor='red', facecolor='red', linewidth=2, alpha=0.5)
            ax.add_patch(polygon)

    plt.axis('off')
    return plt

# Function to visualize original and augmented images
def visualize_comparison(original_img, original_json, augmented_img, augmented_json, title="Original vs Augmented"):
    """Visualize original and augmented images with their annotations"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    # Original image
    ax1.imshow(original_img)
    for shape in original_json['shapes']:
        points = shape['points']
        polygon = patches.Polygon(points, closed=True, edgecolor='red', facecolor='red', linewidth=2, alpha=0.5)
        ax1.add_patch(polygon)
    ax1.set_title("Original")
    ax1.axis('off')

    # Augmented image
    ax2.imshow(augmented_img)
    for shape in augmented_json['shapes']:
        points = shape['points']
        polygon = patches.Polygon(points, closed=True, edgecolor='red', facecolor='red', linewidth=2, alpha=0.5)
        ax2.add_patch(polygon)
    ax2.set_title("Augmented")
    ax2.axis('off')

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Main function to perform data augmentation
def augment_training_dataset(input_path, aug_output_path, samples_per_image=100):
    """Augment the training dataset with jittering transformations"""
    # Get list of all training files
    train_files = [f.replace('.json', '') for f in os.listdir(input_path)
                   if f.endswith('.json')]

    print(f"Found {len(train_files)} original training images")
    print(f"Generating {samples_per_image} augmented versions for each")
    print(f"Target total: {len(train_files) * (samples_per_image)} augmented images")

    # Copy original files to output directory
    for file_id in train_files:
        if not os.path.exists(os.path.join(aug_output_path, f"{file_id}.png")):
            # Copy original image
            shutil.copy(
                os.path.join(input_path, f"{file_id}.png"),
                os.path.join(aug_output_path, f"{file_id}.png")
            )

        if not os.path.exists(os.path.join(aug_output_path, f"{file_id}.json")):
            # Copy original annotation
            shutil.copy(
                os.path.join(input_path, f"{file_id}.json"),
                os.path.join(aug_output_path, f"{file_id}.json")
            )

    # Create a list to store successful augmentations
    augmented_files = []

    # Process each training image
    for i, file_id in enumerate(tqdm(train_files)):
        img_path = os.path.join(input_path, f"{file_id}.png")
        json_path = os.path.join(input_path, f"{file_id}.json")

        if not (os.path.exists(img_path) and os.path.exists(json_path)):
            print(f"Warning: Missing files for {file_id}, skipping...")
            continue

        # Load image and annotation
        image, json_data = load_image_and_json(img_path, json_path)

        # Extract polygons
        polygons = extract_polygons(json_data)

        # Create augmented versions
        for j in range(samples_per_image):
            try:
                # Create a new augmentation sequence for each sample
                seq = create_aug_seq().to_deterministic()

                # Apply augmentation
                image_aug, polys_aug = apply_augmentation(image, polygons, seq)

                # Update JSON with augmented polygons
                json_aug = update_json_with_augmented_polygons(json_data, polys_aug)

                # Generate augmented file ID
                aug_id = f"{file_id}_aug_{j+1}"

                # Save augmented image and JSON
                aug_img_path = os.path.join(aug_output_path, f"{aug_id}.png")
                aug_json_path = os.path.join(aug_output_path, f"{aug_id}.json")

                save_augmented_image_and_json(image_aug, json_aug, aug_img_path, aug_json_path)

                augmented_files.append(aug_id)

                # Visualize the first augmentation of the first few images
                if i < 3 and j == 0:
                    visualize_comparison(image, json_data, image_aug, json_aug,
                                         title=f"Original vs Augmented: {file_id}")

            except Exception as e:
                print(f"Error augmenting {file_id} (sample {j+1}): {e}")

    print(f"Data augmentation complete!")
    print(f"Created {len(augmented_files)} augmented images")
    print(f"Total images in output directory: {len(train_files) + len(augmented_files)}")

    return augmented_files

## Augment

In [None]:
# import shutil

# # Now let's augment the training dataset
augmented_files = augment_training_dataset(apklot_training_path, aug_output_path, SAMPLES_PER_IMAGE)

# # Count files in output directory to verify
output_files = [f.replace('.png', '') for f in os.listdir(aug_output_path) if f.endswith('.png')]
print(f"Total files in output directory: {len(output_files)}")
print(f"Original training images: {len([f for f in output_files if not '_aug_' in f])}")
print(f"Augmented training images: {len([f for f in output_files if '_aug_' in f])}")

# Visualize a few examples
print("\nVisualization examples:")
if len(output_files) > 0:
    # Show an original image
    original_sample = next((f for f in output_files if not '_aug_' in f), None)
    if original_sample:
        print("Original sample:")
        plt = load_image(original_sample, aug_output_path)
        plt.show()
    # Show an augmented image
    augmented_sample = next((f for f in output_files if '_aug_' in f), None)
    if augmented_sample:
        print("Augmented sample:")
        plt = load_image(augmented_sample, aug_output_path)
        plt.show()

In [None]:
n_train_actual = len([f for f in os.listdir(aug_output_path)])
print(f"Number of training data: {n_train_actual}")

# Convert data to png

In [None]:
def convert_poly_to_mask(polygons, image_shape):
    mask = np.zeros(image_shape, dtype=np.uint8)
    for polygon in polygons:
        coords = np.array(polygon.exterior, dtype=np.int32)
        coords = coords.reshape((-1, 1, 2))  # Required shape for cv2.fillPoly
        cv2.fillPoly(mask, [coords], 1)
    return mask

In [None]:
def visualize_overlay(image_path, mask_path, alpha=0.4, mask_color=(255, 0, 0)):
    # Load image and mask
    image = Image.open(image_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")  # grayscale

    # Convert to NumPy
    image_np = np.array(image)
    mask_np = np.array(mask)

    # Create a color mask
    color_mask = np.zeros_like(image_np)
    color_mask[mask_np > 0] = mask_color

    # Overlay
    overlay = image_np.copy()
    overlay = np.where(mask_np[..., None] > 0,
                       (1 - alpha) * image_np + alpha * color_mask,
                       image_np).astype(np.uint8)

    # Display
    plt.figure(figsize=(8, 8))
    plt.imshow(overlay)
    plt.axis("off")
    plt.title("Image with Mask Overlay")
    return plt

## APKLOT training to png 

In [None]:
TRAIN_FILES = os.listdir(apklot_training_path)
TRAIN_FILES = [
    file.replace(".json", "").replace(".png", "")
    for file in TRAIN_FILES
]
TRAIN_FILES = list(set(TRAIN_FILES))  # Remove duplicates

In [None]:
len(TRAIN_FILES)

In [None]:
for i in range(len(TRAIN_FILES)):
    file_id = TRAIN_FILES[i]
    try:
        image = Image.open(f"{apklot_training_path}/{file_id}.png")
    except:
        print(f"Issue with {file_id}")
        continue

    # Load JSON annotations
    try:
        with open(f"{apklot_training_path}/{file_id}.json") as f:
            json_data = json.load(f)
    except:
        print(f"Issue with {file_id}")
        continue
    
    polygons = extract_polygons(json_data)
    width, height = image.size
    image_shape = (height, width)
    mask_array = convert_poly_to_mask(polygons, image_shape)
    mask_image = Image.fromarray(mask_array).convert("L")
    

    image.save(f"{train_output_path}/images/{i}.png")
    mask_image.save(f"{train_output_path}/masks/{i}.png")

In [None]:
plt = visualize_overlay(
    f"{train_output_path}/images/{370}.png", 
    f"{train_output_path}/masks/{370}.png",
    alpha=0.7, mask_color=(0, 0, 0)
)

## APKLOT testing to png 

In [None]:
TEST_FILES = os.listdir(apklot_testing_path)
TEST_FILES = [
    file.replace(".json", "").replace(".png", "")
    for file in TEST_FILES
]
TEST_FILES = list(set(TEST_FILES))  # Remove duplicates

In [None]:
for i in range(len(TEST_FILES)):
    file_id = TEST_FILES[i]
    try:
        image = Image.open(f"{apklot_testing_path}/{file_id}.png")
    except:
        print(f"Issue with {file_id}")
        continue

    # Load JSON annotations
    try:
        with open(f"{apklot_testing_path}/{file_id}.json") as f:
            json_data = json.load(f)
    except:
        print(f"Issue with {file_id}")
        continue
    
    polygons = extract_polygons(json_data)
    width, height = image.size
    image_shape = (height, width)
    mask_array = convert_poly_to_mask(polygons, image_shape)
    mask_image = Image.fromarray(mask_array).convert("L")
    

    image.save(f"{test_output_path}/images/{i}.png")
    mask_image.save(f"{test_output_path}/masks/{i}.png")

In [None]:
plt = visualize_overlay(
    f"{test_output_path}/images/{80}.png", 
    f"{test_output_path}/masks/{80}.png",
    alpha=0.7, mask_color=(0, 0, 0)
)

## Augmentation to png

In [None]:
AUG_FILES = os.listdir(aug_output_path)
AUG_FILES = [
    file.replace(".json", "").replace(".png", "")
    for file in AUG_FILES
]
AUG_FILES = list(set(AUG_FILES))  # Remove duplicates

In [None]:
len(AUG_FILES)

In [None]:
END_NUM = 27488 + 1

In [None]:
for i in range(len(AUG_FILES)):
    file_id = AUG_FILES[i]
    try:
        image = Image.open(f"{aug_output_path}/{file_id}.png")
    except:
        print(f"Issue with {file_id}")
        continue

    # Load JSON annotations
    try:
        with open(f"{aug_output_path}/{file_id}.json") as f:
            json_data = json.load(f)
    except:
        print(f"Issue with {file_id}")
        continue
    
    polygons = extract_polygons(json_data)
    width, height = image.size
    image_shape = (height, width)
    mask_array = convert_poly_to_mask(polygons, image_shape)
    mask_image = Image.fromarray(mask_array).convert("L")
    

    image.save(f"{final_aug_output_path}/images/{i + END_NUM}.png")
    mask_image.save(f"{final_aug_output_path}/masks/{i + END_NUM}.png")
    
    os.remove(f"{aug_output_path}/{file_id}.png")
    os.remove(f"{aug_output_path}/{file_id}.json")
        
    if (i % 1000 == 0):
        print(i)

In [None]:
#plt = visualize_overlay(
#    f"{final_aug_output_path}/images/{10000}.png", 
#    f"{final_aug_output_path}/masks/{10000}.png",
#    alpha=0.7, mask_color=(0, 0, 0)
#)

In [None]:
FINAL_FILES = os.listdir(f"{final_aug_output_path}/images/")
FINAL_FILES.remove('.DS_Store')
FINAL_FILES = [
    int(file.replace(".json", "").replace(".png", ""))
    for file in FINAL_FILES
]
FINAL_FILES = list(set(FINAL_FILES))  # Remove duplicates
FINAL_FILES.sort()

In [None]:
metadata_path = os.path.join(final_aug_output_path, "metadata.jsonl")
with open(metadata_path, "w") as metadata_file:
    for file in FINAL_FILES:
        # Write entry to JSONL
        metadata_entry = {"image": f"images/{file}.png", "mask": f"masks/{file}.png"}
        metadata_file.write(json.dumps(metadata_entry) + "\n")

## Move to chunks of 10000

In [None]:
images_dir = os.path.join(final_aug_output_path, "images")
masks_dir = os.path.join(final_aug_output_path, "masks")
output_metadata_path = os.path.join(final_aug_output_path, "metadata_sharded.jsonl")
shard_size = 8000

In [None]:
def shard_path(file_index, folder_base):
    shard_id = (file_index // shard_size) * shard_size
    subfolder = os.path.join(folder_base, f"{shard_id:05d}")
    return subfolder

In [None]:
# Make subfolders and move files
with open(output_metadata_path, "w") as outfile:
    for i, filename in enumerate(sorted(os.listdir(images_dir))):
        if not filename.endswith(".png"):
            continue
        base = os.path.splitext(filename)[0]

        # Compute shard paths
        image_subdir = shard_path(int(base), images_dir)
        mask_subdir = shard_path(int(base), masks_dir)

        os.makedirs(image_subdir, exist_ok=True)
        os.makedirs(mask_subdir, exist_ok=True)

        # Move files
        src_image = os.path.join(images_dir, filename)
        dst_image = os.path.join(image_subdir, filename)
        shutil.move(src_image, dst_image)

        src_mask = os.path.join(masks_dir, filename)
        dst_mask = os.path.join(mask_subdir, filename)
        shutil.move(src_mask, dst_mask)

        # Write updated metadata line
        relative_image_path = os.path.relpath(dst_image, final_aug_output_path)
        relative_mask_path = os.path.relpath(dst_mask, final_aug_output_path)
        metadata_entry = {"image": relative_image_path, "mask": relative_mask_path}
        outfile.write(json.dumps(metadata_entry) + "\n")

        if i % 1000 == 0:
            print(f"Processed {i} files...")

In [None]:
def unshard_dataset(root_dir):
    images_dir = os.path.join(root_dir, "images")
    masks_dir = os.path.join(root_dir, "masks")
    print("Unsharding images...")
    for subdir in sorted(os.listdir(images_dir)):
        subdir_path = os.path.join(images_dir, subdir)
        if os.path.isdir(subdir_path):
            for file in os.listdir(subdir_path):
                if file.endswith(".png"):
                    src = os.path.join(subdir_path, file)
                    dst = os.path.join(images_dir, file)
                    shutil.move(src, dst)
            os.rmdir(subdir_path)  # remove empty subfolder

    print("Unsharding masks...")
    for subdir in sorted(os.listdir(masks_dir)):
        subdir_path = os.path.join(masks_dir, subdir)
        if os.path.isdir(subdir_path):
            for file in os.listdir(subdir_path):
                if file.endswith(".png"):
                    src = os.path.join(subdir_path, file)
                    dst = os.path.join(masks_dir, file)
                    shutil.move(src, dst)
            os.rmdir(subdir_path)

In [None]:
unshard_dataset(final_aug_output_path)