In [None]:
# Standard Library Imports
import os
import sys

# Third Party Imports
import numpy as np 
from tifffile import imread, imsave
import matplotlib.pyplot as plt

In [None]:
# Specify the size of your Z-step (the scan direction).
dz = 0.2 # Use value straight from piezo translation

# Specify your lateral pixel size.
xypixelsize = 6.5/50 
print("Pixel Size", xypixelsize, dz)

# Specify the angle of the illumination.
angle = 60.5 


# Image Path, use format as r'[insert text path copied here]'
image_path = r''
# Rename individual channels with your desired naming scheme
image_name_0 = r'\CH00_000000.tiff' 
image_name_1 = r'\CH01_000000.tiff'
image_name_2 = r'\CH02_000000.tiff'
image_name_3 = r'\CH03_000000.tiff'

In [None]:
image_0 = imread((image_path+image_name_0))
image_1 = imread((image_path+image_name_1))
image_2 = imread((image_path+image_name_2))
image_3 = imread((image_path+image_name_3))
print("Image Dimensions:", np.shape(image_0))

### Define functions

In [None]:
def deskew(inArray, angle, dz, xypixelsize):
    (z_len, y_len, x_len) = inArray.shape
    Trans = np.cos(angle * np.pi / 180) * dz / xypixelsize
    widenBy = np.uint16(np.ceil(z_len * np.cos(angle * np.pi / 180) * dz / xypixelsize))

    inArrayWiden = np.zeros((z_len, y_len, x_len + widenBy))
    inArrayWiden[:z_len, :y_len, :x_len] = inArray
    output = np.zeros((z_len, y_len, x_len + widenBy))

    xF, yF = np.meshgrid(np.arange(x_len + widenBy), np.arange(y_len))

    for k in range(z_len):
        inSlice = inArrayWiden[k, :, :]
        inSliceFFT = np.fft.fftshift(np.fft.fft2(inSlice))
        inSliceFFTTrans = inSliceFFT * np.exp(-1j * 2 * np.pi * xF * Trans * k / (x_len + widenBy))
        output_temp = np.abs(np.fft.ifft2(np.fft.ifftshift(inSliceFFTTrans)))
        output[k, :, :] = output_temp

    output[output < 0] = 0
    return np.uint16(output)  # return uint16 data to save as tiff

def plot_image(image):
    """ Plot maximum intensity projection of 3D image 

    Parameters
    ----------
    image : np.array
        3D image array.
    """
    ax1 = plt.subplot(311)
    ax1.margins(0.05)           
    ax1.imshow(np.max(image, 0))
    ax2 = plt.subplot(312)
    ax2.margins(0.05)           
    ax2.imshow(np.max(image, 1))
    ax3 = plt.subplot(313)
    ax3.margins(0.05)           
    ax3.imshow(np.max(image, 2))
    plt.show()

### Show Raw Data

In [None]:
plot_image(image_0)

In [None]:
plot_image(image_1)

In [None]:
plot_image(image_2)

In [None]:
plot_image(image_3)

# Channel 0

In [None]:
sheared_data_0 = deskew(inArray=image_0, angle=90-angle, dz=dz, xypixelsize=xypixelsize)
print("Final image dimensions", np.shape(sheared_data_0))
plot_image(sheared_data_0)
deskewed_image_name_0 = r'\Deskewed_CH00.tiff'
imsave((image_path+deskewed_image_name_0),
       data=sheared_data_0)

# Channel 1

In [None]:
sheared_data_1 = deskew(inArray=image_1, angle=90-angle, dz=dz, xypixelsize=xypixelsize)
print("Final image dimensions", np.shape(sheared_data_1))
plot_image(sheared_data_1)
deskewed_image_name_1 = r'\Deskewed_CH01.tiff'
imsave((image_path+deskewed_image_name_1),
       data=sheared_data_1)

# Channel 2

In [None]:
sheared_data_2 = deskew(inArray=image_2, angle=90-angle, dz=dz, xypixelsize=xypixelsize)
print("Final image dimensions", np.shape(sheared_data_2))
plot_image(sheared_data_2)
deskewed_image_name_2 = r'\Deskewed_CH02.tiff'
imsave((image_path+deskewed_image_name_2),
       data=sheared_data_2)

# Channel 3

In [None]:
sheared_data_3 = deskew(inArray=image_3, angle=90-angle, dz=dz, xypixelsize=xypixelsize)
print("Final image dimensions", np.shape(sheared_data_1))
plot_image(sheared_data_3)
deskewed_image_name_3 = r'\Deskewed_CH03.tiff'
imsave((image_path+deskewed_image_name_3),
       data=sheared_data_3)