In [None]:
#Install dependencies
!pip install -U motion-correction ipywidgets
#Or for gpu support (not supported on mac)
#!pip install -U motion-correction[gpu] ipywidgets

In [None]:
# load packages

from motion_correction.desktop.flim_aligner import *


### Step 1. Construct a flim aligner object. 
Specifying the arguments of the flim aligner during construction will probably be supported in the future


In [None]:
flim_aligner = FlimAligner()

### Step 2. Set the alignment method. 

Currently, the following methods are supported:

**Global Methods**: Phase \
**Local Methods**: Morphic, OpticalPoly, OpticalTVL1, OpticalILK

You can set the global_method or local_method to 'None' to only apply local or global correction.

In [None]:
from motion_correction. algorithms import Phase, Morphic, OpticalPoly, OpticalILK, OpticalTVL1
phase = Phase()
morphic = Morphic(sigma_diff=20,radius=15)
flim_aligner.set_methods(global_method=phase, local_method=morphic)

### Step 3. Set similarity metric
Similarity metric will help to evaluate the correction performance. The following metrics are supported:
**SimMetric.NCC**: Normalized Cross-Correction or Pearson Correlation \
**SimMetric.MSE**: Mean Square Error\
**SimMetric.NRM**: Normalized Root MSE\
**SimMetric.SSI**: Structure Similarity Index

In [None]:
flim_aligner.set_sim_metric(sim=SimMetric.MSE)

In [None]:
flim_aligner.set_channel(1)

### Step 4. Apply correction based on intensity 
Apply correction based on the intensity images/frames from pt_file_path. After correction, the original and the corrected intensity images/frames are stored in ***flim_align.flim_frames*** and ***flim_align.flim_frames_corrected***, respectively. The transformation/correction matrix is stored in ***flim_align.transforms***.

In [None]:
#name = "02b pancreas 1000hz zoom=6 _10_1.pt3"
#pt_file_path = os.path.join(os.getcwd(), name)
pt_file_path = "/Users/tkallady/Downloads/RhoA ms881 intenstine 1000Hz unidirectional.pt3"
flim_aligner.apply_correction_intensity(pt_file_path)

### Step 5. Export visualization results
You can export the intensity frames and accumulated intensity image as well as similarity plots to visualize the correction results. Optionally, you can specify the save_dir where the visualization results will be stored. By default, the results will stored in the 'save_dir' folder under the currect working directory.

In [None]:
#Requires ffmpeg: https://support.audacityteam.org/basics/installing-ffmpeg
flim_aligner.export_results(save_dir=None)

In [None]:
# Optionally, you can visualize the corrected data within notebook
%matplotlib inline
from motion_correction.desktop.utility import plot_sequence_images, display_images
import matplotlib.pyplot as plt

plot_sequence_images(flim_aligner.flim_frames.transpose(2, 1, 0))
plot_sequence_images(flim_aligner.flim_frames_corrected.transpose(2, 1, 0))

display_images([flim_aligner.flim_frames.sum(axis=-1), 
                flim_aligner.flim_frames_corrected.sum(axis=-1)])

fig, axes = plt.subplots()
axes.plot(flim_aligner.old_sim, label="original")
axes.plot(flim_aligner.new_sim, label="corrected")
axes.set_ylabel(flim_aligner.sim_metric)
axes.set_xlabel('Frame')
plt.legend(loc="best")
plt.show()

### Step 6. Apply correction to raw flim data
This step may take a few minutes as it involves reading the raw flim data into a sparse matrix, applying correction for all nanotimes data, and saving the corrected data into pt3 file. Note that due to the nature of correction, the corrected data is stored as floating values but converted into uint16 data type before saving to pt3 file. Apart from the pt3 file, the following three matrices might be of your further interest for downstream analyses:

**Original histogramed data**: ***flim_aligner.curve_fit***\
**Corrected hitogramed data based on integer values**: ***flim_aligner.curve_fit_corrected_int***\
**Corrected hitogramed data based on integer values**: ***flim_aligner.curve_fit_corrected***

In [None]:
flim_aligner.apply_correction_flim()

In [None]:
# Compare the histograms within a square block of blk_sz at (row_idx, col_idx) 
# Pay attention the range of y-axis for fair comparison. A larger y-axis range makes the curve look more smoother.
%matplotlib inline
col_idx = 105
row_idx = 180
blk_sz = 5
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].plot(flim_aligner.curve_fit[:, row_idx:row_idx+blk_sz, col_idx:col_idx+blk_sz].sum(axis=(1,2)))
axes[0].set_title("original")
axes[1].plot(flim_aligner.curve_fit_corrected[:, row_idx:row_idx+blk_sz, col_idx:col_idx+blk_sz].sum(axis=(1,2)))
axes[1].set_title("corrected")
axes[2].plot(flim_aligner.curve_fit_corrected_int[:, row_idx:row_idx+blk_sz, col_idx:col_idx+blk_sz].sum(axis=(1,2)))
axes[2].set_title("corrected int")
plt.show()

In [None]:
# Use the following code to interactively inspect the decay curve data
%matplotlib notebook

class RectangularROI:
    
    def __init__(self, fig, decay_data, image_AX, decay_AX, decay_fig, tau_resolution):
        self.fig = fig
        self.decay_data     = decay_data
        self.image_ax       = image_AX
        self.ori_img = self.image_ax.get_array()
        self.decay_ax       = decay_AX
        self.decay_fig      = decay_fig
        self.tau_resolution = tau_resolution
        self.image_ax.figure.canvas.mpl_connect('button_press_event', self.on_press)
        self.image_ax.figure.canvas.mpl_connect('button_release_event', self.on_release)
        self.xs  = None
        self.ys  = None
    def on_press(self, event):
        
        self.x0 = event.xdata
        self.y0 = event.ydata

    def on_release(self, event):

        self.x1 = event.xdata
        self.y1 = event.ydata
        self.x_indices = np.int_(np.ceil(np.abs(np.sort(np.array([self.x1,self.x0]))))) # [x1, x2]
        self.y_indices = np.int_(np.ceil(np.abs(np.sort(np.array([self.y0,self.y1]))))) # [y1, y2]
        self.ys = np.sum(self.decay_data[self.y_indices[0]:self.y_indices[1], self.x_indices[0]:self.x_indices[1],:], axis=0)
        self.ys = np.sum(self.ys, axis = 0)
        self.xs = np.linspace(0, decay_data.shape[2], decay_data.shape[2], 
                              dtype = np.int)*self.tau_resolution
        self.decay_ax.set_data(self.xs, self.ys)
        self.decay_fig.set_ylim(ymin = 0, ymax = np.max(self.ys)*10)
#         self.decay_ax.fig.canvas.draw()
        self.shown_img = self.ori_img.copy()
        self.shown_img[self.y_indices[0], self.x_indices[0]:self.x_indices[1]] = 1000
        self.shown_img[self.y_indices[1], self.x_indices[0]:self.x_indices[1]] = 1000
        self.shown_img[self.y_indices[0]:self.y_indices[1], self.x_indices[0]] = 1000
        self.shown_img[self.y_indices[0]:self.y_indices[1], self.x_indices[1]] = 1000
        self.image_ax.set_array(self.shown_img)
        self.fig.redraw()
        

intensity_image = flim_aligner.flim_frames_corrected.sum(axis=-1)  
decay_data = flim_aligner.curve_fit_corrected.transpose((1, 2, 0))     
tau_resolution = 1.0

fig = plt.figure(figsize=(9, 4))

image_ax = fig.add_subplot(121)
image_AX = image_ax.imshow(intensity_image, cmap="viridis") 
fig.colorbar(image_AX)
image_ax.set_aspect('auto')
plt.title('Draw a rectangle on intensity image')

# Plot decay here
# default decay data Pixel (0,0)
plot_decay_data = decay_data[0,0,:]
tau = np.linspace(0, decay_data.shape[2], decay_data.shape[2], dtype = np.int)*tau_resolution
decay_fig = fig.add_subplot(122)
decay_AX, = decay_fig.plot(tau, plot_decay_data, 'k-', label='Selected ROI Histogram', linewidth=1)
plt.yscale(value="log")
#plt.autoscale(enable=True, axis=1)
plt.axis([0, np.max(tau), 0, np.max(plot_decay_data)*10])
plt.xlabel('Time [ns]')
plt.ylabel('Intensity [counts]')
plt.title('TCSPC Decay')
plt.grid(True)
plt.legend()

plt.sca(decay_fig)
linebuilder = RectangularROI(fig, decay_data,image_AX,decay_AX,decay_fig,tau_resolution)
plt.tight_layout()
plt.show()