In [None]:
#data management and custom functions
from pathlib import Path
from util import pil_imread
import tifffile as tf
from chromatic_aberration_correction import *
#plotting packages
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

%config InlineBackend.figure_format='retina'

# Path to ref image for testing

In [None]:
ref_img = f"/groups/CaiLab/personal/Lex/raw/Linus_10k_cleared_080918_NIH3T3/pyfish_tools/output/z_matched_images/beads/MMStack_Pos0.ome.tif"

# Get offsets and corrected image

In [None]:
import time
start = time.time()
transformed_image, error, tform = chromatic_corr_offsets(ref_img, threshold_abs=800,
                                                         max_dist=1.5, ransac_threshold = 0.20, num_channels=4)
print(f"This task took {(time.time()-start)/60} min")

In [None]:
#alignment error
error

In [None]:
#look at transformation matrix
tform

# Compare corrected and original image

In [None]:
def plot_2d_image(img_2d, zmax=1000, animation = True):
    
    if animation == True:   
    #For Plotting 2d image
        #-------------------------------------------
        fig = px.imshow(
            img_2d,
            width=700,
            height=700,
            binary_string=True,
            binary_compression_level=4,
            binary_backend='pil',
            zmax = zmax,
            animation_frame=0
        )
        #-------------------------------------------

        fig.show()
    else:
        #For Plotting 2d image
        #-------------------------------------------
        fig = px.imshow(
            img_2d,
            width=700,
            height=700,
            binary_string=True,
            binary_compression_level=4,
            binary_backend='pil',
            zmax = zmax,
        )
        #-------------------------------------------

        fig.show()

In [None]:
#read in images
original = pil_imread(ref_img, swapaxes=True)
original_max= np.max(original, axis=0)
transformed_max = np.max(transformed_image,axis=0)

In [None]:
#original
plot_2d_image(original_max, zmax=5000, animation = True)

In [None]:
#corrected
plot_2d_image(transformed_max, zmax=5000, animation = True)

# Check average error

In [None]:
import seaborn as sns

In [None]:
poss = np.arange(0,7,1)
error_list = []
for pos in poss:
    src = f"/groups/CaiLab/personal/Lex/raw/Linus_10k_cleared_080918_NIH3T3/pyfish_tools/output/z_matched_images/beads/MMStack_Pos{pos}_error.txt"
    error = pd.read_csv(src, sep = " ", header=None)
    error_list.append(error)

In [None]:
comb_error = pd.concat(error_list).reset_index(drop=True)
comb_error.columns = ["channel", "percent improvement", "fwhm"]
comb_error.fwhm = comb_error.fwhm * 108 #108 nm/pixel

In [None]:
sns.boxplot(data=comb_error, x="channel", y="fwhm", palette = "Set2")
plt.ylim(0,50)
plt.xlabel("")
plt.xticks([0,1],["Channel 561 nm", "Channel 488 nm"])
plt.ylabel("FWHM (nm)")
plt.show()

In [None]:
comb_error.groupby("channel").mean()

In [None]:
comb_error.groupby("channel").std()