In [None]:
import os
from sklearn.model_selection import train_test_split


def create_image_mask_pairs(data_dir):
    images_dir = os.path.join(data_dir, "image_tiles_/")
    masks_dir = os.path.join(data_dir, "mask_tiles_/")

    # Get lists of files
    image_files = sorted(os.listdir(images_dir))
    mask_files = sorted(os.listdir(masks_dir))

    # Find common files (one-to-one correspondence by name)
    common_files = set(image_files) & set(mask_files)

    # Filter and retain only the common files
    image_files = [f for f in image_files if f in common_files]
    mask_files = [f for f in mask_files if f in common_files]

    # Ensure both lists are sorted to maintain order
    image_files.sort()
    mask_files.sort()

    # Pair images and masks
    all_pairs = [
        {"image": os.path.join(images_dir, file), "annotation": os.path.join(masks_dir, file)}
        for file in image_files
    ]

    print(f"Total valid pairs: {len(all_pairs)}")

    # Split the data into train and test sets
    train_pairs, test_pairs = train_test_split(all_pairs, test_size=0.2, random_state=42)

    # Save training data to 'train_data.txt'
    with open("train_data.txt", "a") as train_file:
        for pair in train_pairs:
            train_file.write(f"'image_path': {pair['image']}, 'annotation_path': {pair['annotation']}\n")

    # Save testing data to 'test_data.txt'
    with open("test_data.txt", "a") as test_file:
        for pair in test_pairs:
            test_file.write(f"'image_path': {pair['image']}, 'annotation_path': {pair['annotation']}\n")

    # Optionally, print sizes for verification
#     print(f"Training pairs: {len(train_pairs)}")
#     print(f"Testing pairs: {len(test_pairs)}")
    
    return train_pairs,test_pairs

data_dir = "/media/usama/SSD/Data_for_SAM2_model_Finetuning/Cities/fl_bowling_city/output/image_tiles_and_masks_tiles_for_fl_bowling_city/"
train_pairs,test_pairs = create_image_mask_pairs(data_dir)
print(len(train_pairs))

In [None]:
from shapely.geometry import Polygon
import cv2
from shapely.validation import make_valid
import numpy as np
import random
import matplotlib.pyplot as plt

In [None]:
def get_representative_points_within_contours(contours, contours_1,mask):
    """Get representative points within each part of the polygon or a reduced number if there's intersection with contours_1."""
    representative_points = []

    def get_quadrant_representative_points(polygon):
        """Get representative points from the quadrants of a polygon."""
        min_x, min_y, max_x, max_y = polygon.bounds
        center_x = (min_x + max_x) / 2
        center_y = (min_y + max_y) / 2

        quadrants = [
            Polygon([(min_x, min_y), (center_x, min_y), (center_x, center_y), (min_x, center_y)]),
            Polygon([(center_x, min_y), (max_x, min_y), (max_x, center_y), (center_x, center_y)]),
            Polygon([(min_x, center_y), (center_x, center_y), (center_x, max_y), (min_x, max_y)]),
            Polygon([(center_x, center_y), (max_x, center_y), (max_x, max_y), (center_x, max_y)])
        ]

        temp_points = []  # Temporary list to hold quadrant representative points

        for quadrant in quadrants:
            if quadrant.intersects(polygon):
                intersection = quadrant.intersection(polygon)
                if not intersection.is_empty:
                    rep_point = intersection.representative_point()
                    temp_points.append((rep_point.x, rep_point.y))

        return temp_points
    
    def is_foreground_pixel(x, y, mask):
        """Check if a point lies on the foreground pixel of the annotation mask."""
        rows, cols = mask.shape
        if 0 <= int(y) < rows and 0 <= int(x) < cols:
#             return mask[int(y), int(x)] == 255  # Adjust based on foreground label
            return mask[int(y), int(x)]>0
        return False

    for contour_1 in contours_1:
        try:
            shapely_polygon = Polygon([(point[0][0], point[0][1]) for point in contour_1])
            shapely_polygon = make_valid(shapely_polygon)  # Ensure the polygon is valid
            count = 0
            tmp_pts = []

            for contour in contours:
                # shapely_polygon_1 = Polygon([(point[0][0], point[0][1]) for point in contour])
                coordinates = []
                for cont_point in contour:
                    x = cont_point[0][0]
                    y = cont_point[0][1]
                    coordinates.append((x, y))
                tmp_pts_1 =[]
                if len(coordinates)>3:
                # Create the polygon using the list of coordinates
                    shapely_polygon_1 = Polygon(coordinates)
                    shapely_polygon_1 = make_valid(shapely_polygon_1)  # Ensure the polygon is valid
                    # plot_polygon
                 

                    if shapely_polygon.intersects(shapely_polygon_1):
                        count += 1

                        if shapely_polygon_1.area <= 200:
                            rep_point = shapely_polygon_1.representative_point()
                            representative_points.append(([(rep_point.x, rep_point.y)]))
                            print("representative point after area is less than 200",representative_points)
                        else:
                            pts = get_quadrant_representative_points(shapely_polygon_1)
                            # print("points11",points)
                            for pt in pts:
                                if is_foreground_pixel(pt[0],pt[1],mask):
                                    tmp_pts_1.append(pt)
                            tmp_pts.append(tmp_pts_1)

                            # tmp_pts.append(get_quadrant_representative_points(shapely_polygon_1))

            if count > 1:
                print("length of tmp_pts",len(tmp_pts))
                if len(tmp_pts) >= 2:
                    representative_points.append(list(random.sample(tmp_pts[0], 2)))
                    representative_points.append(list(random.sample(tmp_pts[1], 2)))
                elif tmp_pts:
                    representative_points.extend(list(tmp_pts[0]))
            elif count==1:
#                 rep_point = shapely_polygon.representative_point()
#                 representative_points.append((rep_point.x, rep_point.y))  # To tackle the case where intersection is not present
                if tmp_pts:
                # If no multiple intersections, still get quadrant points
                    
                    representative_points.append(list(tmp_pts[0]))
                    # print(representative_points)
            else:
                rep_point = shapely_polygon.representative_point()
                representative_points.append(list((rep_point.x, rep_point.y)))  # 

                # if tmp_pts:
                    
                # # If no multiple intersections, still get quadrant points
                #     representative_points.extend(tmp_pts[0])


        except ValueError as e:
            print(f"Error creating polygon: {e}")
            continue

    return representative_points

    

In [None]:
def read_batch(data, visualize_data=False):
    output_base_dir = "output"
    images_dir = os.path.join(output_base_dir, "images")
    masks_dir = os.path.join(output_base_dir, "masks")
    txt_files_dir = os.path.join(output_base_dir, "txt_files")
    os.makedirs(images_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)
    os.makedirs(txt_files_dir, exist_ok=True)
    results = []

    for ent in data:
        image = cv2.imread(ent["image"])[..., ::-1]
        ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)

        if image is None or ann_map is None:
            print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}")
            continue

        scale = min(1024 / image.shape[1], 1024 / image.shape[0])
        image = cv2.resize(image, (int(image.shape[1] * scale), int(image.shape[0] * scale)))
#         ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * scale), int(ann_map.shape[0] * scale)), interpolation=cv2.INTER_NEAREST)
        ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * scale), int(ann_map.shape[0] * scale)))
        _, binary_mask = cv2.threshold(ann_map, 127, 255, cv2.THRESH_BINARY)
        contours_1, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        # contours_1_blank_image = np.zeros(image.shape[:2],dtype=np.uint8)
#         print("Contours length before erosion:", len(contours_1))
        eroded_mask = cv2.erode(ann_map, np.ones((5, 5), np.uint8), iterations=2)
        _, binary_mask_eroded = cv2.threshold(eroded_mask, 127, 255, cv2.THRESH_BINARY)
        contours_2, _ = cv2.findContours(binary_mask_eroded, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        # contours_2_blank_image = np.zeros(image.shape[:2],dtype=np.uint8)
#         print("Contours length after erosion:", len(contours_2))
        final_mask = cv2.erode(ann_map, np.ones((5, 5), np.uint8), iterations=2) if len(contours_2) >= len(contours_1) else ann_map
#         for i in range(1, 2):
       
        _, binary_mask_final = cv2.threshold(final_mask, 100, 255, cv2.THRESH_BINARY)
        contours, _ = cv2.findContours(binary_mask_final, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
#         cv2.drawContours(image, contours, -1, (255, 255, 255), 10)  # Green color, thickness of 2
#         cv2.waitKey(0)
#         cv2.destroyAllWindows()
        # Get representative points with intersection logic
        rep_points = get_representative_points_within_contours(contours, contours_1,ann_map)
        print("points",rep_points)
#             print("length of points",len(points))
#             print("points",points)
        
        ct = 0
        if len(rep_points)!=0:
        # if points is not None:  
            for rep_points_item in rep_points:
                if len(rep_points_item)>=3:
                    ct+=1
            if ct==len(rep_points):
                pts_ = rep_points
            else:
                continue
            for i in range(3):
                image_path = ent["image"]
                mask_path = ent["annotation"]
                image = cv2.imread(image_path)
                mask = cv2.imread(mask_path)
                copy_image_name  = f"{os.path.basename(ent['image']).split('.')[0]}_copy{i}.jpg"
                copy_mask_name  = f"{os.path.basename(ent['annotation']).split('.')[0]}_copy{i}.jpg"
            
                copy_image_path = os.path.join(images_dir, copy_image_name)
                print("copy image path",copy_image_path)
                cv2.imwrite(copy_image_path,image)
              

                copy_mask_path = os.path.join(masks_dir, copy_mask_name)
                cv2.imwrite(copy_mask_path,mask)
                
                txt_file_name = f'{copy_image_name.split(".")[0]}.txt'
                copy_txt_path = os.path.join(txt_files_dir, txt_file_name)
                
                points_pair = []
                for pt_element in pts_:
                    # Get the i-th point from each element and convert to integer
                    pt_ = tuple(map(int, pt_element[i]))  # Convert to integer
                    points_pair.append(pt_)
                    with open(copy_txt_path, 'w') as file:
                        # for point in points:
                        for (cX, cY) in points_pair:
                            file.write(f'{cX}, {cY}\n')
                print("point pairs",points_pair)
#                 if visualize_data:
#                     plt.figure(figsize=(10, 10))
#                     plt.imshow(ann_map, cmap='gray')
#                     for (cX, cY) in points_pair:
#                         plt.plot(cX, cY, 'ro')
#                     plt.title(f"Image: {ent['image']} with Contour Points")
#                     plt.show()


                results.append({
                "image": ent["image"],
                "annotation": ent["annotation"],
                "txt_file": copy_txt_path,
                "points": points_pair
                    


        })
                

    return results

# sk_tiles_/186_tx_willis_147.jpg', 'annotation': '/media/usama/SSD/Data_for_SAM2_model_Finetuning/Cities/tx_willis_city/output/step6_outputs/mask_tiles_/186_tx_willis_147.jpg'}]
train_data =   [{'image': '/media/usama/SSD/Data_for_SAM2_model_Finetuning/Cities/ct_sprague_city/output/step6_outputs/image_tiles_/60_ct_sprague_64.jpg', 'annotation': '/media/usama/SSD/Data_for_SAM2_model_Finetuning/Cities/ct_sprague_city/output/step6_outputs/mask_tiles_/60_ct_sprague_64.jpg'}]
# es_and_mask_tiles_for_Fl_indialantic_city/image_tiles_/0_indialantic_77.jpg', 'annotation': '/media/usama/SSD/Data_for_SAM2_model_Finetuning/Cities/fl_indialantic_city/output/image_tiles_and_mask_tiles_for_Fl_indialantic_city/mask_tiles_/0_indialantic_77.jpg'}]
results = read_batch(train_pairs, visualize_data=True)