In [None]:
import sys

# Don't generate the __pycache__ folder locally
sys.dont_write_bytecode = True 

# Print exception without the buit-in python warning
sys.tracebacklimit = 0 

In [None]:
import matplotlib.pyplot as plt
%config InlineBackend.figure_format='retina'
plt.rcParams.update({'font.size': 12})
from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size

import skimage as skimage

In [None]:
from tqdm.notebook import tqdm

In [None]:
from read_image import *
from make_binarization import *
from make_convolution import *
from make_image_gradients import *
from make_structure_tensor_2d import *
from make_coherence import *
from make_orientation import *
from make_vxvy import *

In [None]:
FilterKey = 1
LocalDensityKey = 10
LocalSigmaKey = 50
ThresholdValueKey = 20
each_chunk_size = int(max(LocalDensityKey, LocalSigmaKey))

In [None]:
aspect = 20
pad_fraction = 0.5

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def show_mosaic(chunks, cmap = 'viridis'):
    """
    Shows a mosaic of the original image, constructed from its chunks, as n x n subplots with reduced borders.
    
    Parameters:
        img (numpy.ndarray): The input 2D grayscale image.
        chunks (list): A list of chunks, each of size chunk_size x chunk_size.
        overlap_pixels (int): The overlap between chunks, in pixels.
    """
    
    # Calculate the number of rows and columns
    n = int(np.ceil(np.sqrt(len(chunks))))
    
    fig, axs = plt.subplots(n, n, figsize=(5, 5))
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    for i, chunk in enumerate(chunks):
        row = i // n
        col = i % n
        axs[row, col].imshow(chunk, cmap = cmap)
        axs[row, col].set_xticks([])
        axs[row, col].set_yticks([])

    plt.show()

In [None]:
import numpy as np

def split_into_chunks(img, chunk_size):
    """
    Splits a 2D grayscale image into chunks of a given size.
    
    Parameters:
        img (numpy.ndarray): The input 2D grayscale image.
        chunk_size (int): The size of the chunks to split 
        the image into.
        overlap_pixels (int): The overlap between chunks, 
        in pixels.
        
    Returns:
        list: A list of chunks, each of 
        size chunk_size x chunk_size.
        numpy.ndarray: The padded image, with 
        size padded_size x padded_size, where padded_size is a multiple of chunk_size.
    """    
    # Divide the image into chunks
    chunks = []
    for i in range(0, img.shape[0] - chunk_size + 1, chunk_size):
        for j in range(0, img.shape[1] - chunk_size + 1, chunk_size ):
            chunk = img[i:i + chunk_size, j:j + chunk_size]
            chunks.append(chunk)
            
    return chunks

In [None]:
def make_padded_image(img, chunk_size):
    # Pad the image to make it square and a multiple of chunk_size
    max_size = max(img.shape)
    padded_size = max_size + (chunk_size - max_size % chunk_size) % chunk_size
    
#     padded_img = np.full((padded_size, padded_size), np.inf)
    padded_img = np.zeros((padded_size, padded_size))

    padded_img[:img.shape[0], :img.shape[1]] = img

    return padded_img

In [None]:
def stitch_chunks(analyzed_chunk_list, padded_img, img, chunk_size):
    # Calculate the number of chunks in each dimension
    num_chunks = padded_img.shape[0] // chunk_size

    # Initialize a new NumPy array for the reconstructed image
    
    reconstructed_img = np.full(padded_img.shape, np.inf)
#     reconstructed_img = np.zeros((padded_img.shape))

    # Iterate over each chunk and copy it back to the correct location in the reconstructed image
    for i in range(len(analyzed_chunk_list)):
        row = i // num_chunks
        col = i % num_chunks
        
        chunk = analyzed_chunk_list[i]
        start_row = row * chunk_size
        end_row = start_row + chunk_size
        
        start_col = col * chunk_size
        end_col = start_col + chunk_size
        reconstructed_img[start_row:end_row, start_col:end_col] = chunk

    # Crop the reconstructed image to the size of the original input image
    reconstructed_img = reconstructed_img[:img.shape[0], :img.shape[1]]

    return reconstructed_img

In [None]:
raw_image = convert_to_8bit_grayscale('/Users/ajinkyakulkarni/Desktop/cl02.tif')

# raw_image = convert_to_8bit_grayscale('TestImage1.tif')

plt.imshow(raw_image, cmap='cividis')
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title(raw_image.shape)
plt.show()

In [None]:
each_chunk_size = int(each_chunk_size)

In [None]:
padded_raw_image = make_padded_image(raw_image, each_chunk_size)

In [None]:
plt.imshow(padded_raw_image, cmap='cividis')
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title(padded_raw_image.shape)
plt.show()

In [None]:
chunks = split_into_chunks(padded_raw_image, each_chunk_size)

len(chunks)

In [None]:
# show_mosaic(chunks, cmap = 'cividis')

In [None]:
print()

# Local_Density_list = []
Image_Coherance_list = []
Image_Orientation_list = []

for i in tqdm(range(len(chunks))):
    
    ##################
    
    current_chunk = chunks[i]

    filtered_chunk = skimage.filters.gaussian(current_chunk, 
                                              sigma = FilterKey, 
                                              mode = 'nearest', 
                                              preserve_range = True)
    
    binarized_chunk = binarize_image(filtered_chunk)
    
#     # Define the kernel and it's size
#     local_kernel_size = LocalDensityKey
#     if (local_kernel_size % 2 == 0):
#         local_kernel_size = local_kernel_size + 1
#     if (local_kernel_size < 3):
#         local_kernel_size = 3
# 
#     local_kernel = np.ones((local_kernel_size, local_kernel_size), 
#                            dtype = np.float32) / (local_kernel_size * local_kernel_size)

#     Local_Density = convolve(binarized_chunk, local_kernel)
    
#     # Normalize Local_Density between 0 and 1
#     if (Local_Density.max() > 0):
#         Local_Density = Local_Density / Local_Density.max()
#     else:
#         raise Exception('Local_Density might be an empty image in this chunk')

    # Calculate image gradients in X and Y directions
    image_gradient_x, image_gradient_y = make_image_gradients(filtered_chunk)
    
    # Calculate the structure tensor and solve for EigenValues, EigenVectors
        
    Structure_Tensor, EigenValues, EigenVectors, Jxx, Jxy, Jyy = make_structure_tensor_2d(image_gradient_x, 
                                                                                          image_gradient_y, 
                                                                                          LocalSigmaKey)
    
    Image_Coherance = make_coherence(filtered_chunk, EigenValues, Structure_Tensor, ThresholdValueKey)
    
    Image_Orientation = make_orientation(filtered_chunk, Jxx, Jxy, Jyy, ThresholdValueKey)
    vx, vy = make_vxvy(filtered_chunk, EigenVectors, ThresholdValueKey)
    
    ##################
    
#     Local_Density_list.append(Local_Density)
    Image_Coherance_list.append(Image_Coherance)
    Image_Orientation_list.append(Image_Orientation)
    
#     local_density = np.where(local_density == 0, np.nan, local_density)

print()

In [None]:
Image_Orientation = stitch_chunks(Image_Orientation_list, 
                                  padded_raw_image, raw_image, 
                                  each_chunk_size)

In [None]:
Image_Coherance = stitch_chunks(Image_Coherance_list, 
                                padded_raw_image, raw_image, 
                                each_chunk_size)

In [None]:
fig = plt.figure(figsize = (8, 4), constrained_layout = True)
# plt.imshow(Image_Coherance, cmap='Spectral_r', vmin = 0, vmax = 1)
im = plt.imshow(plt.cm.gray(raw_image/raw_image.max()) * plt.cm.Spectral_r(Image_Coherance), 
            vmin = 0, vmax = 1, cmap = 'Spectral_r')
plt.xticks([])
plt.yticks([])
ax = plt.gca()
divider = make_axes_locatable(ax)
width = axes_size.AxesY(ax, aspect=1./aspect)
pad = axes_size.Fraction(pad_fraction, width)
cax = divider.append_axes("right", size=width, pad=pad)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize = 12)

plt.show()

In [None]:
fig = plt.figure(figsize = (8, 4), constrained_layout = True)
# plt.imshow(Image_Orientation, cmap='hsv', vmin = 0, vmax = 180)
im = plt.imshow(plt.cm.gray(raw_image/raw_image.max()) * plt.cm.hsv(Image_Orientation/180), 
           vmin = 0, vmax = 1, cmap = 'hsv')
plt.xticks([])
plt.yticks([])
ax = plt.gca()
divider = make_axes_locatable(ax)
width = axes_size.AxesY(ax, aspect=1./aspect)
pad = axes_size.Fraction(pad_fraction, width)
cax = divider.append_axes("right", size=width, pad=pad)
cbar = fig.colorbar(im, cax = cax, ticks = np.linspace(0, 1, 5))
cbar.ax.set_yticklabels([r'$0^{\circ}$', r'$45^{\circ}$', r'$90^{\circ}$', r'$135^{\circ}$', r'$180^{\circ}$'])
ticklabs = cbar.ax.get_yticklabels()
cbar.ax.set_yticklabels(ticklabs, fontsize = 12)

plt.show()