# Thin Plate Splines Transformations

**Purpose:** The purpose of this experiment to implement thin plate spline transformations on images. The goal is to help normalize images of charts that have been folded or bent. TPS registration will be performed post-homography

## What are Thin Plate Splines

For our purpose, a thin plate spline transformation can be defined as follows:  
Given a set of "source" points $P_s$ and a corresponding set of "destination" points $P_d$, the thin plate spline transformation is a function  
$f_{tps}(p) | \forall p_s \in P_s, f_{tps}$ and $f$ is the function when "bends" all other points the least among all differentiable functions (meaning it minimizes an energy function).  

![thin_plate_spline_example](https://github.com/user-attachments/assets/9f269c79-6d5d-4090-988c-39ee9e928ff0)

In this image there are two sets of points which form fish shapes. The red '+' points represent the destination points, and the green 'o' points represent the source points. On the right we see the thin plate spline deformation which makes all the green points exactly match all the red points, and we can also see a grid that shows where all other points on the plane will be mapped as well.  

## Benefits for Image Registration
While the homography is a very reliable transformation, it is limited to purely linear transformations (ex: rotation, scaling, shear).  
Very commonly, pages will be creased or folded, the edge of pages will curl inwards or outwards, or some cameras will cause barrel/pinhole distortions.  
This leads to a whole subset of issues that cannot be corrected linearly, but still must be accounted for.

![smaller_RC_0033_preoperative_postoperative](https://github.com/user-attachments/assets/f22374e1-75b4-4a07-816f-d3ad51e84921)
An example of a paper which has been folded. and unfolded, causing non-linear distortions. This is a *very* modest example. In practice, we have seen photographs of papers that are raised at least an inch off the table in certain areas.

In [2]:
# imports
import os
import cv2
import numpy as np
import json
from pathlib import  Path
from typing import Dict, List, Tuple
from utils.annotations import BoundingBox
from collections import Counter
from PIL import Image
from scipy.interpolate import Rbf
import matplotlib.pyplot as plt

Load the data for testing:
- "intraop_document_landmarks.json"
    - Used for the destination points in the transformation. We are using the landmarks from the unified image
    - Also has landmarks for each of the images. We currently are not using them
- "yolo_data.json"
    - Used for the source points in the transformation.
    - May need to be replace with the landmarks from the other file

In [3]:
# Load yolo_data.json which will be used as src_points
PATH_TO_YOLO_DATA = "../../data/yolo_data.json"
PATH_TO_REGISTERED_IMAGES = "../../data/registered_images"

with open(PATH_TO_YOLO_DATA) as json_file:
    yolo_data = json.load(json_file)

print(f"Found {len(yolo_data)} sheets in yolo_data.json")

# load introp_document_landmarks.json which will be used as dst_points
PATH_TO_LANDMARKS = "../../data/intraop_document_landmarks.json"

DESIRED_IMAGE_WIDTH = 800
DESIRED_IMAGE_HEIGHT = 600

def label_studio_to_bboxes(path_to_json_data: Path) -> List[BoundingBox]:
    json_data: List[Dict] = json.loads(open(str(path_to_json_data)).read())
    return {
        sheet_data['data']['image'].split("-")[-1]:[
            BoundingBox(
                category=label['value']['rectanglelabels'][0],
                left=label['value']['x']/100*DESIRED_IMAGE_WIDTH,
                top=label['value']['y']/100*DESIRED_IMAGE_HEIGHT,
                right=(label['value']['x']/100+label['value']['width']/100)*DESIRED_IMAGE_WIDTH,
                bottom=(label['value']['y']/100+label['value']['height']/100)*DESIRED_IMAGE_HEIGHT,
            )
            for label in sheet_data['annotations'][0]['result']
        ]
        for sheet_data in json_data
    }

landmark_location_data: Dict[str, List[BoundingBox]] = label_studio_to_bboxes(PATH_TO_LANDMARKS)

landmarks = landmark_location_data['unified_intraoperative_preoperative_flowsheet_v1_1_front.png']


Found 19 sheets in yolo_data.json


**TPS Tranformation**
Steps:
1. Filter to keep only the relevant bounding boxes
    - remove all bounding boxes from the source points that do not match a category in the destination points
    - Find all of the duplicates in the source and destination points and remove them
    - sort the source and destination points alphabetically via their category
2. Get lists of the x and y coordinates for both the source and destination points
    - Primary purpose is to enable the use of scipy's Rbf function
    - We are using the top left corner of the bounding boxes
3. Estimate the transformation
    - Use the Rbf function to apply the TPS transformation
4. Apply the transformation and Warp the image
     - Create a grid from 0 to maximum value of the image
     - Apply the transformation to the grids
     - Ensure that the transformed points are within bounds
     - Use those grids to warp the original image

*Note:* There are a lot of different outputs that are currently commented out that can be used for debugging purposes:
- Print the lists of duplicate keys and the categories being used in the source and destination points
- Plot of the source and destination points on the image
- View the bounds of the transformed points
- View the distribution of the transformed points

In [None]:
def tps_transform(image, src_points: List[BoundingBox], dst_points: List[BoundingBox]):

    # get the categories from dst_points
    landmark_cats = [bb.category for bb in dst_points]
    # remove all bbs in src that are not in those categories
    src_points = [bb for bb in src_points if bb.category in landmark_cats]
    # get list of duplicate keys
    duplicate_count_src = dict(Counter([bb.category for bb in src_points]))
    duplicates = [k for k, v in duplicate_count_src.items() if v > 1]
    duplicate_count_dst = dict(Counter([bb.category for bb in dst_points]))
    duplicates.extend([k for k, v in duplicate_count_dst.items() if v > 1])
    duplicates = list(set(duplicates)) 
    # print(duplicates)
    # remove duplicates
    src_points = [bb for bb in src_points if bb.category not in duplicates]
    dst_points = [bb for bb in dst_points if bb.category not in duplicates]
    # sort categories alphabetically
    src_points = sorted(src_points, key = lambda bb: bb.category)
    # print([bb.category for bb in src_points])
    dst_points = sorted(dst_points, key = lambda bb: bb.category)
    # print([bb.category for bb in dst_points])

    # get lists of the x and y coordinates
    src_x , src_y = zip(*[(bb.left, bb.top) for bb in src_points])
    dst_x, dst_y = zip(*[(bb.left, bb.top) for bb in dst_points])

    # plt.imshow(image)
    # plt.scatter(src_x, src_y, color='blue', label='Source Points')
    # plt.scatter(dst_x, dst_y, color='red', label='Destination Points')
    # plt.legend()
    # plt.title("Source and Destination Control Points")
    # plt.show()

    # use RBF function to do the thin plate splines
    rbf_x = Rbf(dst_x, dst_y, src_x, function="thin_plate")
    rbf_y = Rbf(dst_x, dst_y, src_y, function="thin_plate")


    # Alter the image according to the transformation
    h, w, _ = image.shape
    # create grid
    x = np.linspace(0, w-1, w)
    y = np.linspace(0, h-1, h)
    grid_x, grid_y = np.meshgrid(x, y)

    # apply the transformation
    # reshape into grid
    transformed_x = rbf_x(grid_x, grid_y).astype(np.float32)
    transformed_y = rbf_y(grid_x, grid_y).astype(np.float32)

    transformed_x = np.clip(transformed_x, 0, image.shape[1] - 1)
    transformed_y = np.clip(transformed_y, 0, image.shape[0] - 1)


    # print(f"Transformed X range: {np.min(transformed_x)}, {np.max(transformed_x)}")
    # print(f"Transformed Y range: {np.min(transformed_y)}, {np.max(transformed_y)}")

    # plt.imshow(transformed_x, cmap='coolwarm', interpolation='nearest')
    # plt.title("Transformed X Coordinates")
    # plt.show()

    # plt.imshow(transformed_y, cmap='coolwarm', interpolation='nearest')
    # plt.title("Transformed Y Coordinates")
    # plt.show()

    # warp the image
    warp_img = cv2.remap(image, transformed_x, transformed_y, interpolation=cv2.INTER_LINEAR)

    return warp_img

In [21]:
for sheet, yolo_bbs in yolo_data.items():
    # get path to current image
    full_image_path = os.path.join(PATH_TO_REGISTERED_IMAGES, sheet)
    image = cv2.imread(full_image_path)
    resized_img = cv2.resize(image, (DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT))
    # get the sheet's bounding boxes
    sheet_bbs = [BoundingBox.from_yolo(bb, DESIRED_IMAGE_WIDTH, DESIRED_IMAGE_HEIGHT) for bb in yolo_bbs]
    transformed_img = tps_transform(resized_img, sheet_bbs, landmarks)
    transformed_img = Image.fromarray(transformed_img)
    transformed_img.show()


['2', '9', '8', '1', '3', 'code', '7', '0', 'mg', 'ml', 'minute', '4', 'trendeleburg', 'pcnt', 'temperature', '5', '6', 'hour_24hr']
['BPM', 'airway', 'airway_device', 'anesthesia_end', 'anesthesia_start', 'blood_loss', 'bronchoscope', 'capnography', 'central_iv_line', 'degree_C', 'des', 'diastolic', 'difficult_ventilation', 'direct_laryngoscopy', 'dl_view', 'drug_name', 'easy_ventilation', 'ecg', 'etco2', 'ett_n', 'eye_protection', 'fentanyl', 'fio2', 'fluid_blood_product', 'fowler', 'gastric_tube', 'halo', 'heart_rate', 'inhaled_exhaled', 'inhaled_volatile', 'iso', 'lateral', 'lithotomy', 'lma_n', 'mask_ventilation', 'micro_g', 'mmHg', 'monitoring_details', 'natural', 'nibp', 'other_airway_device', 'peripheral_iv_line', 'position', 'procedure_details', 'prone', 'propofol', 'respiratory_rate', 'reverse_trendelenburg', 'rocuronium', 'safety_checklist', 'sev', 'sitting', 'spo2', 'supine', 'surgery_end', 'surgery_start', 'systolic', 'ted_stockings', 'tidal_volume', 'total', 'tubes_and_li