**Preprocess OCT Images**
Provide input image path (oct_file_path) and output folder path (output_dir). Output image for preprocessed OCT is: oct_preprocessed.tif

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage import filters, exposure, morphology
from skimage.restoration import denoise_nl_means, estimate_sigma
from skimage.segmentation import clear_border
from tifffile import TiffFile, imwrite
from scipy.ndimage import gaussian_filter, sobel
from tqdm import tqdm

# -------------------------
# File paths and output directories
# -------------------------
oct_file_path = '/Users/atesfet/Library/CloudStorage/GoogleDrive-atesfet@stanford.edu/Shared drives/de la Zerda Lab/Lab Members/Ates/Initial Dataset/LM-01/LM-01_OCT_XY.tif'
output_dir = '/Users/atesfet/Library/CloudStorage/GoogleDrive-atesfet@stanford.edu/Shared drives/de la Zerda Lab/Lab Members/Ates/OCTHE/octhe_train/oct_preprocessed_0326'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# -------------------------
# Load OCT Stack
# -------------------------
with TiffFile(oct_file_path) as tif:
    oct_stack = tif.asarray()
print(f"Loaded OCT stack with shape: {oct_stack.shape}")

# -------------------------
# Preprocessing Function
# -------------------------
def preprocess_image(img):
    """
    Preprocess a single OCT image:
      1. Convert to float32.
      2. Normalize and apply CLAHE.
      3. Estimate noise sigma (using channel_axis=None for grayscale).
      4. Denoise using non-local means.
    """
    img = img.astype(np.float32)
    # Normalize and apply adaptive histogram equalization (CLAHE)
    img = exposure.equalize_adapthist(img / img.max())
    # Estimate noise sigma; for grayscale use channel_axis=None
    sigma = estimate_sigma(img, channel_axis=None)
    # Denoise using non-local means
    img = denoise_nl_means(img, h=1.15 * sigma, fast_mode=True, channel_axis=None)
    return img

# -------------------------
# Preprocess the OCT Stack
# -------------------------
preprocessed_stack = np.array([preprocess_image(oct_stack[i]) for i in tqdm(range(oct_stack.shape[0]))])

# Visualize a raw vs. preprocessed frame
idx = 100  # adjust index as needed
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(oct_stack[idx], cmap='gray')
plt.title("Raw OCT Frame")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(preprocessed_stack[idx], cmap='gray')
plt.title("Preprocessed OCT Frame")
plt.axis('off')
plt.tight_layout()
plt.show()

# -------------------------
# Threshold-Based Segmentation (per slice)
# -------------------------
def segment_threshold(img):
    thresh_val = filters.threshold_otsu(img)
    binary = img > thresh_val
    binary = morphology.remove_small_objects(binary, 64)
    binary = clear_border(binary)
    return binary

threshold_segmented = np.array([segment_threshold(preprocessed_stack[i]) for i in tqdm(range(preprocessed_stack.shape[0]))])

# Visualize threshold segmentation on one slice
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(preprocessed_stack[idx], cmap='gray')
plt.title('Preprocessed')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(threshold_segmented[idx], cmap='gray')
plt.title('Threshold Segmentation')
plt.axis('off')
plt.show()

# -------------------------
# Save the Preprocessed OCT Stack as a Single Multi-Page TIFF
# -------------------------
preprocessed_output_path = os.path.join(output_dir, "oct_preprocessed.tif")
# Scale to 0-255 and convert to uint8
imwrite(preprocessed_output_path, (preprocessed_stack * 255).astype(np.uint8))
print("Saved preprocessed OCT stack as a single TIFF:", preprocessed_output_path)

# -------------------------
# Save the 3D Threshold Segmentation as a Single Multi-Page TIFF
# -------------------------
segmentation_output_path = os.path.join(output_dir, "segmentation_lm1-oct.tif")
imwrite(segmentation_output_path, (threshold_segmented.astype(np.uint8) * 255))
print("Saved 3D threshold segmentation as a single TIFF:", segmentation_output_path)

# -------------------------
# 3D Segmentation and Smooth Edge Generation
# -------------------------
# Convert threshold segmentation (binary) to float for smoothing
threshold_float = threshold_segmented.astype(np.float32)
# Apply a 3D Gaussian filter to smooth the binary volume
smoothed_seg = gaussian_filter(threshold_float, sigma=1)  # tune sigma as needed

# Compute 3D gradient using Sobel filters along each axis
sobel_x = sobel(smoothed_seg, axis=0)
sobel_y = sobel(smoothed_seg, axis=1)
sobel_z = sobel(smoothed_seg, axis=2)
edges_3d = np.sqrt(sobel_x**2 + sobel_y**2 + sobel_z**2)

# Optionally threshold the gradient to yield a binary edge map
edge_threshold = 0.05  # adjust threshold as needed
edges_binary = edges_3d > edge_threshold

# Visualize one slice of the 3D edge volume for inspection
plt.figure(figsize=(6, 6))
plt.imshow(edges_binary[idx], cmap='gray')
plt.title("3D Smooth Edge Segmentation (Slice {})".format(idx))
plt.axis('off')
plt.show()

# -------------------------
# Save 3D Smooth Edges as a Single Multi-Page TIFF
# -------------------------
edge_output_path = os.path.join(output_dir, "edge_lm1-oct.tif")
imwrite(edge_output_path, (edges_binary.astype(np.uint8) * 255))
print("Saved 3D smooth edges as a single TIFF:", edge_output_path)