In [None]:
#basic analysis package
import numpy as np
import pandas as pd
from pathlib import Path
import tifffile as tf
from importlib import reload
#enable relative import
import sys 
sys.path.append("..")
from helpers.util import pil_imread
#cusom packages
import fiducial_alignment_affine as fa_affine
%config InlineBackend.figure_format='retina'

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

def plot_2d_image(img_2d, zmax=1000, animation = True):
    
    if animation == True:   
    #For Plotting 2d image
        #-------------------------------------------
        fig = px.imshow(
            img_2d,
            width=700,
            height=700,
            binary_string=True,
            binary_compression_level=4,
            binary_backend='pil',
            zmax = zmax,
            animation_frame=0
        )
        #-------------------------------------------

        fig.show()
    else:
        #For Plotting 2d image
        #-------------------------------------------
        fig = px.imshow(
            img_2d,
            width=700,
            height=700,
            binary_string=True,
            binary_compression_level=4,
            binary_backend='pil',
            zmax = zmax
        )
        #-------------------------------------------

        fig.show()

In [None]:
def plot_2d_locs_on_2d_image(df_locs_2d_1, img_2d, zmax=1000):
    
    #For Plotting 2d image
    #-------------------------------------------
    fig = px.imshow(
        img_2d,
        width=700,
        height=700,
        binary_string=True,
        binary_compression_level=4,
        binary_backend='pil',
        zmax = zmax
    )
    #-------------------------------------------
    
    #For Plotting 2d dots
    #-------------------------------------------
    fig.add_trace(go.Scattergl(
        x=df_locs_2d_1.x,
        y=df_locs_2d_1.y,
        mode='markers',
        marker_symbol='cross',
        marker=dict(
            #maxdisplayed=1000,
            size=4
            ),
        name = "Gaussian"
        )
    )
    #-------------------------------------------
    
    fig.show()

In [None]:
#get image paths
Pos = 0
bead_src = Path("/groups/CaiLab/personal/Lex/raw/Linus_10k_cleared_080918_NIH3T3/beads")
bead_src = bead_src / f"MMStack_Pos{Pos}.ome.tif"

tiff_src = Path("/groups/CaiLab/personal/Lex/raw/Linus_10k_cleared_080918_NIH3T3/pyfish_tools/output/dapi_aligned/HybCycle_0")
tiff_src = tiff_src / f"MMStack_Pos{Pos}.ome.tif"

In [None]:
#read in beads
beads = pil_imread(str(bead_src),num_channels=4, swapaxes=True)
#read in image
raw = pil_imread((str(tiff_src)), num_channels=None, swapaxes=False)

In [None]:
#make sure shapes match
beads.shape == raw.shape

In [None]:
#look at beads
plot_2d_image(beads[0], zmax=2000)

In [None]:
#look at raw
plot_2d_image(raw[0], zmax=2000)

In [None]:
#check how off
plot_2d_image(np.array([raw[0][0],beads[0][0]]), zmax=1000)

# Test fiducial alignment on one position

In [None]:
import time
start = time.time()

#set bead channel to None if there are beads in all channels

image, error = fa_affine.fiducial_alignment_single(tiff_src, bead_src, threshold_abs=800, 
                                                  max_dist=1,ransac_threshold=0.20, bead_channel_single=None,
                                                  include_dapi=False,use_ref_coord=True, 
                                                  num_channels=4, write=False)
print(f"This task took {(time.time()-start)/60} min")

In [None]:
#look at displacement ([channel, percent change, displacement])
error

In [None]:
#look at image to make sure transform looks normal
plot_2d_image(image[0], zmax=3000)

In [None]:
#check how off
plot_2d_image(np.array([beads[0][0],image[0][0]]), zmax=5000)

# Check FWHM across all hybs

In [None]:
import glob
import matplotlib.pyplot as plt

In [None]:
#grab all files
pixel_size_nm = 108
hyb_all = []
for hyb in range(80):
    path = f"/groups/CaiLab/personal/Lex/raw/Linus_10k_cleared_080918_NIH3T3/pyfish_tools/output/fiducial_aligned/HybCycle_{hyb}/*_error.txt"
    files = glob.glob(path)
    error_list = []
    for file in files:
        error = pd.read_csv(file, sep = " ", header=None)[2].values
        error_list.append(error)
    final_error = np.array(error_list) * pixel_size_nm
    hyb_all.append(final_error)

In [None]:
#reformat
final = []
for pos in range(7):
    by_fov = []
    for hyb in hyb_all:
        by_fov.append(hyb[pos])
    df = pd.DataFrame(by_fov)
    df.columns = ["Ch1", "Ch2", "Ch3"]
    final.append(df)
    
#mean
running_sum_ch1 = final[0].Ch1.values.copy()
for pos in range(1, len(final)):
    running_sum_ch1 += final[pos].Ch1.values
mean_ch1 = running_sum_ch1/len(final)

running_sum_ch2 = final[0].Ch2.values.copy()
for pos in range(1, len(final)):
    running_sum_ch2 += final[pos].Ch2.values
mean_ch2 = running_sum_ch2/len(final)

running_sum_ch3 = final[0].Ch3.values.copy()
for pos in range(1, len(final)):
    running_sum_ch3 += final[pos].Ch3.values
mean_ch3 = running_sum_ch3/len(final)

In [None]:
#plot
for pos in range(len(final)):
    plt.plot(np.arange(0,80,1), final[pos].Ch1, alpha=0.10, lw = 1, color="red")
plt.plot(np.arange(0,80,1), mean_ch1, color="red",  label="Channel 647 nm")
for pos in range(len(final)):
    plt.plot(np.arange(0,80,1), final[pos].Ch2, alpha=0.10, lw = 1, color="orange")
plt.plot(np.arange(0,80,1), mean_ch2, color="orange",  label="Channel 561 nm")
for pos in range(len(final)):
    plt.plot(np.arange(0,80,1), final[pos].Ch3, alpha=0.10, lw = 1, color="green")
plt.plot(np.arange(0,80,1), mean_ch3, color="green",  label="Channel 488 nm")
plt.ylim(0,50)
plt.ylabel("FWHM (nm)")
plt.xlabel("HybCycles")
plt.legend()
plt.show()