## Optical Flow calculation in 3d with the Lukas Kanade Method

In [1]:
## Importing the necessary modules

import numpy as np
import os
from tqdm import tqdm
import scipy as sp
import time
import scipy.io as sio
import hdf5storage
from multiprocessing import TimeoutError
from multiprocessing.pool import ThreadPool as Pool
from functools import partial
import matplotlib.pyplot as plt
from natsort import natsorted
import warnings
warnings.filterwarnings("ignore")

In [2]:
# Lucas-Kanade method for Optical Flow
# INPUTS: 

#k: Timeframe index, input_image: temporally smoothed movie
#sig: Spread of the Gaussian weights around a pixel, #thresh: Threshold for reliability score

# OUTPUTS (stored in a data file): 
#vx: x-velocity, vy: y-velocity, vz: z-velocity
#reliabMat: reliability score for velocities calculated at a pixel (higher the better)

def LKxOptFlow(k,sig,thresh):
    
    def reliability_from_cubic_roots(a,b,c,d,e,f):
        trA = a + b + c;
        q = trA/3;
        x = (a + b - 2*c)**2/9 + (a -2*b + c)**2/9 + (b - 2*a + c)**2/9 + 2*(d**2 + e**2 + f**2);
        p = np.sqrt(x/6)
        detB = ((a-q)*(b-q)*(c-q) + 2*d*e*f - (a-q)*f**2 - (b-q)*e**2 - (c-q)*d**2)/(p**3)
        detB[detB > 2] = 2;
        detB[detB < -2] = -2;

        theta = np.arccos(detB/2)/3

        # define roots
        r1 = 2*np.cos(theta);
        r2 = 2*np.cos(theta + 2*np.pi/3);
        r3 = 2*np.cos(theta + 4*np.pi/3);

        # define eigs
        b1 = r1*p + q;
        b2 = r2*p + q;
        b3 = r3*p + q;

        # define reliability
        rel = np.stack((b1,b2,b3), axis=3)
        rel = np.min(rel, axis=3)
        return rel


    current_frame = np.load(smooth_address+'\\{}.npy'.format(k))             # Load temporally smoothed frame-1
    next_frame = np.load(smooth_address+'\\{}.npy'.format(k+1))         # Load temporally smoothed frame-2
    
    # Calculating the spatial and temporal gradients assuming the input frame is of the shape (ZYX)
    
    dy = np.gradient(current_frame,axis=1,edge_order=1)
    dx = np.gradient(current_frame,axis=2,edge_order=1)
    dz = np.gradient(current_frame,axis=0,edge_order=1)
    dt = next_frame - current_frame

    dz = dz.astype(np.float32); dy = dy.astype(np.float32); dx = dx.astype(np.float32); dt = dt.astype(np.float32); 

    wdx2 = sp.ndimage.gaussian_filter(dx**2, sig, mode='nearest')
    wdy2 = sp.ndimage.gaussian_filter(dy**2, sig, mode='nearest')
    wdz2 = sp.ndimage.gaussian_filter(dz**2, sig, mode='nearest')
    
    wdxy = sp.ndimage.gaussian_filter(dx*dy, sig, mode='nearest')
    wdxz = sp.ndimage.gaussian_filter(dx*dz, sig, mode='nearest')
    wdyz = sp.ndimage.gaussian_filter(dy*dz, sig, mode='nearest')
    
    wdtx = sp.ndimage.gaussian_filter(dx*dt, sig, mode='nearest')
    wdty = sp.ndimage.gaussian_filter(dy*dt, sig, mode='nearest')
    wdtz = sp.ndimage.gaussian_filter(dz*dt, sig, mode='nearest')

    # Calculate Flow
    eps = 1e-6
    determinant = wdx2*wdy2*wdz2 + 2*wdxy*wdyz*wdxz - wdx2*(wdyz**2) - wdy2*(wdxz**2) - wdz2*(wdxy**2) 
    
    vx = -((determinant + eps)**(-1))* ( \
              wdtx*(wdy2*wdz2 - wdyz**2) \
            + wdty*(wdxz*wdyz - wdxy*wdz2) \
            + wdtz*(wdxy*wdyz - wdxz*wdy2) );
    vy = -((determinant + eps)**(-1))* ( \
              wdtx*(wdxz*wdyz - wdxy*wdz2) \
            + wdty*(wdx2*wdz2 - wdxz**2) \
            + wdtz*(wdxy*wdxz - wdx2*wdyz) );
    vz = -((determinant + eps)**(-1))* ( \
              wdtx*(wdxy*wdyz - wdxz*wdy2) \
            + wdty*(wdxy*wdxz - wdx2*wdyz) \
            + wdtz*(wdx2*wdy2 - wdxy**2) );
    reliabMat = reliability_from_cubic_roots(wdx2,wdy2,wdz2,wdxy,wdxz,wdyz)
    
    vx = vx*(reliabMat > thresh);
    vy = vy*(reliabMat > thresh);
    vz = vz*(reliabMat > thresh);
    
#     np.savez(save_path+'/{}.npz'.format(k),vx = vx,vy = vy,rel=reliabMat)   # to save as numpy data files
    sio.savemat(save_path+'\\{}.mat'.format(k),{'vx':vx,'vy':vy,'vz':vz,'rel':reliabMat},do_compression=False) # to save as .MAT files
    return [0]

# Example usage:

data_folder = 'E:\\Spandan\\3D_Lattice_Lightsheet\\Shen 1-30-23\\dicty_factin_pip3-06_processed'   # Root address where all the images are stored
smooth_address = data_folder+'\\smoothed_frames'
save_path = data_folder+'\\Op_flow_threaded'                 # Create a local folder for storing OF results
os.makedirs(save_path,exist_ok=True)

file_list = natsorted(os.listdir(smooth_address))    # sort the image files numerically by frame-index
n_frames = len(file_list);

In [None]:
numProcessors = 16   # maybe change it to 8 (but that's it)
pool = Pool(processes=numProcessors)
    
print('Post-processing in parallel with '+str(numProcessors)+' processors')

#store start time
stopwatchStart = time.time()
[temp] =zip(*pool.map(partial(LKxOptFlow, sig = 2, thresh = 0.1), range(n_frames-1)))
  
print('Wall time = '+str(np.round(time.time() - stopwatchStart,2))+' s')

print('Parallel post processing complete, switching to serial')
pool.close()

Post-processing in parallel with 16 processors
