In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import numpy as np
from skimage import io, filters, exposure, restoration, img_as_float, img_as_uint
import glob
import tifffile
from pathlib import Path
import matplotlib.pyplot as plt

# Define input and output paths
input_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/denoised/Static-A-1'
output_dir = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Projected/Static-A-1'

# Define channel names
channel_names = ['Cadherins', 'Nuclei', 'Golgi']

# Create output directories for each channel
for channel_name in channel_names:
    channel_dir = os.path.join(output_dir, channel_name)
    os.makedirs(channel_dir, exist_ok=True)

def normalize_image(img):
    """Normalize image to 0-1 range"""
    img_min = np.min(img)
    img_max = np.max(img)
    if img_max > img_min:
        return (img - img_min) / (img_max - img_min)
    return img

def contrast_weighted_projection(image_stack):
    """
    Create a weighted average projection where weights are proportional to local contrast

    Parameters:
    -----------
    image_stack : numpy.ndarray
        Stack of images with shape (z, y, x)

    Returns:
    --------
    numpy.ndarray
        Contrast-weighted projected image
    """
    # Convert to float for processing
    stack = img_as_float(image_stack)

    # Calculate weights based on local contrast for each slice
    weights = np.zeros_like(stack)
    for i in range(stack.shape[0]):
        # Use variance of Laplacian for local contrast
        lap = filters.laplace(stack[i])
        weights[i] = filters.gaussian(np.abs(lap), sigma=2)

    # Normalize weights for each pixel position
    weight_sum = np.sum(weights, axis=0)
    weight_sum[weight_sum == 0] = 1  # Avoid division by zero

    # Initialize output with zeros
    output = np.zeros_like(stack[0])

    # Calculate weighted average
    for i in range(stack.shape[0]):
        output += stack[i] * weights[i] / weight_sum

    return output

def process_4d_microscopy_image(image, filename):
    """
    Process a 4D microscopy image with shape (channels, z-slices, height, width)
    Apply contrast-weighted projection to each channel

    Parameters:
    -----------
    image : numpy.ndarray
        4D input image with shape (channels, z-slices, height, width)
    filename : str
        Original filename for naming the output

    Returns:
    --------
    dict
        Dictionary of processed images
    """
    print(f"  Processing 4D image with shape {image.shape}")

    # Normalize values if needed
    if image.dtype == np.float32 or image.dtype == np.float64:
        # Check if values are outside expected range
        if np.max(image) > 1.0 or np.min(image) < -1.0:
            print(f"  Normalizing image values from range [{np.min(image)}, {np.max(image)}] to [0, 1]")
            for c in range(image.shape[0]):
                for z in range(image.shape[1]):
                    image[c, z] = normalize_image(image[c, z])

    projections = {}

    # Process each channel (only using contrast_weighted_projection)
    for channel in range(min(image.shape[0], len(channel_names))):
        # Get all z-slices for this channel
        channel_data = image[channel]

        try:
            # Apply contrast projection method
            proj = contrast_weighted_projection(channel_data)

            # Enhance contrast with adaptive histogram equalization
            enhanced = exposure.equalize_adapthist(proj, clip_limit=0.03)

            # Add sharpening
            enhanced = filters.unsharp_mask(enhanced, radius=2, amount=1.5)

            # Store result with channel name
            projections[channel_names[channel]] = enhanced
        except Exception as e:
            print(f"  Error applying contrast projection to channel {channel+1}: {e}")

    return projections

def process_all_images():
    """
    Process all 4D microscopy images in the input directory
    """
    # Get all image files with common extensions (case insensitive)
    extensions = ['.tif', '.tiff', '.TIF', '.TIFF']
    all_files = []

    for ext in extensions:
        all_files.extend(glob.glob(os.path.join(input_dir, f'*{ext}')))

    if len(all_files) == 0:
        print(f"No image files found in {input_dir}")
        return []

    # Sort files to ensure consistent processing order
    all_files.sort()

    print(f"Found {len(all_files)} image files. Processing each with contrast projection...")

    processed_files = []

    # Process each image
    for idx, file_path in enumerate(all_files):
        filename = os.path.basename(file_path)
        print(f"Processing image {idx+1}/{len(all_files)}: {filename}")

        try:
            # Load the image
            image = tifffile.imread(file_path)

            # Print image info
            print(f"  Image shape: {image.shape}, dtype: {image.dtype}")

            # Skip if not 4D
            if len(image.shape) != 4:
                print(f"  Skipping - not a 4D image")
                continue

            # Process the 4D image
            projections = process_4d_microscopy_image(image, filename)

            # Save each projection in its corresponding folder
            for channel_name, projection in projections.items():
                # Create output filename that includes the channel type
                base_name = os.path.splitext(filename)[0]
                output_filename = f"{base_name}_contrast_{channel_name}.tif"
                channel_dir = os.path.join(output_dir, channel_name)
                output_path = os.path.join(channel_dir, output_filename)

                # Convert to uint16 for saving
                if projection.dtype == np.float64 or projection.dtype == np.float32:
                    projection = img_as_uint(projection)

                # Save the projection
                tifffile.imwrite(output_path, projection)

                # Add to list of processed files
                processed_files.append(output_path)

            print(f"  Saved {len(projections)} projections")

        except Exception as e:
            print(f"  Error processing image {filename}: {e}")
            import traceback
            traceback.print_exc()
            continue

    return processed_files

# Main execution
print("Starting 4D image processing with contrast projection method...")

# Process all image files in the directory
processed_files = process_all_images()

# Display results
if processed_files:
    print(f"\nSuccessfully processed {len(processed_files)} projections:")
    for output_file in processed_files[:5]:  # Show first 5
        print(f"  - {os.path.basename(output_file)}")
    if len(processed_files) > 5:
        print(f"  ... and {len(processed_files) - 5} more")

    # Try to display the first result in the notebook
    try:
        from IPython.display import Image
        print("\nDisplaying first result:")
        Image(processed_files[0])
    except Exception as e:
        print(f"Couldn't display image: {e}")
else:
    print("No images were processed successfully.")

Mounted at /content/drive
Starting 4D image processing with contrast projection method...
Found 16 image files. Processing each with contrast projection...
Processing image 1/16: denoised_0Pa_A1_19dec21_20xA_L2RA_FlatA_seq001.tif
  Image shape: (3, 13, 1024, 1024), dtype: float32
  Processing 4D image with shape (3, 13, 1024, 1024)
  Normalizing image values from range [-0.03957584500312805, 4.482115745544434] to [0, 1]
  Saved 3 projections
Processing image 2/16: denoised_0Pa_A1_19dec21_20xA_L2RA_FlatA_seq002.tif
  Image shape: (3, 13, 1024, 1024), dtype: float32
  Processing 4D image with shape (3, 13, 1024, 1024)
  Normalizing image values from range [-0.03917171433568001, 5.5141215324401855] to [0, 1]
  Saved 3 projections
Processing image 3/16: denoised_0Pa_A1_19dec21_20xA_L2RA_FlatA_seq003.tif
  Image shape: (3, 13, 1024, 1024), dtype: float32
  Processing 4D image with shape (3, 13, 1024, 1024)
  Normalizing image values from range [-0.05667900666594505, 4.3115363121032715] to [