<a href="https://colab.research.google.com/github/WinetraubLab/coregister-xy/blob/main/coregister_xy_2.ipynb" target="_blank">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>
<a href="https://github.com/WinetraubLab/coregister-xy/blob/main/coregister_xy_2.ipynb" target="_blank">
  <img src="https://img.shields.io/badge/view%20in-GitHub-blue" alt="View in GitHub"/>
</a>

# Overview
Use this notebook to get alignment information from ImageJ image registration. Print stats for individual barcodes and calculate mapping from u,v pixels to x,y,z physical space.

In [1]:
# @title Notebook Inputs { display-mode: "form" }
# @markdown How to use this notebook: [See Instructions](https://docs.google.com/document/d/1G2AME1q6XQhxQ1A2FhkfpktaSFpNXSNQB6mEWwM0YM0/edit?usp=sharing)
import numpy as np
from google.colab import drive
from google.colab import files
drive.mount('/content/drive/')

# @markdown Input Paths:
# @markdown Leave either path blank to load a file from local file system.
trakem_xml_path = "/content/drive/Shareddrives/Yolab - Current Projects/_Datasets/2024-09-04 Multiple Barcode Alignment/align8.xml" # @param {type:"string"}
fluorescent_patch_number = 8 # @param {type:"integer"}
# @markdown For the alignment of multiple templates to one fluorescent image, specify the patch numbers of each template in the TrakEM stack.
# template_patch_1 = 11 # @param {type:"integer"}
# template_patch_2 = 14 # @param {type:"integer"}
# template_patch_3 = 17 # @param {type:"integer"}

# @markdown Enter template patch IDs in order as a comma-separated list. Example: [11, 14, 17]
template_patch_list = [11, 14, 17, 20, 23, 26, 29, 32] # @param

# @markdown Z-depth of each template, in um, as a comma-separated list. Example: [50, 52, 54]
template_z_list = [32, 34, 32, 32, 34, 34, 34, 36] # @param

# @markdown Real (x,y) locations of photobleach barcode centers, as specified by script used to photobleach. Enter in format: (1000,0), (0, 1000), (0,0target_centersters = [0,1000], [1000, 1000], [0, 0] # @partarget_centersenters = np.target_centers_ctarget_centerset_centers = np.colutarget_centersrget_centertarget_centerstarget_centers.shape[0])))
target_centers = [0,0], [1000, 0], [2000, 0], [0,1000], [1000, 1000], [2000,1000],  [1000, 2000], [2000,2000] # @param

target_centers = np.array(target_centers)
assert target_centers.shape[1] == 2, "Points in target_centers should be in format [x,y]"

template_size = 401
um_per_pixel = 2

assert len(template_patch_list) == len(template_z_list), "Number of elements in template patch list and template z list must match"

if not trakem_xml_path:
  print("Upload saved TrakEM project:")
  uploaded = files.upload()
  trakem_xml_path = list(uploaded.keys())[0]
  trakem_xml_path = os.path.join(os.getcwd(), trakem_xml_path)


Mounted at /content/drive/


In [2]:
# @title Environment Setup
!git clone https://github.com/WinetraubLab/coregister-xy.git
%cd coregister-xy

from plane.fit_template import FitTemplate
from plane.fit_plane import FitPlane
import matplotlib.pyplot as plt
import os
from google.colab import files
import math
import pandas as pd
import numpy as np

%cd ..

Cloning into 'coregister-xy'...
remote: Enumerating objects: 701, done.[K
remote: Counting objects: 100% (171/171), done.[K
remote: Compressing objects: 100% (95/95), done.[K
remote: Total 701 (delta 125), reused 98 (delta 76), pack-reused 530 (from 1)[K
Receiving objects: 100% (701/701), 1.09 MiB | 6.60 MiB/s, done.
Resolving deltas: 100% (432/432), done.
/content/coregister-xy
/content


In [3]:
# @title Print Template Stats

# Setup code
fts = []
templates = [int(x) for x in template_patch_list]

for i in range(0, len(templates)):
    ft = FitTemplate.from_imagej_xml(trakem_xml_path, fluorescent_patch_number, templates[i], None, True)
    fts.append(ft)

uv_px = []
for i, ft in enumerate(fts):
    uv_px.append((ft.tx + template_size/2, ft.ty + template_size/2))

zs = np.array([float(x) for x in template_z_list])
template_centers_xyz = np.column_stack((target_centers, zs))
template_centers_uv = np.array(uv_px)

fp = FitPlane.from_template_centers(template_centers_uv, template_centers_xyz, print_inputs = False)


"""
Print stats for each FitPlane as a table: shrinkage, rotation, shear, and mean/stdev for each
Units: um
"""
num_templates = len(fts)
projects_data = {
"Template ID": [i for i in range(1, num_templates+1)],
"Center (x, pix)": [project.tx + template_size/2 for project in fts],
"Center (y, pix)": [project.ty + template_size/2 for project in fts],
"Rotation (deg)": [project.theta_deg for project in fts],
"Scaling": [project.scale for project in fts],
"Shear magnitude": [project.shear_magnitude for project in fts],
"Shear unit vector (x)": [project.shear_vector[0] for project in fts],
"Shear unit vector (y)": [project.shear_vector[1] for project in fts]
}

columns_to_summarize = ["Rotation (deg)", "Scaling", "Shear magnitude", "Shear unit vector (x)", "Shear unit vector (y)"]

# Create DataFrame
df = pd.DataFrame(projects_data)

# Compute mean and standard deviation for selected columns only
mean_row = df[columns_to_summarize].mean()
std_row = df[columns_to_summarize].std()

# Append mean and std as new rows for selected columns only
summary_df = df.copy()
summary_df.loc['Mean', columns_to_summarize] = mean_row
summary_df.loc['StDev', columns_to_summarize] = std_row
summary_df = summary_df.round(2)
summary_df = summary_df.replace(np.nan, '', regex=True)

summary_df


Unnamed: 0,Template ID,"Center (x, pix)","Center (y, pix)",Rotation (deg),Scaling,Shear magnitude,Shear unit vector (x),Shear unit vector (y)
0,1.0,644.31,1062.96,11.78,1.66,0.04,0.8,0.6
1,2.0,1454.47,1166.28,10.43,1.77,0.03,0.73,0.68
2,3.0,2425.81,1332.18,11.99,1.83,0.14,0.99,0.11
3,4.0,532.06,1851.14,13.16,1.79,0.07,1.0,0.02
4,5.0,1360.98,2003.33,9.54,1.88,0.07,0.96,0.28
5,6.0,2300.31,2129.16,11.4,1.93,0.04,0.95,0.32
6,7.0,1190.89,2833.23,13.24,1.86,0.09,1.0,0.08
7,8.0,2125.52,3012.87,11.73,1.97,0.06,0.95,-0.3
Mean,,,,11.66,1.84,0.07,0.92,0.22
StDev,,,,1.25,0.1,0.04,0.1,0.32


In [4]:
# @title Plane Mapping Stats
# Print stats
print('Fit Statistics:')
print('---------------')
print("Pixel Size: |u| = %.3fum, |v| = %.3fum" %
  (fp.u_norm_mm(), fp.v_norm_mm() ) )
print("Signed Distance from Origin: %.3fmm" % fp.distance_from_origin_mm())
print("X-Y Rotation: %.2f degrees" % fp.xy_rotation_deg())
print("Tilt: %.2f degrees" % fp.tilt_deg())

# Serialize to JSON
print("")
print('Data to Keep:')
print('---------------')
print("fp = FitPlane.from_json('" + fp.to_json() + "')")

# Write data to JSON for easy loading

outfile = "save_mapping.json" # Change filepath as needed

json_data = fp.to_json()
with open(outfile, 'w') as f:
    f.write(json_data)

Fit Statistics:
---------------
Pixel Size: |u| = 1.107um, |v| = 1.186um
Signed Distance from Origin: 31.890mm
X-Y Rotation: 9.84 degrees
Tilt: 89.94 degrees

Data to Keep:
---------------
fp = FitPlane.from_json('{"u": [1.0910388967059503, -0.18930164426478685, 0.000538955869956731], "v": [0.1907877197614561, 1.1707269179641029, 0.0012787894742597704], "h": [-883.2992011346798, -1092.5888019513498, 30.228996490115296]}')


In [5]:
uv_to_xyz = [fp.get_xyz_from_uv(p) for p in template_centers_uv]
uv_to_xyz = np.array(uv_to_xyz)

in_plane_err = fp.avg_in_plane_projection_error(uv_to_xyz, template_centers_xyz)
print(f"Average in-plane mapping error: {in_plane_err:.2f} um")

out_plane_err = fp.avg_out_of_plane_projection_error(uv_to_xyz, template_centers_xyz)
print(f"Average out-of-plane mapping error: {out_plane_err:.2f} um")


Average in-plane mapping error: 37.70 um
Average out-of-plane mapping error: 0.70 um


In [19]:
# @title Visualize Plane with Points
import plotly.graph_objects as go

# plane coeffs
normal_vector = np.cross(fp.u, fp.v)
a,b,c = normal_vector
d = -1 * (a*fp.h[0] + b*fp.h[1] + c*fp.h[2])

# Generate grid for the plane
x = np.linspace(-10, 2700, 100)
y = np.linspace(-10, 2700, 100)
X, Y = np.meshgrid(x, y)
Z = (-d - a * X - b * Y) / c  # Solve for Z

# Create the plane surface
plane = go.Surface(x=X, y=Y, z=Z, colorscale='Viridis', opacity=0.7, name='Plane',showscale=False)

# Create the points
uv = go.Scatter3d(
    x=uv_to_xyz[:, 0],
    y=uv_to_xyz[:, 1],
    z=uv_to_xyz[:, 2],
    mode='markers',
    marker=dict(size=4, color='blue'),
    name='UV to XYZ'
)
xyz = go.Scatter3d(
    x=template_centers_xyz[:, 0],
    y=template_centers_xyz[:, 1],
    z=template_centers_xyz[:, 2],
    mode='markers',
    marker=dict(size=4, color='red'),
    name='True XYZ'
)

fig = go.Figure()
fig.add_trace(plane)
fig.add_trace(uv)
fig.add_trace(xyz)
fig.show()