In [1]:
%load_ext autoreload
%autoreload 2

# Get the sets

In [2]:
import os
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2,40)).__str__()
from satellitepy.data.labels import read_label, init_satellitepy_label, set_image_keys, get_all_satellitepy_keys
from satellitepy.utils.path_utils import get_file_paths, create_folder
from satellitepy.data.utils import get_satellitepy_dict_values, count_unique_values, get_satellitepy_table, read_img, set_satellitepy_dict_values
from satellitepy.data.bbox import BBox
from satellitepy.data.patch import is_truncated, shift_bboxes, create_patch_polygon, get_intersection
from satellitepy.data.tools import show_labels_on_images

import cv2
import random
import numpy as np
import itertools
import sys
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import json



In [None]:
# img_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_dataset/images/")
# label_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_dataset/labels_fineair/role_th_50/")
# label_format = 'fineair' # 
# test_img_sz = 8000
# intersection_th = 0.91
# test_sum_ratio_th = 0.15
# bbox_for_intersection = 'dbboxes'
# delete_output_files = True
# input_img_ext = 'tif'
# img_read_module = 'rasterio'

# train_label_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/train/labels_alpha")
# train_img_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/train/images_alpha")
# test_label_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/test/labels_alpha")
# test_img_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/test/images_alpha")


# for folder in [train_img_folder, train_label_folder, test_img_folder, test_label_folder]:
#     assert create_folder(folder)


In [None]:
img_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/train/images_alpha")
label_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/train/labels_alpha")
label_format = 'satellitepy' # 
test_img_sz = 6000
intersection_th = 0.91
test_sum_ratio_th = 0.15
bbox_for_intersection = 'obboxes'
img_read_module = 'cv2'
input_img_ext = 'png'
delete_output_files = True

train_label_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/only_train/labels_alpha")
train_img_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/only_train/images_alpha")
test_label_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/val/labels_alpha")
test_img_folder = Path("/mnt/2tb-1/satellitepy/data/FR24_sets/val/images_alpha")

for folder in [train_img_folder, train_label_folder, test_img_folder, test_label_folder]:
    assert create_folder(folder)

## WARNING! THIS WILL DELETE ALL IMAGES AND LABELS IN OUTPUT FOLDERS

In [4]:
def remove_all_files(folder_path):
    folder = Path(folder_path)
    for file in folder.iterdir():
        if file.is_file():
            file.unlink()

if delete_output_files:
    remove_all_files(test_img_folder)
    remove_all_files(test_label_folder)
    remove_all_files(train_img_folder)
    remove_all_files(train_label_folder)

## Get patch labels

In [5]:
def get_patch_dict(label, test_img_size, margin=100, intersection_th=0.95):

    satellitepy_fac = get_satellitepy_table()['fineair-class']
    
    # Define patch starting coords
    patch_start_coords = []
    for i, bbox_corners in enumerate(label[bbox_for_intersection]):
        bbox = BBox(corners=bbox_corners)
        x_min, x_max, y_min, y_max = bbox.get_bbox_limits(bbox.corners)
        x_0, y_0 = np.maximum(x_min - margin,0), np.maximum(y_min - margin,0)
        patch_start_coords.append([x_0,y_0])

    # Set patch 
    patch_dict = {
        'test_indices': [[] for _ in range(len(patch_start_coords))],
        'train_indices':[[] for _ in range(len(patch_start_coords))],
        'test_fac':[[] for _ in range(len(patch_start_coords))],
        'test_fac_count':[[0 for fac_i in range(len(set(satellitepy_fac.values())))] for _ in range(len(patch_start_coords))],
        'train_fac':[[] for _ in range(len(patch_start_coords))],
        'train_fac_count':[[0 for fac_i in range(len(set(satellitepy_fac.values())))] for _ in range(len(patch_start_coords))],
        'start_coords': patch_start_coords,
        }
    for i, patch_start_coord in enumerate(patch_start_coords):
        x_0, y_0 = patch_start_coord
        patch_polygon = create_patch_polygon(x_0=x_0, y_0=y_0, patch_size=test_img_size)
        fineair_classes = get_satellitepy_dict_values(label,task='fineair-class')
        for j, bbox_corners in enumerate(label[bbox_for_intersection]):
            intersection = get_intersection(bbox_corners=bbox_corners, patch_polygon=patch_polygon)
            # is_truncated_bbox = is_truncated(bbox_corners=bbox_corners, patch_polygon=patch_polygon, relative_area_threshold=relative_area_thr)
            ## Set the labels to empty because of cutoff objects
            if intersection == 0:
                continue
            if intersection > 0 and intersection_th > intersection:
                # patch_dict['labels'][i] = []
                patch_dict['test_indices'][i] = []
                patch_dict['test_fac'][i] = []
                break
            else:
                # patch_dict['labels'][i] = set_image_keys(get_all_satellitepy_keys(), patch_dict['labels'][i], label, j)
                patch_dict['test_indices'][i].append(j)
                patch_dict['test_fac'][i].append(fineair_classes[j])
        patch_dict['train_indices'][i] = list(set(range(len(label[bbox_for_intersection])))-set(patch_dict['test_indices'][i]))
        patch_dict['train_fac'][i] = [fineair_classes[j] for j in patch_dict['train_indices'][i]]

        # Set train fac count and test fac count
        for fac in patch_dict['train_fac'][i]:
            fac_i = satellitepy_fac[fac]
            patch_dict['train_fac_count'][i][fac_i] += 1 
        for fac in patch_dict['test_fac'][i]:
            fac_i = satellitepy_fac[fac]
            patch_dict['test_fac_count'][i][fac_i] += 1
    return patch_dict
        

In [6]:
label_paths = get_file_paths(label_folder)


patch_dicts = {}

for label_path in label_paths:
    label = read_label(label_path=label_path,label_format=label_format)
    patch_dict = get_patch_dict(label=label,test_img_size=test_img_sz, intersection_th=intersection_th)
    label_file_name = label_path.name
    patch_dicts[label_file_name] = patch_dict



  return lib.intersection(a, b, **kwargs)


## Get one patch from each original image for the test set

In [7]:
def unique_lists_with_indices(list_of_lists):
    unique_dict = {}
    
    for idx, sublist in enumerate(list_of_lists):
        # Convert the list to a tuple so it can be used as a dictionary key
        tuple_sublist = tuple(sublist)
        
        # Store the sublist and its first occurrence index if it's unique
        if tuple_sublist not in unique_dict:
            unique_dict[tuple_sublist] = idx
    
    # Extract unique lists and their indices
    unique_lists = [list(key) for key in unique_dict.keys()]
    unique_indices = list(unique_dict.values())
    
    return unique_lists, unique_indices

# Test case
# input_list = [[1, 1], [2, 2], [1, 1]]
# unique_lists, unique_indices = unique_lists_with_indices(input_list)
# print("Unique Lists:", unique_lists)
# print("Indices of First Occurrences:", unique_indices)

In [8]:
def compute_test_train_pairs(patch_dicts):
    
    orig_test_train_pair_indices = []
    all_test_train_pairs = []
    all_test_train_pair_indices = []
    # Iterate through each dictionary in B
    for img_name, patch_dict in patch_dicts.items():
        test_indices, unique_patch_indices = unique_lists_with_indices(patch_dict['test_indices'])
        ## Remove empty list from test indices
        if len(unique_patch_indices) == 0:
            print(img_name)
        else:
            for i, test_ind in enumerate(test_indices.copy()):
                if test_ind == []:
                    test_indices.pop(i)
                    unique_patch_indices.pop(i)
        train_indices = [ind for i, ind in enumerate(patch_dict['train_indices']) if i in unique_patch_indices]
        # if img_name == 'O_Hare_Int_Airport_23FEB28165710.json':
        #     print(patch_dict['train_indices'])
        # train_fac = [fac for i, fac in enumerate(patch_dict['train_fac']) if i in unique_patch_indices]
        # test_fac = [fac for i, fac in enumerate(patch_dict['test_fac']) if i in unique_patch_indices]
        test_fac_count = [fac for i, fac in enumerate(patch_dict['test_fac_count']) if i in unique_patch_indices]
        train_fac_count = [fac for i, fac in enumerate(patch_dict['train_fac_count']) if i in unique_patch_indices]
        test_train_pairs = list(zip(test_fac_count, train_fac_count))
        test_train_pair_indices = list(zip(test_indices, train_indices))
        all_test_train_pairs.append(test_train_pairs)
        all_test_train_pair_indices.append(test_train_pair_indices)
        orig_test_train_pair_indices.append(unique_patch_indices)
    return all_test_train_pairs, all_test_train_pair_indices, orig_test_train_pair_indices


In [9]:
test_train_pairs, test_train_pair_indices, orig_test_train_pair_indices = compute_test_train_pairs(patch_dicts)
# print(test_train_pair_indices[0][17])
# print(orig_test_train_pair_indices)
# print(list(patch_dicts.values())[0]['test_indices'][17])
# print(list(patch_dicts.values())[0]['train_indices'])

for i, s in enumerate(test_train_pairs):
    if len(s) <= 0:
        print(s)
        print(len(s),i)
        print(list(patch_dicts.keys())[i])
        print(list(patch_dicts.values())[i])


## Cartesian options

In [10]:
# Calculate the number of elements in each set
# lengths = [len(s) for s in test_train_pairs]

def cartesian_with_indices(all_test_train_pairs):
    # Calculate the product of all test-train pairs across dictionaries in B
    for combination in itertools.product(*all_test_train_pairs):
         # Find indices of elements in their respective sets
        indices = [s.index(elem) for s, elem in zip(all_test_train_pairs, combination)]
        yield combination, indices

def random_cartesian_with_indices(sets, num_samples):
    for _ in range(num_samples):
        # Randomly choose an index for each set
        indices = [random.randint(0, len(s) - 1) for s in sets]
        # Get the elements at the chosen indices
        sample = [sets[i][index] for i, index in enumerate(indices)]
        yield sample, indices

In [None]:

def get_cartesian_test_train_pairs(sets, test_sum_ratio_th):
    best_indices = [0]*len(sets)
    best_sum_ratio_dif = np.inf
    best_test_to_all_ratio = np.Inf
    # for combination, indices in cartesian_with_indices(sets):
    for combination, indices in random_cartesian_with_indices(test_train_pairs, 300000):

        test_sum = np.sum(np.array(combination)[:,0],axis=0)
        train_sum = np.sum(np.array(combination)[:,1],axis=0)
        total_sum = np.sum(a=[test_sum,train_sum],axis=1)

        test_sum_ratio = test_sum / total_sum[0]
        train_sum_ratio = train_sum / total_sum[1]
        test_to_all_ratio = total_sum[0]/np.sum(total_sum)
        sum_ratio_dif = np.sum(np.abs(test_sum_ratio-train_sum_ratio)) # train and test set ratio difference

        if (best_test_to_all_ratio+0.001 >= np.abs(test_to_all_ratio-test_sum_ratio_th)) and (sum_ratio_dif <= best_sum_ratio_dif):
            print('# Test instances: ', test_sum)
            print('# Train instances: ',train_sum)
            print('# Total instances (ratio): ', total_sum, test_to_all_ratio)
            print('Absolute sum diff: ', sum_ratio_dif)
            print(indices)
            best_sum_ratio_dif = sum_ratio_dif
            best_test_to_all_ratio = np.abs(test_to_all_ratio-test_sum_ratio_th)
            best_indices = indices
    return best_indices

In [13]:
result = get_cartesian_test_train_pairs(test_train_pairs, test_sum_ratio_th=test_sum_ratio_th)

# Test instances:  [10 17 77 67 20  0  6  0 74 90  3 10 16 23 19  1  2  7  0 11  0 84 15]
# Train instances:  [ 66 164 540 438 188   0 127   0 764 780  50  53 133 269 128  56  59  95  53  86   0 710 209]
# Total instances (ratio):  [ 552 4968] 0.1
Absolute sum diff:  0.2101449275362319
[26, 1, 18, 12, 28, 36, 10, 24, 18, 2, 21, 7, 36, 9, 5, 4, 9, 23, 12, 23, 16, 24, 17, 21, 5, 10, 1, 9, 19, 12, 12, 3, 1, 20, 19, 7, 6, 12, 3, 6, 13]
# Test instances:  [ 15  22  83  62  26   0  27   0 110 138  12   4  16  53  20   8  16  12  19  18   0 109  31]
# Train instances:  [ 61 159 534 443 182   0 106   0 728 732  41  59 133 239 127  49  45  90  34  79   0 685 193]
# Total instances (ratio):  [ 801 4719] 0.1451086956521739
Absolute sum diff:  0.17742443687285364
[10, 12, 16, 7, 14, 23, 13, 16, 12, 0, 6, 9, 32, 22, 23, 19, 18, 18, 6, 11, 16, 5, 13, 31, 7, 11, 1, 5, 0, 22, 34, 20, 19, 22, 15, 0, 11, 11, 5, 25, 9]
# Test instances:  [ 15  28  89  83  29   0  20   0 107 131   2   8  17  42  27  15   

In [14]:
print(result)

[19, 3, 3, 20, 9, 13, 17, 5, 7, 3, 17, 20, 44, 12, 5, 19, 6, 8, 16, 17, 5, 23, 18, 31, 9, 3, 35, 10, 5, 28, 31, 2, 1, 19, 16, 6, 16, 3, 1, 24, 8]


## Save the splits

In [16]:

def save_sets(patch_dicts,best_indices,test_train_pair_indices, orig_test_train_pair_indices):
    
    # for file_name, patch_dict in patch_dicts.items():
    file_names = list(patch_dicts.keys())
    tasks = get_all_satellitepy_keys()
    for i, ind in enumerate(best_indices):
        # i : original image index
        # ind : patch index within the original image

        orig_ind = orig_test_train_pair_indices[i][ind]
        # Save labels
        file_name = Path(file_names[i]).stem
        print(f'Processing {file_name}')
        train_img_path = train_img_folder/f"{file_name}.png"
        if train_img_path.is_file():
             print(f'{file_name} exists in the destination train folder, skipped...')
             continue
        patch_dict = patch_dicts[file_names[i]]
        label_path = label_folder / file_names[i]
        label = read_label(label_path=label_path,label_format=label_format)
        test_label = init_satellitepy_label()
        train_label = init_satellitepy_label()
        x_0, y_0 = patch_dict['start_coords'][orig_ind]

        # Set test and train labels
        ## Unique values are calculated previously for test_train_pairs
        ## Use the corresponding unique value indices to find the train-test pair indices
        test_indices, train_indices = test_train_pair_indices[i][ind]

        # print(test_indices)
        print(f"Test image has {len(test_indices)} airplanes.")
        # print(train_indices)
        # if len(test_indices) == 0:
        #     print(f"{file_name} has no test objects, original image will be saved into the train folder.")
        #     train_label_path = train_label_folder / f"{file_name}.json"
        #     with open(str(train_label_path), 'w') as f:
        #         json.dump(train_label, f, indent=4)
        #     img = read_img(img_path=img_folder/f"{file_name}.tif", module='rasterio')
        #     cv2.imwrite(str(train_img_folder/f"{file_name}.png"),img)
        #     continue
        for task in tasks:
            task_values = get_satellitepy_dict_values(label,task=task)
            test_task_values = []
            train_task_values = []
            for task_value_i, task_value in enumerate(task_values):
                if task_value_i in test_indices:
                    test_task_values.append(task_value)
                elif task_value_i in train_indices:
                    train_task_values.append(task_value)
            test_label = set_satellitepy_dict_values(test_label,task=task,value=test_task_values)
            train_label = set_satellitepy_dict_values(train_label,task=task,value=train_task_values)

        # shift test bboxes
        for bbox_task in ['obboxes','hbboxes']:
            bbox_values = get_satellitepy_dict_values(test_label,task=bbox_task)
            shifted_bbox_values = (np.array(bbox_values) - [x_0, y_0]).tolist()
            test_label = set_satellitepy_dict_values(test_label,task=bbox_task,value=shifted_bbox_values)

        test_file_name = f'{file_name}_x0_{x_0}_y0_{y_0}_sz_{test_img_sz}'
        test_label_path = test_label_folder / f"{test_file_name}.json"
        train_label_path = train_label_folder / f"{file_name}.json"

        with open(str(test_label_path), 'w+') as f:
                json.dump(test_label, f, indent=4)
        with open(str(train_label_path), 'w+') as f:
                json.dump(train_label, f, indent=4)

        # Save images
        print(f"Saving train image to {str(train_img_path)}...")
        img = read_img(img_path=str(img_folder/f"{file_name}.{input_img_ext}"), module=img_read_module)
        test_img = img[y_0:y_0+test_img_sz, x_0:x_0+test_img_sz, :]
        train_mask = np.ones_like(img, dtype=np.uint8)
        train_mask[y_0:y_0+test_img_sz, x_0:x_0+test_img_sz, :] = 0
        train_img = img*train_mask
        cv2.imwrite(str(train_img_path),train_img)
        cv2.imwrite(str(test_img_folder/f"{test_file_name}.png"),test_img)


In [17]:
save_sets(patch_dicts=patch_dicts,
    best_indices=result,
    test_train_pair_indices=test_train_pair_indices,
    orig_test_train_pair_indices=orig_test_train_pair_indices)

Processing Amsterdam_23MAR14104929
Test image has 5 airplanes.
Saving train image to /mnt/2tb-1/satellitepy/data/FR24_sets/train/images_alpha/Amsterdam_23MAR14104929.png...
Processing Bangkok_23FEB21040003
Test image has 1 airplanes.
Saving train image to /mnt/2tb-1/satellitepy/data/FR24_sets/train/images_alpha/Bangkok_23FEB21040003.png...
Processing Beijing_23SEP13031225
Test image has 5 airplanes.
Saving train image to /mnt/2tb-1/satellitepy/data/FR24_sets/train/images_alpha/Beijing_23SEP13031225.png...
Processing Beijing_Capital_International_22DEC04031345
Test image has 12 airplanes.
Saving train image to /mnt/2tb-1/satellitepy/data/FR24_sets/train/images_alpha/Beijing_Capital_International_22DEC04031345.png...
Processing Cairo_24JAN04084426
Test image has 14 airplanes.
Saving train image to /mnt/2tb-1/satellitepy/data/FR24_sets/train/images_alpha/Cairo_24JAN04084426.png...
Processing Cairo_24JAN29084910
Test image has 47 airplanes.
Saving train image to /mnt/2tb-1/satellitepy/data

In [34]:
# show_labels_on_images(image_folder=test_img_folder,
#         label_folder=test_label_folder,
#         mask_folder=None,
#         label_format='satellitepy',
#         img_read_module='cv2',
#         out_folder=Path('/home/murat/Projects/satellitepy/docs/temp_fineair_set'),
#         tasks=['coarse-class','obboxes'],
#         rescaling=1.0,
#         interpolation_method=None)