# Detection of Axonal Fragments
---

As part of Mayoral lab investigation of oligodendrocyte precursor cell influence on neurodegenertation, have fluroescent images with channels:

- GFP (sparse axonal labeling)
- DegenoTag
- Myelin staining
- DAPI

This programs goal is to develop an automated pipeline to detect and measure axonal fragments.


In [None]:
#%% Import block

# System
import os
import os.path
from datetime import datetime


# Plotting
import matplotlib.pyplot as plt 


# Image handling
import tifffile as tf 
import imageio as img # tiff writing
import czifile
from tifffile import imsave, imread, imwrite
from enum import Enum


# Numerical and statistics
import seaborn as sb
import pandas as pd 
import numpy as np
from PIL import Image
import math
import csv



from scipy.signal import find_peaks, peak_prominences
from scipy.ndimage import gaussian_filter

# Image analysis
import cv2 as cv
import skimage as sk
from skimage import io, morphology
from skimage.color import gray2rgb
from skimage import data
from skimage import img_as_float
from skimage.morphology import reconstruction

In [None]:
'''
This cell determines where the matplotlib will show up. 
The value 1 will show plots on a separate screen whereas a value of 0 will show plots in the Jupyter notebook.
'''
if 1:
    %matplotlib qt
    plt.ion()
else:
    %matplotlib inline
    plt.ion()


# Pre-processing stage
### The cell below until the next markdown cells will prepare the correct directories and files to feed into the automated pipeline.
 This section will do the following:
 - Convert CZI files to TIFF
 - Read the images into respective image arrays
 - Crop the image arrays to the respective region of interest


In [None]:
''' 
This cell takes in a particular input directory and output directory to convert the czi files to tiff files. 
You can change the main directory to a specific directory which contains your data. If you do not have a directory within 
your working directory please create the following directories with the following file format. You do not have to use the following
naming conventions but would require refactoring the code:

->current working directory
    -> data
        -> channels_directory
        -> cropped_images_directory
        -> czi_images_directory
        -> tiff_images_directory

Upload your .czi files into the czi_images_directory and begin running this cell
'''

main_directory = "data"
input_directory = os.path.join(main_directory, "czi_images_directory")
output_directory = os.path.join(main_directory, "tiff_images_directory")

# Check if input directory exists, create if not
if not os.path.exists(input_directory):
    print("Output Directory not found, creating one \n")
    os.makedirs(input_directory)
    print("Output Directory successfully created \n")

# Check if output directory exists, create if not
if not os.path.exists(output_directory):
    print("Output Directory not found, creating one \n")
    os.makedirs(output_directory)
    print("Output Directory successfully created \n")


# Loop through files in the input directory
for filename in os.listdir(input_directory):
    full_file_path = os.path.join(input_directory, filename)
    
    # Check if it is a file and has a .czi extension
    if os.path.isfile(full_file_path) and filename.lower().endswith('.czi'):
        with czifile.CziFile(full_file_path) as czi:
            image_array = czi.asarray()
        
        # Save the image as TIFF in the output directory
        output_file_path = os.path.join(output_directory, filename + '.tiff')
        tf.imsave(output_file_path, image_array, imagej=True)
        print(f"Successfully converted: {filename} to {output_directory} \n")

In [None]:
def crop_image(channel_array):  
    '''
    This function takes in:
        A 2D array (channel_array which is ideally a channel of an array) 
        
    This function:
        Crops of the image to the specified x and y limits 

    This function returns:
        a cropped 2D array
    '''
    # You can change these values if necessary or negate this function compeletely rto 
    x_min = 2200
    x_max = 3200
    y_min = 1100
    y_max = 3000
    return channel_array[y_min:y_max, x_min:x_max]
    

In [None]:
'''
This array stores the information of the 2D arrays of the channel for statistical analysis
'''
channel_data_array = []

In [None]:
def save_array_before_crop(image_array):
    '''
    This function takes in:
        A 2D array (image_array) 
        
    This function:
        Flattens the 2D array and sorts it numerically

    This function returns:
        The flattened and sorted 1D array (sorted_pixel_values)
    '''
    pixel_values = image_array.flatten()
    sorted_pixel_values = np.sort(pixel_values)
    return sorted_pixel_values
    

In [None]:
def split_channels(tiff_image):
    ''' 
    This function takes in:
        A image path to a tiff image (tiff_image)
        
    This function:
        Reads the image, splits the image into its distinct channels and adds those arrays to an array of arrays.

    This function returns:
        An array of arrays (image_channel_arrays)
        
    ''' 
    image_array = tf.imread(tiff_image) #Gets the image array
    image_channels_array = []
    num_channels = image_array.shape[0] # Gets the number of channels
    for channel in range(num_channels):
        channel_data = image_array[channel]  # Extract data for one channel
        channel_data_array.append(save_array_before_crop(channel_data))
        image_channels_array.append(channel_data)
    return image_channels_array
    

    

## Splitting and cropping images
The cell directly below aims to split the generated tiff image in the specified directory into its various channels and save the various channels into the channels_directory, enabling us to view and perform operations on each channel if needed. We are particularlly focused on the first channel. The cell below also crops the images produced after splitting the original image into its various channels to use as training data for the algorithm. These cropped images allow us to focus on a particular section of the image to better evaluate and refine the algorithm.


In [None]:
'''
This section of code will split all the the images in the tiff_images_directory (or subsequent directory), loop through the directory
save the channels and then crop them. The crop function call may be removed or commented out in order to analyze the full image. 
Select the code and (CMD + /) to comment out the code segment. 
'''
tiff_images_directory = 'data/tiff_images_directory'

channels_directory = "data/channels_directory"
cropped_images_directory = "data/cropped_images"

if not os.path.exists(channels_directory):
    os.makedirs(channels_directory)
    print(f"Created directory: {channels_directory}")

if not os.path.exists(cropped_images_directory):
    os.makedirs(cropped_images_directory)
    print(f"Created directory: {cropped_images_directory}")



# Loop through the .tiff files in the tiff_images_directory
for filename in os.listdir(tiff_images_directory):
    if filename.endswith(".tiff"):
        filepath = os.path.join(tiff_images_directory, filename)
        
        # Split the channels
        array = split_channels(filepath)
        
        # Save each channel
        for i, channel in enumerate(array):
            channel_filename = f'{os.path.splitext(filename)[0]}_channel_{i+1}.tiff'
            channel_filepath = os.path.join(channels_directory, channel_filename)
            tf.imwrite(channel_filepath, channel)
            ## Feel free to remove these print lines if necessary
            print(f"Saved channel {i + 1} for {filename} to {channel_filepath} \n")

        # Crop and save each channel into the cropped_images_directory
        cropped_array = [crop_image(channel) for channel in array]
        for j, cropped_channel in enumerate(cropped_array):
            cropped_channel_filename = f'{os.path.splitext(filename)[0]}_cropped_channel_{j + 1}.tiff'
            cropped_channel_filepath = os.path.join(cropped_images_directory, cropped_channel_filename)
            tf.imwrite(cropped_channel_filepath, cropped_channel, imagej=True)
            ## Feel free to remove these print lines if necessary
            print(f"Successfully saved cropped channel {j + 1} to {cropped_channel_filepath} \n")

## This section below is the start of the automated pipeline. The pipeline may or may not require fine-tuning to improve the algorithm. This is the current minimum-viable product.



In [None]:
def find_local_maximums(image_array):
    '''
    find_peaks() may require further tuning for apply across all images
    This function takes in:
        A 2D array (image_array)
    This function:
        Loops through each row of the original image and uses the find_peaks() method
        in order to find the local peaks of the row.
    This function returns:
        A 2D array with 0s and the peak values at the 
        corresponding index location of the original image
    Note: the find_peaks() function will return a int64 array.
    '''
    rows = len(image_array) 
    cols = len(image_array[0])
    peaks_indices_array = []
    new_image_array = np.zeros((rows, cols), dtype=np.uint16)
    for row in range(len(image_array) - 1):
        peaks, properties = find_peaks(image_array[row], height=(1500,23000), threshold=None, distance=3, 
                                          prominence=3000, width=50, wlen=None, rel_height=100, 
                                          plateau_size=None)
        prom_array = peak_prominences(image_array[row], peaks)
        peaks_indices_array.append(peaks)
    for i in range(len(image_array) - 1):
        for j in range(len(image_array[i]) - 1):
            if (j in peaks_indices_array[i]):
                new_image_array[i][j] = image_array[i][j]
    return new_image_array
    

## The current algorithm works by looking into the cropped_images directory.
----


##### The multiple variable:
- multiple = False : This will only read a single image specified by changing the single_image_path string value.
- multiple = True : This will read all your .tiff files in the cropped_images_directory 
##### The image_dir variable:
- image_dir = "data/some_directory" : This will work on files in this specific directory, for example if you would like the algorithm to work on a cropped image or the full image.
##### The all_channels variable: 
- all_channels = False: This will only look at channel_1 files (you may change this by changing the variable string value to channel_n.tiff where n is the channel number)
- all_channels = True: This will run the algorithm on every channel.
##### The print_image variable: 
- print_image = False: This will not print the images.
- all_channels = True: This will print the images, it may increase the run time of the algorithm.

In [None]:
multiple = True 
image_dir = "data/channels_directory"
all_channels = False
print_image = True #Default/tested is False

In [None]:
'''
This section reads the image.
'''

single_image_path = image_dir + "/CTRL_INT_4.czi_cropped_channel_1"
image_arrays = []
filenames = sorted(os.listdir(image_dir))
ending = ".tiff"


if not all_channels:
    ending = "channel_1.tiff"

for file in filenames:
    if file.endswith(ending):  # Check if the file is a TIFF image

        print(f"Inspecting: {file}")

if multiple: 
    for filename in filenames:
        if filename.endswith(ending):  # Check if the file is a TIFF image
            # Construct the full path to the image file
            filepath = os.path.join(image_dir, filename)
            # Read the image and add it to the list
            image = plt.imread(filepath)
            print(f"Added {filename} to list")
            image_arrays.append(image)
else:
    im = plt.imread(single_image_path)
    plt.figure()
    plt.imshow(im)



In [None]:
directory = 'data/channels_directory'

# List files in directory
files = os.listdir(directory)

# Sort files by name
files.sort()

# If you want to sort by date, you can use:
# files.sort(key=lambda x: os.path.getmtime(os.path.join(directory, x)))

print("Files sorted by name:")
for file in files:
    if file.endswith(".tiff"):
        print(file)

In [None]:
'''
This section finds the local maximums and displays it
'''
maximums_list = []

if multiple:
    for image in image_arrays:
        # Call the function and store the result in maximums_list
        local_maximums = find_local_maximums(image)
        maximums_list.append(local_maximums) 
        
        if print_image:
            plt.figure()
            plt.imshow(local_maximums)       
else:
    maximums_list.append(find_local_maximums(im)) 
    plt.figure()
    plt.imshow(maximums_list[0])

In [None]:
def find_non_zero_neighbors(image_array, start_y, start_x, global_visited):
    '''
    This function takes in:
        A 2D array (_array); the y coordinate of the pixel (start_y) ; the x cooridnate of the pixel (start_x); and a global 1D array
        which determines if a pixel has already been visited (global_visited)
    This function:
        Finds a non-zero pixel and checks if any adjacent pixels are also non-zero, if so continue to find adjacent pixels until
        no adjacent pixels are found. An example of the directions is shown below with P being the pixel to check adjacent pixels.

                    ↖ ↑ ↗
                    ← P →
                    ↙ ↓ ↘
                    
            P: Original pixel value
            →: Right
            ←: Left 
            ↙: Bottom-left
            ↓: Bottom 
            ↘: Bottom-right 
        
    This function returns:
       4 integer values which are start and end coordinate pairs of the lines (start_y, start_x, end_y, end_x)
    '''
    class Direction(Enum):
        LEFT = (0, -1)
        BOTTOM_LEFT = (1, -1)
        BOTTOM = (1, 0)
        BOTTOM_RIGHT = (1, 1)
        RIGHT = (0, 1)
        
    visited = set()
    stack = [(start_y, start_x)]
    end_y, end_x = start_y, start_x

    while stack:
        y, x = stack.pop()
        if (y, x) in visited or (y, x) in global_visited:
            continue
        visited.add((y, x))        
        global_visited.add((y, x))


        for direction in Direction:
            dy, dx = direction.value
            ny, nx = y + dy, x + dx
            if 0 <= ny < len(image_array) and 0 <= nx < len(image_array[0]) and image_array[ny][nx] != 0:
                stack.append((ny, nx))
                end_y, end_x = ny, nx

    return start_y, start_x, end_y, end_x


In [None]:
def find_lines(test_array): 
    '''
    This function takes in:
        A 2D array (_array)
    This function:
        Loops through the entire image to find neighboring pixel values that are non-zero and continues this operation
        until it cannot find a non-zero neighboring pixel. It then adds the start and end coordinates to a 2D array
        and calculates the Euclidean distance and adds that value to a 1D array.
    This function returns:
       A 2D array (line_endpoints); a 1D array (line_lengths)
    '''   
    global_visited = set()
    line_endpoints = []
    line_lengths = []
    # Loop through the entire array
    for y in range(len(test_array)):
        for x in range(len(test_array[y])):
            if test_array[y][x] != 0 and (y, x) not in global_visited:
                start_y, start_x, end_y, end_x = find_non_zero_neighbors(test_array, y, x, global_visited)
                distance = math.sqrt((end_y - start_y) ** 2 + (end_x - start_x) ** 2)
                if distance > 5: ## Can change the size
                    line_endpoints.append([start_y, start_x, end_y, end_x])
                    line_lengths.append(distance)
                    
    print("Line endpoints (start_y, start_x, end_y, end_x):", line_endpoints)
    print("Line lengths:", line_lengths)

    # Specify the filename

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f'file_at_time:{timestamp}.csv'

    # Write the 1D array to a CSV file
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)
        for item in line_lengths:
            writer.writerow([item])

    print(f"The 1D array has been written to {filename}")

    return line_endpoints, line_lengths

   

In [None]:
'''
This section finds the lines, and overlays the lines onto the original image.
'''
## Converts to a 8bit image. May lose information.

if multiple:
    for i in range(len(maximums_list)):
        line_endpoints, line_lengths = find_lines(maximums_list[i])

        ## If the print_image value is true
        if print_image:

            gray_image = maximums_list[i] * 255
            color_image = cv.cvtColor(gray_image.astype(np.uint8), cv.COLOR_GRAY2BGR)

            for (start_y, start_x, end_y, end_x) in line_endpoints:
                cv.line(color_image, (start_x, start_y), (end_x, end_y), (255, 0, 0), 2)  # Red color (0, 0, 255) and thickness of 1 
            plt.figure()    
            plt.imshow(color_image)
            plt.show()
else:
    line_endpoints, line_lengths = find_lines(im)
    color_image = cv.cvtColor(im * 255, cv.COLOR_GRAY2BGR)
    for (start_y, start_x, end_y, end_x) in line_endpoints:
        cv.line(color_image, (start_x, start_y), (end_x, end_y), (0, 0, 255), 2)  # Red color (0, 0, 255) and thickness of 1 
    plt.figure()    
    plt.imshow(color_image)




# END OF CURRENT PIPELINE (August 2024)
#### Future work could be taking these endpoints and line lengths to validate and refine the algorithm.
***
