In [None]:
# Standard Library Imports
import os
import sys

# Third Party Imports
import numpy as np 
from tifffile import imread, imsave
import matplotlib.pyplot as plt

In [None]:
# Specify the size of your Z-step (the scan direction).
dz = 0.2 # Use value straight from piezo

# Specify your lateral pixel size.
xypixelsize = 6.5/50
print("Pixel Size", xypixelsize, dz)

# Specify the angle of the illumination.
angle = 60.5


# Image Path, use format as r'[insert text path copied here]'
image_path = r''
# Rename individual channel with your desired naming scheme
image_name_0 = r'\CH00_000000.tiff' 
image = imread((image_path+image_name))
print("Image Dimensions:", np.shape(image))


### Define functions

In [None]:
def deskew(inArray, angle, dz, xypixelsize):
    (z_len, y_len, x_len) = inArray.shape
    Trans = np.cos(angle * np.pi / 180) * dz / xypixelsize
    widenBy = np.uint16(np.ceil(z_len * np.cos(angle * np.pi / 180) * dz / xypixelsize))

    inArrayWiden = np.zeros((z_len, y_len, x_len + widenBy))
    inArrayWiden[:z_len, :y_len, :x_len] = inArray
    output = np.zeros((z_len, y_len, x_len + widenBy))

    xF, yF = np.meshgrid(np.arange(x_len + widenBy), np.arange(y_len))

    for k in range(z_len):
        inSlice = inArrayWiden[k, :, :]
        inSliceFFT = np.fft.fftshift(np.fft.fft2(inSlice))
        inSliceFFTTrans = inSliceFFT * np.exp(-1j * 2 * np.pi * xF * Trans * k / (x_len + widenBy))
        output_temp = np.abs(np.fft.ifft2(np.fft.ifftshift(inSliceFFTTrans)))
        output[k, :, :] = output_temp

    output[output < 0] = 0
    return np.uint16(output)  # return uint16 data to save as tiff

def plot_image(image):
    """ Plot maximum intensity projectio of 3D image 

    Parameters
    ----------
    image : np.array
        3D image array.
    """
    ax1 = plt.subplot(311)
    ax1.margins(0.05)           
    ax1.imshow(np.max(image, 0))
    ax2 = plt.subplot(312)
    ax2.margins(0.05)           
    ax2.imshow(np.max(image, 1))
    ax3 = plt.subplot(313)
    ax3.margins(0.05)           
    ax3.imshow(np.max(image, 2))
    plt.show()

### Show Raw Data

In [None]:
plot_image(image)

### Deskewed

In [None]:
sheared_data = deskew(inArray=image, angle=90-angle, dz=dz, xypixelsize=xypixelsize)
print("Final image dimensions", np.shape(sheared_data))
plot_image(sheared_data)
deskewed_image_name = r'\Deskewed.tiff'
imsave((image_path+deskewed_image_name),
       data=sheared_data)