In [6]:
import os
import cv2
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans

In [7]:
def process_image(image_path: str, label_path: str, output_directory: str, color: tuple =(255, 255, 255), threshold: int = 220) -> None:

    """
    Process an image and its corresponding label file to extract polylines and save them as separate images.
    
    Args:
        image_path: str: Path to the image file.
        label_path: str: Path to the label file.
        output_directory: str: Path to the directory where the extracted polylines will be saved.
        color: tuple: RGB color of the polyline.
        threshold: int: Threshold value for white pixels in the image.
    
    Returns:
        None
    """
    
    # Load image
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB
    
    # find dominant color
    _, image = find_dominant_color(image)
    
    # save image
    image_no_dominant_color = Image.fromarray(image)
    image_no_dominant_color.save(os.path.join(output_directory, 'image_without_dominant_color.png'))

    # Load labels
    with open(label_path, 'r') as f:
        lines = f.readlines()
    lines = [line.strip().split(' ') for line in lines] # removes classnumber and \n

    # Process each polyline
    for idx, line in enumerate(lines):
        # Skip the class label (at index 0) and process points
        points = [[int(float(line[i]) * image.shape[1]), int(float(line[i+1]) * image.shape[0])] for i in range(1, len(line), 2)]
        
        # Check if we have enough points to form a polyline
        if len(points) > 1:
            # create mask
            mask = np.zeros_like(image)

            # Convert points into a numpy array and reshape for polylines
            polyline_points = np.array(points, dtype=np.int32).reshape((-1, 1, 2))

            # Draw the polyline on the mask with white color and increased thickness
            cv2.fillPoly(mask, [polyline_points], color=color)

            # Extracting the polyline to a transparent background
            extracted = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)

            # Convert mask to grayscale (easier to handle)
            mask_gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)

            # Copy color from the original image where mask is white
            for i in range(3):  # Assuming RGB
                extracted[:, :, i] = np.where(mask_gray == 255, image[:, :, i], 0)

            # Set the alpha channel: full opacity where mask is white, transparent elsewhere
            extracted[:, :, 3] = np.where(mask_gray == 255, 255, 0)

            # Call the function to convert white pixels to transparent
            final_image = white_to_transparent(extracted, threshold)

            # Save the processed image
            output_path = os.path.join(output_directory, f'extracted_polyline_{idx}.png')
            final_image.save(output_path)
        else:
            print(f"Not enough points to form a polyline for line {idx}.")


def white_to_transparent(image: np.ndarray, threshold: int = 220) -> Image:
    """
    Convert white pixels in an image to transparent.
    
    Args:
        image: np.ndarray: Image as a NumPy array.
        threshold: int: Threshold value for white pixels in the image.
        
    Returns:
        Image: Image with white pixels converted to transparent. 
    """
    # Convert the NumPy array to a PIL Image
    pil_img = Image.fromarray(image)
    
    # Check if the image has an alpha channel
    assert pil_img.mode == 'RGBA'
    
    # Get the image data
    datas = pil_img.getdata()
    
    # Create a new image data list
    new_image_data = []
    
    # Iterate over the image data
    for item in datas:
        # Checking the RGB channels for whiteness, ignore the alpha channel
        if item[0] > threshold and item[1] > threshold and item[2] > threshold:
            new_image_data.append((255, 255, 255, 0))  # Full transparency
        else:
            new_image_data.append(item)
            
    pil_img.putdata(new_image_data)
    
    return pil_img


def find_dominant_color(image: Image, k: int = 5) -> tuple:
    
    # Convert image to numpy array
    img_array = np.array(image)
    
    # Reshape it to a list of RGB values
    img_vector = img_array.reshape((-1, 3))
    
    # Run k-means on the pixel colors
    kmeans = KMeans(n_clusters=k, random_state=0).fit(img_vector)
    
    # Get the dominant color
    dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
    
    # Create a mask for pixels within a certain distance from the dominant color
    distances = np.sqrt(np.sum((img_vector - dominant_color) ** 2, axis=1))
    mask = distances < np.std(distances)
    
    # Turn the dominant color range to white
    img_vector[mask] = [255, 255, 255]
    result_img_array = img_vector.reshape(img_array.shape)

    return dominant_color, result_img_array

def remove_object(image, segmentation_points):
    for points in segmentation_points:
        p = [[int(float(points[i]) * image.shape[1]), int(float(points[i+1]) * image.shape[0])] for i in range(0, len(points), 2)]
        p = np.array(p, dtype=np.int32).reshape((-1, 1, 2))
        cv2.fillPoly(image, [p], color=(255, 255, 255))
    return image

def rotate_crop(crop: np.ndarray, angle: int) -> Image:
    crop = Image.fromarray(crop)
    rotated = crop.rotate(angle, expand=False)
    
    return rotated

def rotate_points(points, angle: int, image_width, image_height):
    # Convert the angle to radians
    angle_radians = np.radians(angle)
    # Compute the rotation matrix
    rotation_matrix = np.array([
        [np.cos(angle_radians), -np.sin(angle_radians)],
        [np.sin(angle_radians), np.cos(angle_radians)]
    ])
    # Convert points to a homogeneous 2D coordinate system with origin at the image's center
    # Note: image_width and image_height are used to transform the origin
    points = np.array(points)
    points[:, 0] = points[:, 0] * image_width - image_width / 2
    points[:, 1] = points[:, 1] * image_height - image_height / 2
    # Rotate points
    rotated = np.dot(points, rotation_matrix.T)
    # Convert points back to the standard 2D coordinate system with origin at the top-left
    rotated[:, 0] = rotated[:, 0] + image_width / 2
    rotated[:, 1] = rotated[:, 1] + image_height / 2
    
    return rotated

def paste_rotated_object(image_background, rotated_object, top_left):
    """
    Paste the rotated object onto the image background at the specified top-left position.

    :param image_background: The background image where the object is removed.
    :param rotated_object: The rotated object as a PIL Image with an alpha channel.
    :param top_left: The top-left position where the object should be pasted.
    :return: Image with the rotated object pasted.
    """
    # Convert background to PIL Image for compatibility
    image_pil = Image.fromarray(image_background)
    
    # Paste the rotated object using its alpha channel as mask
    image_pil.paste(rotated_object, box=top_left, mask=rotated_object.split()[-1])
    return np.array(image_pil)

Crops out segmented object

In [17]:
image_name = '187196037_jpg.rf.0623711bac39d692612564a04adf2564'
image_path = f'../data/Ntnu_segmentation-24/train/images/{image_name}.jpg'
label_path = f'../data/Ntnu_segmentation-24/train/labels/{image_name}.txt'
output_path = './test_0'
ANGLE = -22

dir_number = None
for dirs in sorted(os.listdir('.')):
    if dirs.startswith('test_'):
        dir_number = dirs.split('_')[-1]

output_path = f'./test_{int(dir_number)+1}'
os.makedirs(output_path, exist_ok=False)

process_success = process_image(image_path, label_path, output_path)

with open(label_path, 'r') as f:
    lines = f.readlines()
lines = [line.strip().split(' ')[1:] for line in lines]

image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_255_bg = find_dominant_color(image)[1]
image_no_object = remove_object(image_255_bg.copy(), lines)

for idx, line in enumerate(lines):
    
    crop_name = f'extracted_polyline_{idx}.png'
    crop = cv2.imread(os.path.join(output_path, crop_name), cv2.IMREAD_UNCHANGED)
    crop = cv2.cvtColor(crop, cv2.COLOR_BGRA2RGBA)
    
    rotated_crop = np.array(rotate_crop(crop, angle=ANGLE))
    rotated_crop_image = Image.fromarray(rotated_crop)
    
    points = [[float(line[i]), float(line[i+1])] for i in range(0, len(line), 2)]
    
    rotated_points = rotate_points(np.array(points), -ANGLE, image.shape[1], image.shape[0])
    rotated_points = rotated_points / np.array([image.shape[1], image.shape[0]])
    
    top_left_position = (0, 0)
    if idx == 0:
        image_with_object_pasted = paste_rotated_object(image_no_object, rotated_crop_image, top_left_position)
        image_with_object_pasted = Image.fromarray(image_with_object_pasted)
        image_with_object_pasted.save(os.path.join(output_path, f'{image_name}_{idx}.png'))
    if idx > 0:
        image_no_object = Image.open(os.path.join(output_path, f'{image_name}_{idx-1}.png'))
        image_no_object = np.array(image_no_object)
        image_with_object_pasted = paste_rotated_object(image_no_object, rotated_crop_image, top_left_position)
        image_with_object_pasted = Image.fromarray(image_with_object_pasted)
        image_with_object_pasted.save(os.path.join(output_path, f'{image_name}_{idx}.png'))
    
    with open(os.path.join(output_path, f'{image_name}.txt'), 'a') as f:
        flattend_points = rotated_points.flatten()
        f.write('0')
        for point in rotated_points:
            f.write(f' {point[0]} {point[1]}')
        f.write('\n')