In [1]:
from PIL import Image
import os
import random
import math
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables import Keypoint, KeypointsOnImage

In [2]:
base_path = "./seg-14"
destination = "./seg-14-augmented"
crop_size = 224

In [3]:
seq = iaa.Sequential([
    iaa.Fliplr(0.5),  # Horizontal flip with 50% probability
    iaa.Flipud(0.5),  # Vertical flip with 50% probability
    iaa.OneOf([  # Either contrast normalization or color augmentation
        iaa.ContrastNormalization((0.75, 1.25)),
        iaa.Multiply((0.8, 1.2)),
        iaa.Grayscale(alpha=(0.0, 0.1))
    ]),
    iaa.OneOf([  # Either blur or noise transformation
        iaa.GaussianBlur(sigma=(0.0, 1)),
        iaa.AverageBlur(k=(2, 3)),
        iaa.MedianBlur(k=(3)),
        iaa.BilateralBlur(d=(3, 5), sigma_color=(10, 50), sigma_space=(10, 50)),
        iaa.AdditiveGaussianNoise(scale=(0.0, 0.01 * 255)),
        iaa.ElasticTransformation(alpha=2, sigma=1)
    ])
])

rotate = iaa.Affine(rotate=(-45, 45))

  warn_deprecated(msg, stacklevel=3)


In [4]:
def get_keypoints_from_line(line, height, width):
    data = line.strip().split()
    
    coords = list(map(float, data[1:]))
    
    points = [(coords[i], coords[i+1]) for i in range(0, len(coords), 2)]
    
    keypoints = [Keypoint(x=x*width, y=y*height) for x, y in points]

    return keypoints
def augment_image_and_keypoints(image, keypoints, sequence):
    image_np = np.array(image, dtype=np.uint8)

    keypointsOnImage = KeypointsOnImage(keypoints, shape=image_np.shape)
    
    image_aug_np, keypoints_aug = sequence(image=image_np, keypoints=keypointsOnImage)
    
    image_aug_pil = Image.fromarray(image_aug_np)
    
    return image_aug_pil, keypoints_aug

In [5]:
mask_paths = []
for root, dirs, files in os.walk(base_path):
    for file in files:
        if file.endswith(".txt"):
            mask_paths.append(os.path.join(root, file))

In [6]:
count = 0
image_and_paths = []
for root, dirs, files in os.walk(base_path):
    for file in files:
        if file.endswith(".jpg"):
            image_path = os.path.join(root, file)
            image = Image.open(image_path)
            mask_path = os.path.join(base_path, image_path.split("/")[2], "labels", os.path.basename(image_path).replace(".jpg", ".txt"))
            # mask_path = os.path.join(base_path, "labels", os.path.basename(image_path).replace(".jpg", ".txt"))
            try:
                with open(mask_path, "r") as mask_file:
                    lines = mask_file.readlines()
                    empty = len(lines) == 0
                    image_and_paths.append((image, lines, file, empty))
            except FileNotFoundError:
                print(f"File not found: {mask_path}")
                lines = []
                empty = True
                image_and_paths.append((image, lines, file, empty))
                count += 1
print(f"Files not found: {count}")

Files not found: 0


In [7]:
count_dict = {}

for image_and_path in image_and_paths:
    _, _, s, _ = image_and_path
    number = s.split('_')[0]
    try:
        number = int(number)
    except:
        print(number)
        if -1 in count_dict:
            count_dict[-1] += 1
        else:
            count_dict[-1] = 1
        continue

    if number in count_dict:
        count_dict[number] += 1
    else:
        count_dict[number] = 1

sorted_counts = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)

# total = 0
# for number, count in sorted_counts:
#     print(f"Number {number}: {count} instance(s)")

2022-02-07
2019-09-01
2021-09-01
2021-10-10
2019-10-10
2021-12-29
2021-03-19
2022-09-01
2021-11-19
2019-12-29


In [8]:
test_numbers = []
total = 0
for number, count in sorted_counts:
    if (count > 1 and count < 6) and number not in test_numbers:
        test_numbers.append(number)
        total += count
print(f"Total: {total}")

Total: 163


In [9]:
import copy
val_numbers = copy.deepcopy(test_numbers)

In [10]:
# val_numbers = []
# total = 0
# for number, count in sorted_counts:
#     if (count > 2 and count < 6) and number not in val_numbers:
#         val_numbers.append(number)
#         total += count
# print(f"Total: {total}")

In [11]:
test_image_and_paths = []
for image_and_path in image_and_paths:
    image, lines, file, empty = image_and_path
    try:
        number = int(file.split('_')[0])
    except:
        number = -1
    if number in test_numbers:
        test_image_and_paths.append(image_and_path)

In [12]:
val_image_and_paths = []
for image_and_path in image_and_paths:
    image, lines, file, empty = image_and_path
    try:
        number = int(file.split('_')[0])
    except:
        number = -1
    if number in val_numbers:
        val_image_and_paths.append(image_and_path)

In [13]:
train_image_and_paths = []
for image_and_path in image_and_paths:
    image, lines, file, empty = image_and_path
    try:
        number = int(file.split('_')[0])
    except:
        number = -1
    if number not in test_numbers and number not in val_numbers:
        train_image_and_paths.append(image_and_path)

In [14]:
train_rotated = []
for crop in train_image_and_paths:
    for _ in range(2):
        image, lines, _, _ = crop
        points = []
        for line in lines:
            points += get_keypoints_from_line(line, image.height, image.width)
        image, points = augment_image_and_keypoints(image, points, rotate)
        train_rotated.append((image, points))

In [15]:
def crop(image, keypoints):
    width, height = image.size
    empty = False
    if len(keypoints.keypoints) == 0:
        empty = True
        x_min_valid = 0
        x_max_valid = width - crop_size
        y_min_valid = 0
        y_max_valid = height - crop_size
    else:
        relative_x = [keypoint.x for keypoint in keypoints.keypoints]
        relative_y = [keypoint.y for keypoint in keypoints.keypoints]

        x_min = min(relative_x)
        x_max = max(relative_x)
        y_min = min(relative_y)
        y_max = max(relative_y)

        x_min_valid = max(0, x_max - crop_size)
        x_max_valid = min(width - crop_size, x_min)
        y_min_valid = max(0, y_max - crop_size)
        y_max_valid = min(height - crop_size, y_min)

        if x_min_valid >= x_max_valid or y_min_valid >= y_max_valid or x_max_valid - x_min_valid < 0.5 or y_max_valid - y_min_valid < 0.5:
            print("Invalid crop for image " + image_path)
            return None

    crop_x = random.randint(math.ceil(x_min_valid), math.floor(x_max_valid))
    crop_y = random.randint(math.ceil(y_min_valid), math.floor(y_max_valid))

    crop = image.crop((crop_x, crop_y, crop_x + crop_size, crop_y + crop_size))
    
    if not empty:
        relative_x = [(x - crop_x) for x in relative_x]
        relative_y = [(y - crop_y) for y in relative_y]
        
        string = '0 '
        for x, y in zip(relative_x, relative_y):
            string += str(x / crop_size) + ' ' + str(y / crop_size) + ' '
    else:
        string = ''
    return (crop, string)

In [16]:
train_cropped = []
for i, rotated in enumerate(train_rotated):
    for _ in range(10):
        image, keypoints = rotated
        cropped = crop(image, keypoints)
        if cropped is not None:
            train_cropped.append(cropped)

Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-

In [17]:
test_cropped = []
for i, rotated in enumerate(test_image_and_paths):
    image, lines, _, _ = rotated
    points = []
    for line in lines:
        points += get_keypoints_from_line(line, image.height, image.width)
    cropped = crop(image, KeypointsOnImage(points, shape=image.size))
    if cropped is not None:
        test_cropped.append(cropped)

Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-

In [18]:
val_cropped = []
for i, rotated in enumerate(val_image_and_paths):
    image, lines, _, _ = rotated
    points = []
    for line in lines:
        points += get_keypoints_from_line(line, image.height, image.width)
    cropped = crop(image, KeypointsOnImage(points, shape=image.size))
    if cropped is not None:
        val_cropped.append(cropped)

Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-12_2024-04-21_png_jpg.rf.be71c9743cc84457ba82abb6733ef628.jpg
Invalid crop for image ./seg-14/train/images/10_2024-03-

In [19]:
for i in range(len(train_cropped)):
    for _ in range(2):
        image, line = train_cropped[i]
        try:
            image, keypoints = augment_image_and_keypoints(image, line, seq)
        except:
            continue
        line = ''
        if len(keypoints.keypoints) > 0:
            line = '0 '
        for keypoint in keypoints.keypoints:
            if keypoint.x < 0 or keypoint.y < 0 or keypoint.x > crop_size or keypoint.y > crop_size:
                print("Invalid keypoint: " + str(keypoint.x/crop_size) + " " + str(keypoint.y/crop_size))
            line += str(keypoint.x/crop_size) + ' ' + str(keypoint.y/crop_size) + ' '
        train_cropped[i] = (image, line)

  augmenter_active = np.zeros((nb_rows, len(self)), dtype=np.bool)


In [20]:
print("Train size: " + str(len(train_cropped)))
print("Val size: " + str(len(val_cropped)))
print("Test size: " + str(len(test_cropped)))

Train size: 11830
Val size: 146
Test size: 146


In [21]:
# from matplotlib import pyplot as plt
# for idx, (image, data) in enumerate(train_cropped):
#     # Convert PIL Image to NumPy array
#     img_array = np.array(image)
    
#     # Create a Matplotlib figure and axis
#     plt.figure(figsize=(8, 6))
#     plt.imshow(img_array)
#     plt.axis('off')  # Hide axis
    
#     # Parse the data string to extract coordinates
#     points = data.strip().split()
    
#     # Assuming the first element is a label, skip it
#     coords = points[1:]
    
#     # Ensure that there is an even number of coordinates
#     if len(coords) % 2 != 0:
#         raise ValueError(f"Odd number of coordinates in data: {data}")
    
#     # Extract x and y coordinates, scaling them by crop_size
#     x_coords = []
#     y_coords = []
#     for i in range(0, len(coords), 2):
#         x = float(coords[i]) * crop_size
#         y = float(coords[i+1]) * crop_size
#         x_coords.append(x)
#         y_coords.append(y)
    
#     # Plot the points on the image
#     plt.scatter(x_coords, y_coords, c="red", s=0.5, marker='o', edgecolors='white')
    
#     # Optionally, add titles or annotations
#     plt.title(f'Image {idx+1} with Points')
    
#     plt.show()

In [22]:
folder = "train"
os.makedirs(os.path.join(destination, folder, "images"), exist_ok=True)
os.makedirs(os.path.join(destination, folder, "labels"), exist_ok=True)
for image, data in train_cropped:
    name = str(random.randint(0, 1000000))
    image.save(os.path.join(destination, folder, "images", name + ".jpg"))
    with open(os.path.join(destination, folder, "labels", name + ".txt"), "w") as f:
        f.write(data)

In [23]:
folder = "test"
os.makedirs(os.path.join(destination, folder, "images"), exist_ok=True)
os.makedirs(os.path.join(destination, folder, "labels"), exist_ok=True)
for image, data in test_cropped:
    name = str(random.randint(0, 1000000))
    image.save(os.path.join(destination, folder, "images", name + ".jpg"))
    with open(os.path.join(destination, folder, "labels", name + ".txt"), "w") as f:
        f.write(data)

In [24]:
folder = "valid"
os.makedirs(os.path.join(destination, folder, "images"), exist_ok=True)
os.makedirs(os.path.join(destination, folder, "labels"), exist_ok=True)
for image, data in val_cropped:
    name = str(random.randint(0, 1000000))
    image.save(os.path.join(destination, folder, "images", name + ".jpg"))
    with open(os.path.join(destination, folder, "labels", name + ".txt"), "w") as f:
        f.write(data)