# Thin Plate Splines Transformations

**Purpose:** The purpose of this experiment to implement thin plate spline transformations on images. The goal is to help normalize images of charts that have been folded or bent. TPS registration will be performed post-homography


## What are Thin Plate Splines

For our purpose, a thin plate spline transformation can be defined as follows:  
Given a set of "source" points $P_s$ and a corresponding set of "destination" points $P_d$, the thin plate spline transformation is a function  
$f_{tps}(p) | \forall p_s \in P_s, f_{tps}$ and $f$ is the function when "bends" all other points the least among all differentiable functions (meaning it minimizes an energy function).

![thin_plate_spline_example](https://github.com/user-attachments/assets/9f269c79-6d5d-4090-988c-39ee9e928ff0)

In this image there are two sets of points which form fish shapes. The red '+' points represent the destination points, and the green 'o' points represent the source points. On the right we see the thin plate spline deformation which makes all the green points exactly match all the red points, and we can also see a grid that shows where all other points on the plane will be mapped as well.

## Benefits for Image Registration

While the homography is a very reliable transformation, it is limited to purely linear transformations (ex: rotation, scaling, shear).  
Very commonly, pages will be creased or folded, the edge of pages will curl inwards or outwards, or some cameras will cause barrel/pinhole distortions.  
This leads to a whole subset of issues that cannot be corrected linearly, but still must be accounted for.

![smaller_RC_0033_preoperative_postoperative](https://github.com/user-attachments/assets/f22374e1-75b4-4a07-816f-d3ad51e84921)
An example of a paper which has been folded. and unfolded, causing non-linear distortions. This is a _very_ modest example. In practice, we have seen photographs of papers that are raised at least an inch off the table in certain areas.


In [1]:
# imports
import os
import cv2
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Tuple
from utils.annotations import BoundingBox
from collections import Counter
from PIL import Image
from scipy.interpolate import Rbf

from memory_profiler import profile

In [2]:
%load_ext memory_profiler

Load the data for testing:

- "intraop_document_landmarks.json"
  - Used for the destination points in the transformation. We are using the landmarks from the unified image
  - Also has landmarks for each of the images. We currently are not using them
- "yolo_data.json"
  - Used for the source points in the transformation.
  - May need to be replace with the landmarks from the other file


In [3]:
# Load yolo_data.json which will be used as src_bbs
PATH_TO_YOLO_DATA = "../../data/yolo_data.json"
PATH_TO_REGISTERED_IMAGES = "../../data/registered_images"

with open(PATH_TO_YOLO_DATA) as json_file:
    yolo_data = json.load(json_file)

print(f"Found {len(yolo_data)} sheets in yolo_data.json")

# load introp_document_landmarks.json which will be used as dst_points
PATH_TO_LANDMARKS = "../../data/intraop_document_landmarks.json"

DESIRED_IMAGE_WIDTH = 800
DESIRED_IMAGE_HEIGHT = 600


def label_studio_to_bboxes(path_to_json_data: Path) -> List[BoundingBox]:
    """
    Convert the json data from label studio to a list of BoundingBox objects

    Args:
        path_to_json_data (Path): Path to the json data from label studio

    Returns:
        List[BoundingBox]: List of BoundingBox objects
    """
    json_data: List[Dict] = json.loads(open(str(path_to_json_data)).read())
    return {
        sheet_data["data"]["image"].split("-")[-1]: [
            BoundingBox(
                category=label["value"]["rectanglelabels"][0],
                left=label["value"]["x"] / 100 * DESIRED_IMAGE_WIDTH,
                top=label["value"]["y"] / 100 * DESIRED_IMAGE_HEIGHT,
                right=(label["value"]["x"] / 100 + label["value"]["width"] / 100)
                * DESIRED_IMAGE_WIDTH,
                bottom=(label["value"]["y"] / 100 + label["value"]["height"] / 100)
                * DESIRED_IMAGE_HEIGHT,
            )
            for label in sheet_data["annotations"][0]["result"]
        ]
        for sheet_data in json_data
    }


landmark_location_data: Dict[str, List[BoundingBox]] = label_studio_to_bboxes(
    PATH_TO_LANDMARKS
)

landmarks = landmark_location_data[
    "unified_intraoperative_preoperative_flowsheet_v1_1_front.png"
]

Found 22 sheets in yolo_data.json


**TPS Tranformation**
Steps:

1. Filter to keep only the relevant bounding boxes
   - remove all bounding boxes from the source points that do not match a category in the destination points
   - Find all of the duplicates in the source and destination points and remove them
   - sort the source and destination points alphabetically via their category
2. Get lists of the x and y coordinates for both the source and destination points
   - Primary purpose is to enable the use of scipy's Rbf function
   - We are using the top left corner of the bounding boxes
3. Estimate the transformation
   - Use the Rbf function to apply the TPS transformation
4. Apply the transformation and Warp the image
   - Create a grid from 0 to maximum value of the image
   - Apply the transformation to the grids
   - Ensure that the transformed points are within bounds
   - Use those grids to warp the original image

_Note:_ There are a lot of different outputs that are currently commented out that can be used for debugging purposes:

- Print the lists of duplicate keys and the categories being used in the source and destination points
- Plot of the source and destination points on the image
- View the bounds of the transformed points
- View the distribution of the transformed points


In [4]:
def tps_transform(image: np.ndarray, src_bbs: List[BoundingBox], dst_bbs: List[BoundingBox], scale_factor: float = 0.25) -> np.ndarray:
    """
    Perform a memory-efficient thin plate spline transformation by computing the warp field at low resolution.

    Args:
        image (np.ndarray): The image to be transformed.
        src_bbs (List[BoundingBox]): List of source BoundingBox objects.
        dst_bbs (List[BoundingBox]): List of destination BoundingBox objects.
        scale_factor (float): Factor to downscale the image for computing warp (default: 0.25).

    Returns:
        np.ndarray: The transformed image at original size.
    """

    # Image dimensions
    h, w = image.shape[:2]
    small_h, small_w = int(h * scale_factor), int(w * scale_factor)

    # Scale bounding boxes
    def scale_bbs(bbs, factor):
        return [
            BoundingBox(
                category=bb.category,
                left=bb.left * factor,
                top=bb.top * factor,
                right=bb.right * factor,
                bottom=bb.bottom * factor,
            ) for bb in bbs
        ]

    scaled_src_bbs = scale_bbs(src_bbs, scale_factor)
    scaled_dst_bbs = scale_bbs(dst_bbs, scale_factor)

    # Remove duplicate category points
    duplicate_cats = {k for k, v in Counter([bb.category for bb in scaled_src_bbs]).items() if v > 1}
    duplicate_cats |= {k for k, v in Counter([bb.category for bb in scaled_dst_bbs]).items() if v > 1}

    scaled_src_bbs = [bb for bb in scaled_src_bbs if bb.category not in duplicate_cats]
    scaled_dst_bbs = [bb for bb in scaled_dst_bbs if bb.category not in duplicate_cats]

    # Sort bounding boxes by category
    scaled_src_bbs.sort(key=lambda bb: bb.category)
    scaled_dst_bbs.sort(key=lambda bb: bb.category)

    # Extract bounding box centers
    def get_centers(bbs):
        return np.array(
            [[(bb.left + bb.right) / 2, (bb.top + bb.bottom) / 2] for bb in bbs], dtype=np.float32
        )

    src_points = get_centers(scaled_src_bbs)
    dst_points = get_centers(scaled_dst_bbs)

    # Apply RANSAC filtering
    _, mask = cv2.findHomography(dst_points, src_points, method=cv2.RANSAC, ransacReprojThreshold=5.0)
    inlier_mask = mask.ravel() == 1
    filtered_src = src_points[inlier_mask]
    filtered_dst = dst_points[inlier_mask]

    if len(filtered_src) < 4:
        return image  # Not enough inliers for TPS

    # Apply Thin Plate Splines using RBF
    rbf_x = Rbf(filtered_dst[:, 0], filtered_dst[:, 1], filtered_src[:, 0] - filtered_dst[:, 0], function="thin_plate")
    rbf_y = Rbf(filtered_dst[:, 0], filtered_dst[:, 1], filtered_src[:, 1] - filtered_dst[:, 1], function="thin_plate")

    # Generate a low-res grid for memory-efficient TPS
    small_grid_x, small_grid_y = np.meshgrid(np.linspace(0, small_w - 1, small_w), np.linspace(0, small_h - 1, small_h))

    # Compute displacement fields at low resolution
    disp_x_small = rbf_x(small_grid_x, small_grid_y).astype(np.float32)
    disp_y_small = rbf_y(small_grid_x, small_grid_y).astype(np.float32)

    # Upscale displacement fields to full resolution
    disp_x_large = cv2.resize(disp_x_small, (w, h), interpolation=cv2.INTER_CUBIC)
    disp_y_large = cv2.resize(disp_y_small, (w, h), interpolation=cv2.INTER_CUBIC)

    # Compute final absolute positions
    grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
    transformed_x = np.clip(grid_x + disp_x_large, 0, w - 1)
    transformed_y = np.clip(grid_y + disp_y_large, 0, h - 1)

    # Apply transformation to the original image
    warped_large = cv2.remap(image, transformed_x.astype(np.float32), transformed_y.astype(np.float32), interpolation=cv2.INTER_LINEAR)

    return warped_large


In [None]:
%%memit
for sheet, yolo_bbs in yolo_data.items():
    # get path to current image
    full_image_path = os.path.join(PATH_TO_REGISTERED_IMAGES, sheet)
    image = cv2.imread(full_image_path)
    resized_img = cv2.resize(image, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))
    # get the sheet's bounding boxes
    sheet_bbs = [
        BoundingBox.from_yolo(bb, DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT)
        for bb in yolo_bbs
    ]

    print(len(sheet_bbs))
    transformed_img = tps_transform(resized_img, sheet_bbs, landmarks)
    transformed_img = Image.fromarray(transformed_img)
    transformed_img.show()
    # Show the image
    cv2.imshow("image", resized_img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


304


In [66]:
def tps_transform_ransac(image, src_bbs: List[BoundingBox], dst_bbs: List[BoundingBox]):
    """
    Perform a thin plate spline transformation on the image using the src_bbs and dst_bbs, using RANSAC to filter out outliers.

    Args:
        image (np.ndarray): The image to be transformed
        src_bbs (List[BoundingBox]): List of BoundingBox objects
        dst_bbs (List[BoundingBox]): List of BoundingBox objects

    Returns:
        np.ndarray: The transformed image
    """
    # get the categories from dst_bbs
    landmark_cats = [bb.category for bb in dst_bbs]
    # remove all bbs in src that are not in those categories
    src_bbs = [bb for bb in src_bbs if bb.category in landmark_cats]
    # get list of duplicate keys
    duplicate_count_src = dict(Counter([bb.category for bb in src_bbs]))
    duplicates = [k for k, v in duplicate_count_src.items() if v > 1]
    duplicate_count_dst = dict(Counter([bb.category for bb in dst_bbs]))
    duplicates.extend([k for k, v in duplicate_count_dst.items() if v > 1])
    duplicates = list(set(duplicates))
    # print(duplicates)
    # remove duplicates
    src_bbs = [bb for bb in src_bbs if bb.category not in duplicates]
    dst_bbs = [bb for bb in dst_bbs if bb.category not in duplicates]
    # sort categories alphabetically
    src_bbs = sorted(src_bbs, key=lambda bb: bb.category)
    # print([bb.category for bb in src_bbs])
    dst_bbs = sorted(dst_bbs, key=lambda bb: bb.category)
    # print([bb.category for bb in dst_bbs])

    src_bbs = np.array([[bb.left, bb.top] for bb in src_bbs], dtype=np.float32)
    dst_bbs = np.array([[bb.left, bb.top] for bb in dst_bbs], dtype=np.float32)

    _, mask = cv2.findHomography(
        dst_bbs,
        src_bbs,
        method=cv2.RANSAC,
        ransacReprojThreshold=10.0,
        maxIters=5000,
        confidence=0.99,
    )
    inlier_mask = mask.ravel() == 1

    filtered_src = src_bbs[inlier_mask]
    filtered_dst = dst_bbs[inlier_mask]

    new_src_x, new_src_y = filtered_src[:, 0], filtered_src[:, 1]
    new_dst_x, new_dst_y = filtered_dst[:, 0], filtered_dst[:, 1]

    if len(new_src_x) < 4:
        return image

    # use RBF function to do the thin plate splines
    rbf_x = Rbf(new_dst_x, new_dst_y, new_src_x, function="thin_plate")
    rbf_y = Rbf(new_dst_x, new_dst_y, new_src_y, function="thin_plate")

    # Alter the image according to the transformation
    h, w, _ = image.shape
    # create grid
    x = np.linspace(0, w - 1, w)
    y = np.linspace(0, h - 1, h)
    grid_x, grid_y = np.meshgrid(x, y)

    # apply the transformation
    # reshape into grid
    transformed_x = rbf_x(grid_x, grid_y).astype(np.float32)
    transformed_y = rbf_y(grid_x, grid_y).astype(np.float32)

    transformed_x = np.clip(transformed_x, 0, image.shape[1] - 1)
    transformed_y = np.clip(transformed_y, 0, image.shape[0] - 1)

    # warp the image
    warp_img = cv2.remap(
        image, transformed_x, transformed_y, interpolation=cv2.INTER_LINEAR
    )

    return warp_img

In [77]:
%%memit
for sheet, yolo_bbs in yolo_data.items():
    if sheet != "RC_0031_intraoperative.JPG":
        continue
    # get path to current image
    full_image_path = os.path.join(PATH_TO_REGISTERED_IMAGES, sheet)
    image = cv2.imread(full_image_path)
    resized_img = cv2.resize(image, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))
    # get the sheet's bounding boxes
    sheet_bbs = [
        BoundingBox.from_yolo(bb, DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT)
        for bb in yolo_bbs
    ]
    transformed_img_arr = tps_transform_ransac(resized_img, sheet_bbs, landmarks)
    transformed_img = Image.fromarray(transformed_img_arr)

    #Overlay this image with the original as a sanity check to make sure the transformation is correct
    # Should used image from above as well as transformed_img_arr
    transformed_img = cv2.cvtColor(transformed_img_arr, cv2.COLOR_BGR2RGB)
    # Get data/unified_intraoperative_preoperative_flowsheet_v1_1_back.png
    unified = cv2.imread("../../data/unified_intraoperative_preoperative_flowsheet_v1_1_front.png")
    resized_unified = cv2.resize(unified, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))
    # Show original overlay
    overlay = cv2.addWeighted(resized_unified, 0.5, resized_img, 0.5, 0)
    overlay = Image.fromarray(overlay)
    # Add a name to the window
    overlay.name = "Original Overlay"
    overlay.show()
    overlay = cv2.addWeighted(resized_unified, 0.5, transformed_img, 0.5, 0)
    overlay = Image.fromarray(overlay)
    # Add a name to the window
    overlay.name = "Transformed Overlay"
    overlay.show()


peak memory: 902.11 MiB, increment: 743.00 MiB


### Merging these two methods and priming for placing into extractor

I'm going to create private functions that do RANSAC and distance measurments


In [4]:
# Start with turning Hannah's code into a function
def __filter_by_distance(
    src_bbs: List[BoundingBox],
    dst_bbs: List[BoundingBox],
    threshold: float,
) -> Tuple[List[BoundingBox], List[BoundingBox]]:
    """
    Filter out source and destination points that have a distance greater than the threshold.
    Large transformations are likely to be outliers and erroneous. We only expect small tweaks via the thin plate spline transformation.
    Homography should already be completed prior to TPS.

    Args:
        src_bbs (List[BoundingBox]): The source points
        dst_bbs (List[BoundingBox]): The destination points
        threshold (float): The threshold distance

    Returns:
        Tuple[List[BoundingBox], List[BoundingBox]]: The filtered source and destination points

    """
    filtered_points = [
        (src_bb, dst_bb)
        for src_bb, dst_bb in zip(src_bbs, dst_bbs)
        if abs(src_bb.top - dst_bb.top) < threshold
        and abs(src_bb.left - dst_bb.left) < threshold
    ]

    new_src_bbs, new_dst_bbs = zip(*filtered_points) if filtered_points else ([], [])

    return list(new_src_bbs), list(new_dst_bbs)


# Now we can turn Matt's RANSAC code into a function
def __filter_by_RANSAC(
    src_bbs: List[BoundingBox],
    dst_bbs: List[BoundingBox],
    threshold: float,
    max_iters: int = 5000,
    confidence_limit: float = 0.99,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Filter out source and destination points that are not inliers according to RANSAC

    Args:
        src_bbs (np.ndarray): The source points
        dst_bbs (np.ndarray): The destination points
        threshold (float): The threshold distance
        max_iters (int, optional): The maximum number of iterations for RANSAC. Defaults to 5000.
        confidence_limit (float, optional): The confidence limit for RANSAC. Defaults to 0.99.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: The filtered source and destination points as numpy arrays in this order (src_x, src_y, dst_x, dst_y)
    """
    # Turn the BoundingBox objects into numpy arrays of coordinates
    src_points = np.array([[bb.left, bb.top] for bb in src_bbs], dtype=np.float32)
    dst_points = np.array([[bb.left, bb.top] for bb in dst_bbs], dtype=np.float32)

    # Complete RANSAC
    _, mask = cv2.findHomography(
        dst_points,
        src_points,
        method=cv2.RANSAC,
        ransacReprojThreshold=threshold,
        maxIters=max_iters,
        confidence=confidence_limit,
    )
    inlier_mask = mask.ravel() == 1

    # Apply the mask to the source and destination points
    filtered_src = src_points[inlier_mask]
    filtered_dst = dst_points[inlier_mask]

    # Get the x and y coordinates of the filtered points
    src_x, src_y = filtered_src[:, 0], filtered_src[:, 1]
    dst_x, dst_y = filtered_dst[:, 0], filtered_dst[:, 1]

    return src_x, src_y, dst_x, dst_y

In [None]:
def transform_thin_plate_splines(
    image: np.ndarray,
    src_bbs: List[BoundingBox],
    dst_bbs: List[BoundingBox],
    threshold: float = 10.0,
    downscale_factor: float = 0.5,
    ransac_threshold: float = 5.0,
    max_ransac_iters: int = 5000,
    confidence_limit: float = 0.99,
) -> np.ndarray:
    """
    Perform a memory-efficient thin plate spline transformation with distance and RANSAC filtering.

    Args:
        image (np.ndarray): The image to be transformed.
        src_bbs (List[BoundingBox]): List of source BoundingBox objects.
        dst_bbs (List[BoundingBox]): List of destination BoundingBox objects.
        threshold (float): The threshold distance for filtering out points (default: 10.0).
        downscale_factor (float): Factor to downscale the image for computing warp (default: 0.25).
        ransac_threshold (float): The threshold distance for RANSAC filtering (default: 5.0).
        max_ransac_iters (int): The maximum number of iterations for RANSAC (default: 5000).
        confidence_limit (float): The confidence limit for RANSAC (default: 0.99).

    Returns:
        np.ndarray: The transformed image at original size
    """

    # Get image dimensions and downscale to new dimensions
    h, w = image.shape[:2]
    small_h, small_w = int(h * downscale_factor), int(w * downscale_factor)

    # Helper function to scale bounding boxes down by a factor
    #   Return bounding boxes as a list of BoundingBox objects
    def scale_bbs(bbs, factor):
        return [
            BoundingBox(
                category=bb.category,
                left=bb.left * factor,
                top=bb.top * factor,
                right=bb.right * factor,
                bottom=bb.bottom * factor,
            ) for bb in bbs
        ]

    # Scale source and destination bounding boxes (destinations are landmarks)
    scaled_src_bbs = scale_bbs(src_bbs, downscale_factor)
    scaled_dst_bbs = scale_bbs(dst_bbs, downscale_factor)

    # Remove duplicates
    # Find duplicate categories
    duplicate_cats = {k for k, v in Counter([bb.category for bb in scaled_src_bbs]).items() if v > 1}
    duplicate_cats |= {k for k, v in Counter([bb.category for bb in scaled_dst_bbs]).items() if v > 1}

    # Filter out the duplicates
    scaled_src_bbs = [bb for bb in scaled_src_bbs if bb.category not in duplicate_cats]
    scaled_dst_bbs = [bb for bb in scaled_dst_bbs if bb.category not in duplicate_cats]

    # Sort bounding boxes by category
    scaled_src_bbs.sort(key=lambda bb: bb.category)
    scaled_dst_bbs.sort(key=lambda bb: bb.category)

    # Apply distance filtering
    # Remove points where the distance between source and destination is greater than the threshold
    #   Defaults to 10 pixels
    scaled_src_bbs, scaled_dst_bbs = __filter_by_distance(scaled_src_bbs, scaled_dst_bbs, threshold)

    # Convert to numpy arrays
    # The points are the centers of the bounding boxes
    src_points = np.array([[(bb.left + bb.right) / 2, (bb.top + bb.bottom) / 2] for bb in scaled_src_bbs], dtype=np.float32)
    dst_points = np.array([[(bb.left + bb.right) / 2, (bb.top + bb.bottom) / 2] for bb in scaled_dst_bbs], dtype=np.float32)

    # Apply RANSAC filtering only if there are at least 4 points to use
    if len(src_points) >= 4:
        # Filter using ransac
        src_x, src_y, dst_x, dst_y = __filter_by_RANSAC(
            scaled_src_bbs, scaled_dst_bbs, 
            threshold=ransac_threshold, 
            max_iters=max_ransac_iters, 
            confidence_limit=confidence_limit
        )

        # If not enough inliers remain, return the original image
        if len(src_x) < 4:
            return image

        # Update src and dst points
        src_points = np.column_stack((src_x, src_y))
        dst_points = np.column_stack((dst_x, dst_y))
    else:
        return image  # Not enough points for TPS

    # Compute displacement field using TPS
    # This uses a radial basis function with thin plate spline kernel to compute the displacement field
    #   rbf_x and rbf_y are the radial basis functions for x and y displacements respectively
    # Interpolation of all the images pixels can be done using these functions to get the final transformed image

    rbf_x = Rbf(dst_points[:, 0], dst_points[:, 1], src_points[:, 0] - dst_points[:, 0], function='thin_plate')
    rbf_y = Rbf(dst_points[:, 0], dst_points[:, 1], src_points[:, 1] - dst_points[:, 1], function='thin_plate')

    # Generate grid for transformation
    # This has a grid for the location of each pixel in the image
    y, x = np.mgrid[0:small_h, 0:small_w].astype(np.float32)

    # Compute displacements for each pixel
    # This has data on the amount of movement of *each pixel* in the x and y directions
    dx = rbf_x(x, y).astype(np.float32)
    dy = rbf_y(x, y).astype(np.float32)

    # Final coordinate mapping
    # Here we do the actual movement of each pixel
    map_x = (x + dx)
    map_y = (y + dy)

    # Scale back to original resolution
    # Then we scale back to the original resolution
    map_x = cv2.resize(map_x, (w, h)) * (1 / downscale_factor)
    map_y = cv2.resize(map_y, (w, h)) * (1 / downscale_factor)

    # Clip to image boundaries, don't have any pixels outside the image
    #   It doesn't delete the pixels, it just moves them back inside the image
    map_x = np.clip(map_x, 0, w - 1)
    map_y = np.clip(map_y, 0, h - 1)

    # Apply warping using a bilnear interpolation
    #   This is the final transformed image
    #   Border reflect reflexts the images at the endges so we don't end up with artifacts
    #   Remaps the image and interpolates the pixel values
    warped = cv2.remap(image, map_x.astype(np.float32), map_y.astype(np.float32), interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)

    return warped


# Function to visualize the transformation
def visualize_transformation(image, transformed_image):
    cv2.imshow("Original Image", image)
    cv2.imshow("Transformed Image", transformed_image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


In [None]:
# Now we need to merge the two functions into one that uses RANSAC to filter out outliers as well as those far from the destination points
def transform_thin_plate_splines_old(
    image: np.ndarray,
    src_bbs: List[BoundingBox],
    dst_bbs: List[BoundingBox],
    max_dist: float = 4.0,
    threshold: float = 10.0,
    max_iters: int = 5000,
    confidence_limit: float = 0.99,
) -> np.ndarray:
    """
    Perform a thin plate spline transformation on the image using the src_bbs and dst_bbs, using RANSAC to filter out outliers.
    We assume that homography was completed prior to calling this function.
    We start by filtering by points that are too far from their destination counterparts, then we use RANSAC to filter out outliers.

    Args:
        image (np.ndarray): The image to be transformed
        src_bbs (List[BoundingBox]): List of BoundingBox objects
        dst_bbs (List[BoundingBox]): List of BoundingBox objects
        max_dist (float, optional): The maximum distance for filtering out points. Defaults to 4.0.
        threshold (float, optional): The threshold distance for RANSAC. Defaults to 10.0.
        max_iters (int, optional): The maximum number of iterations for RANSAC. Defaults to 5000.
        confidence_limit (float, optional): The confidence limit for RANSAC. Defaults to 0.99.

    Returns:
        np.ndarray: The transformed image
    """
    # get the categories from dst_bbs
    landmark_cats = [bb.category for bb in dst_bbs]
    # remove all bbs in src that are not in those categories
    src_bbs = [bb for bb in src_bbs if bb.category in landmark_cats]
    # get list of duplicate keys
    duplicate_count_src = dict(Counter([bb.category for bb in src_bbs]))
    duplicates = [k for k, v in duplicate_count_src.items() if v > 1]
    duplicate_count_dst = dict(Counter([bb.category for bb in dst_bbs]))
    duplicates.extend([k for k, v in duplicate_count_dst.items() if v > 1])
    duplicates = list(set(duplicates))
    # print(duplicates)
    # remove duplicates
    src_bbs = [bb for bb in src_bbs if bb.category not in duplicates]
    dst_bbs = [bb for bb in dst_bbs if bb.category not in duplicates]

    # sort categories alphabetically
    src_bbs = sorted(src_bbs, key=lambda bb: bb.category)
    dst_bbs = sorted(dst_bbs, key=lambda bb: bb.category) # These are your landmarks

    # remove source points with suspiciously high distances to their destination counterparts
    src_bbs, dst_bbs = __filter_by_distance(src_bbs, dst_bbs, max_dist)

    # Now lets use RAANSAC to filter out outliers
    src_x, src_y, dst_x, dst_y = __filter_by_RANSAC(
        src_bbs, dst_bbs, threshold, max_iters, confidence_limit
    )

    # use RBF function to do the thin plate splines
    rbf_x = Rbf(dst_x, dst_y, src_x, function="thin_plate")
    rbf_y = Rbf(dst_x, dst_y, src_y, function="thin_plate")

    # Alter the image according to the transformation
    h, w, _ = image.shape
    # create grid
    x = np.linspace(0, w - 1, w)
    y = np.linspace(0, h - 1, h)
    grid_x, grid_y = np.meshgrid(x, y)

    # apply the transformation
    # reshape into grid
    transformed_x = rbf_x(grid_x, grid_y).astype(np.float32)
    transformed_y = rbf_y(grid_x, grid_y).astype(np.float32)

    transformed_x = np.clip(transformed_x, 0, image.shape[1] - 1)
    transformed_y = np.clip(transformed_y, 0, image.shape[0] - 1)

    # warp the image
    warp_img = cv2.remap(
        image, transformed_x, transformed_y, interpolation=cv2.INTER_LINEAR
    )

    return warp_img

In [19]:
import cProfile
import pstats
from pstats import SortKey

In [None]:
%%memit
# Runs the new function
for sheet, yolo_bbs in yolo_data.items():
    if sheet != "RC_0031_intraoperative.JPG":
        continue
    # get path to current image
    full_image_path = os.path.join(PATH_TO_REGISTERED_IMAGES, sheet)
    image = cv2.imread(full_image_path)
    resized_img = cv2.resize(image, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))

    # Turn resized image to an Image and show it
    resized_img_ = Image.fromarray(resized_img)
    resized_img_.show()

    # get the sheet's bounding boxes
    sheet_bbs = [
        BoundingBox.from_yolo(bb, DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT)
        for bb in yolo_bbs
    ]

    # Run the profiler
    profiler = cProfile.Profile()
    profiler.enable()
    transformed_img_arr = transform_thin_plate_splines(resized_img, sheet_bbs, landmarks)
    profiler.disable()

    # Print stats sorted by cumulative time
    stats = pstats.Stats(profiler).sort_stats(SortKey.CUMULATIVE)
    stats.print_stats(20)  # Show top 20 functions by time
    
    transformed_img = Image.fromarray(transformed_img_arr)
    transformed_img.show()

    #Overlay this image with the original as a sanity check to make sure the transformation is correct
    # Should used image from above as well as transformed_img_arr
    transformed_img = cv2.cvtColor(transformed_img_arr, cv2.COLOR_BGR2RGB)
    # Get data/unified_intraoperative_preoperative_flowsheet_v1_1_back.png
    unified = cv2.imread("../../data/unified_intraoperative_preoperative_flowsheet_v1_1_front.png")
    resized_unified = cv2.resize(unified, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))
    # Show original overlay
    overlay = cv2.addWeighted(resized_unified, 0.5, resized_img, 0.5, 0)
    overlay = Image.fromarray(overlay)
    # Add a name to the window
    overlay.name = "Original Overlay"
    overlay.show()
    overlay = cv2.addWeighted(resized_unified, 0.5, transformed_img, 0.5, 0)
    overlay = Image.fromarray(overlay)
    # Add a name to the window
    overlay.name = "Transformed Overlay"
    overlay.show()


         1792 function calls in 0.252 seconds

   Ordered by: cumulative time
   List reduced from 95 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.012    0.012    0.251    0.251 C:\Users\15406\AppData\Local\Temp\ipykernel_12712\1017396507.py:1(transform_thin_plate_splines)
        2    0.007    0.004    0.217    0.109 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.venv\Lib\site-packages\scipy\interpolate\_rbf.py:280(__call__)
        4    0.165    0.041    0.165    0.041 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.venv\Lib\site-packages\scipy\interpolate\_rbf.py:169(_h_thin_plate)
        2    0.000    0.000    0.043    0.021 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.venv\Lib\site-packages\scipy\interpolate\_rbf.py:277(_call_norm)
        2    0.000    0.000    0.043    0.021 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.venv\

In [None]:
%%memit
# Runs the old version of the function
for sheet, yolo_bbs in yolo_data.items():
    if sheet != "RC_0031_intraoperative.JPG":
        continue
    # get path to current image
    full_image_path = os.path.join(PATH_TO_REGISTERED_IMAGES, sheet)
    image = cv2.imread(full_image_path)
    resized_img = cv2.resize(image, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))

    # Turn resized image to an Image and show it
    resized_img_ = Image.fromarray(resized_img)
    resized_img_.show()

    # get the sheet's bounding boxes
    sheet_bbs = [
        BoundingBox.from_yolo(bb, DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT)
        for bb in yolo_bbs
    ]

    # Run the profiler
    profiler = cProfile.Profile()
    profiler.enable()
    transformed_img_arr = transform_thin_plate_splines_old(resized_img, sheet_bbs, landmarks)
    profiler.disable()

    # Print stats sorted by cumulative time
    stats = pstats.Stats(profiler).sort_stats(SortKey.CUMULATIVE)
    stats.print_stats(20)  # Show top 20 functions by time
    
    transformed_img = Image.fromarray(transformed_img_arr)
    transformed_img.show()

    #Overlay this image with the original as a sanity check to make sure the transformation is correct
    # Should used image from above as well as transformed_img_arr
    transformed_img = cv2.cvtColor(transformed_img_arr, cv2.COLOR_BGR2RGB)
    # Get data/unified_intraoperative_preoperative_flowsheet_v1_1_back.png
    unified = cv2.imread("../../data/unified_intraoperative_preoperative_flowsheet_v1_1_front.png")
    resized_unified = cv2.resize(unified, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))
    # Show original overlay
    overlay = cv2.addWeighted(resized_unified, 0.5, resized_img, 0.5, 0)
    overlay = Image.fromarray(overlay)
    # Add a name to the window
    overlay.name = "Original Overlay"
    overlay.show()
    overlay = cv2.addWeighted(resized_unified, 0.5, transformed_img, 0.5, 0)
    overlay = Image.fromarray(overlay)
    # Add a name to the window
    overlay.name = "Transformed Overlay"
    overlay.show()


         584 function calls in 0.825 seconds

   Ordered by: cumulative time
   List reduced from 105 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.013    0.013    0.825    0.825 C:\Users\15406\AppData\Local\Temp\ipykernel_12712\4190479871.py:2(transform_thin_plate_splines_old)
        2    0.026    0.013    0.797    0.399 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.venv\Lib\site-packages\scipy\interpolate\_rbf.py:280(__call__)
        4    0.600    0.150    0.600    0.150 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.venv\Lib\site-packages\scipy\interpolate\_rbf.py:169(_h_thin_plate)
        2    0.000    0.000    0.162    0.081 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.venv\Lib\site-packages\scipy\interpolate\_rbf.py:277(_call_norm)
        2    0.000    0.000    0.162    0.081 c:\Users\15406\Coding-Projects\Paper Chart Extraction\Supplements\.v