In [None]:
# import libraries
import numpy as np
import matplotlib.pyplot as plt
import scipy as sc
import numpy as np
from scipy.ndimage import gaussian_filter

In [None]:
# functions 
def calculate_oct_snr(tom):
    """
    Calculate the signal-to-noise ratio (SNR) for an OCT image.

    Args:
        oct_image: A 3D numpy array representing the OCT image, with dimensions (depth,height, width).

    Returns:
        The SNR of the OCT image.
    """

    oct_image = abs(tom[:,:,:,0]+1j*tom[:,:,:,1])**2
    # Determine the dimensions of the OCT image
    depth, height, width = oct_image.shape

    # Compute the signal as the mean of the intensity values across the entire image
    signal = np.mean(oct_image)

    # Compute the noise as the standard deviation of the intensity values within the background region
    background = oct_image[:,0:height//4, 0:width//4]
    noise = np.std(background)

    # Compute the SNR as the ratio of signal to noise
    snr = signal / noise

    return snr

def calculate_oct_speckle(tom):
    """
    Calculate the size and distribution of speckle in an OCT image.

    Args:
        oct_image: A 3D numpy array representing the OCT image, with dimensions (height, width, depth).

    Returns:
        A tuple containing the mean speckle size and a histogram of speckle sizes.
    """
    oct_image = abs(tom[:,:,:,0]+1j*tom[:,:,:,1])**2
    # Determine the dimensions of the OCT image
    depth,height, width  = oct_image.shape

    # Compute the local standard deviation of the image
    std_image = np.zeros((height, width))
    for i in range(depth):
        std_image += (oct_image[i,:,:] - np.mean(oct_image[i,:,:])) ** 2
    std_image = np.sqrt(std_image / depth)

    # Compute the binary image based on the standard deviation threshold
    thresh = 0.5 * np.max(std_image)
    binary_image = std_image > thresh

    # Compute the connected components of the binary image
    labels, num_labels = sc.ndimage.label(binary_image)

    # Compute the size of each connected component
    sizes = np.zeros(num_labels)
    for i in range(num_labels):
        sizes[i] = np.sum(labels == i+1)

    # Compute the mean speckle size
    mean_size = np.mean(sizes)

    # Compute the histogram of speckle sizes
    hist, bins = np.histogram(sizes, bins=range(int(np.max(sizes))+2))

    return mean_size, hist, std_image



def calculate_oct_std(tom, neighborhood_size=5):
    """
    Calculate the local standard deviation of an OCT image.

    Args:
        oct_image: A 3D numpy array representing the OCT image, with dimensions (depth, height, width).
        neighborhood_size: An integer specifying the size of the local neighborhood for computing the standard deviation.

    Returns:
        A 2D numpy array representing the local standard deviation of the OCT image.
    """
    oct_image = abs(tom[:,:,:,0]+1j*tom[:,:,:,1])**2
    # Compute the local standard deviation of the image
    std_image = np.zeros((oct_image.shape[0], oct_image.shape[1], oct_image.shape[2]))
    for i in range(oct_image.shape[0]):
        std_image[i,:,:] = gaussian_filter(oct_image[i,:,:], sigma=neighborhood_size/6)

    return std_image

def mps(tom):
    """
    Calculates the mean power spectrum of a tomogram.

    Parameters:
        tomogram (ndarray): The tomogram as a 3D NumPy array.

    Returns:
        ndarray: The mean power spectrum as a 1D NumPy array.
    """

    tomogram = tom[:,:,:,0]+1j*tom[:,:,:,1]
    # Compute the Fourier transform of the tomogram.
    tomogram_ft = np.fft.fftshift(np.fft.fftn(tomogram))

    # Compute the power spectrum.
    power_spectrum = np.abs(tomogram_ft)**2

    # Compute the mean power spectrum.
    mean_power_spectrum = np.mean(power_spectrum, axis=(0,1))
    tdmps = np.mean(power_spectrum, axis=0)

    tdmps = tdmps/np.max(tdmps)

    return mean_power_spectrum, tdmps

In [None]:
""" Load tomograms"""
rootFolder = 'D:/DLOCT/TomogramsDataAcquisition/' # porcine cornea
fnameTom = 'ExperimentalTomogram/ExperimentalROI_corrected5' # porcine cornea

# rootFolder = 'D:/DLOCT/TDG/OCT_Real/nueva data/' # s.eye_swine
# fnameTom = '[p.SHARP][s.Eye2a][10-09-2019_13-14-42]_TomJones_z=(586)_x=(512)_y=(512)'# s.eye_swine

""" Shape of each tomogram, as tuples (Z, X, Y)"""

tomShape = [(350,384,384)]# porcine cornea
# tomShape = [(586,512,512)] # s.eye_swine

fname = rootFolder + fnameTom
# Names of all real and imag .bin files
fnameTomReal = [fname + '_real.bin' ]
fnameTomImag = [fname + '_imag.bin' ]

In [None]:
tomReal = np.fromfile(fnameTomReal[0]) # quit single for porcine cornea and put single for s_eye ,dtype='single'
tomReal = tomReal.reshape(tomShape[0], order='F')  # reshape using
# Fortran style to import according to MATLAB

tomImag = np.fromfile(fnameTomImag[0])
tomImag = tomImag.reshape(tomShape[0], order='F')  # reshape using
# Fortran style to import according to MATLAB

tomData = np.stack((tomReal, tomImag), axis=3)
# tomData = tomData/np.max(abs(tomData)) # normalize tomogram

In [None]:
z = 128
plt.imshow(10*np.log10(abs(tomData[z,:,:,0]+1j*tomData[z,:,:,1])**2))
print(tomData.shape)

In [None]:

# Calculate the SNR of the OCT image
snr = calculate_oct_snr(tomData)
snrdB = 20*np.log10(snr)
print('SNR:', snr,'SNRdB:',snrdB)

In [None]:

# Calculate the size and distribution of speckle in the OCT image
mean_size, hist,sizes= calculate_oct_speckle(tomData)

print('Mean speckle size:', mean_size)

plt.bar(range(len(hist)), hist)
plt.xlabel('Speckle size')
plt.ylabel('Number of speckles')
plt.show()

In [None]:
plt.imshow(sizes)
print(np.max(sizes))

In [None]:
# Calculate the local standard deviation of the OCT image
std_image = calculate_oct_std(tomData)
mean_std = np.mean(std_image)
z = 128
# Display a B-scan and its corresponding local standard deviation
plt.subplot(1,2,1)
plt.imshow(10*np.log10(abs(tomData[z,:,:,0]+1j*tomData[z,:,:,1])**2))
plt.title('on face plane')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(std_image[z,:,:])
plt.title('Local standard deviation')
plt.colorbar()
plt.show()

In [None]:
std_image2 = np.sqrt(np.sum(std_image,axis=0)/std_image.shape[0])
std_image2 = std_image2/np.max(std_image2)
sizes2 = sizes/np.max(sizes)
dif = std_image2-sizes2

plt.imshow(dif)



In [None]:

mps,power = mps(tomData)



In [None]:
fig,axs = plt.subplots(1,2)
axs[0].plot(power[:,256])
axs[1].imshow(power)