In [None]:
!pip install tifffile tensorflow scikit-learn matplotlib

In [None]:
import os
import numpy as np
import tifffile as tiff
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import plotly.graph_objects as go

# Define constants
NUM_CLASSES = 15  # Number of organ classes

# Load and preprocess a single .tif stack without resizing
def load_tif_stack(tif_folder):
    print(f"Loading TIFF stack from folder: {tif_folder}")
    tif_files = sorted([os.path.join(tif_folder, f) for f in os.listdir(tif_folder) if f.endswith('.tif')])
    stack = []

    for tif_file in tif_files:
        print(f"Loading file: {tif_file}")
        image = tiff.imread(tif_file)
        stack.append(image)

    stack = np.array(stack)
    stack = (stack - np.min(stack)) / (np.max(stack) - np.min(stack))  # Normalize
    print(f"Loaded stack shape: {stack.shape}")
    return stack

# Autoencoder model for anomaly detection
def build_autoencoder(input_shape):
    inputs = Input(shape=input_shape)

    # Encoder
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2), padding='same')(x)

    # Bottleneck
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)

    # Decoder
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    outputs = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

    return Model(inputs, outputs)

# Ensure all stacks have consistent dimensions by padding
def preprocess_stacks(all_stacks):
    # Determine the target shape (largest dimensions across all stacks)
    target_shape = tuple(np.max([stack.shape for stack in all_stacks], axis=0))

    resized_stacks = []
    for stack in all_stacks:
        # Calculate padding for each dimension
        pad_width = [(0, target_shape[i] - stack.shape[i]) for i in range(len(stack.shape))]
        padded_stack = np.pad(stack, pad_width, mode='constant', constant_values=0)  # Pad with zeros
        resized_stacks.append(padded_stack)

    return np.array(resized_stacks)

# Train the autoencoder on slices
def train_shared_autoencoder(autoencoder, all_stacks, epochs=5, batch_size=8):
    autoencoder.compile(optimizer='adam', loss='mse')
    all_data = np.concatenate(all_stacks, axis=0)
    history = autoencoder.fit(
        all_data, all_data,
        epochs=epochs,
        batch_size=batch_size,
        shuffle=True,
        validation_split=0.2
    )
    return history

# Compute reconstruction error for anomaly detection and highlight differences
def compute_reconstruction_error(autoencoder, data):
    reconstructed = autoencoder.predict(data)
    errors = [mean_squared_error(orig.flatten(), recon.flatten()) for orig, recon in zip(data, reconstructed)]
    differences = np.abs(data - reconstructed)
    return errors, differences

# Save slice-by-slice comparisons using Plotly
def save_slice_comparisons(differences, stack, reconstructed, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    for i, (orig, recon, diff) in enumerate(zip(stack, reconstructed, differences)):
        fig = go.Figure()
        fig.add_trace(go.Image(z=orig.squeeze(), name='Original'))
        fig.add_trace(go.Image(z=recon.squeeze(), name='Reconstructed'))
        fig.add_trace(go.Image(z=diff.squeeze(), name='Difference'))
        fig.update_layout(title=f"Slice {i}", showlegend=True)
        fig.write_html(os.path.join(output_folder, f"slice_{i:03d}_comparison.html"))

# Generate summary bar chart and boxplots using Plotly
def generate_summary_metrics(group_errors, output_path):
    group_means = {group: np.mean(errors) for group, errors in group_errors.items()}
    group_stds = {group: np.std(errors) for group, errors in group_errors.items()}

    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=list(group_means.keys()),
        y=list(group_means.values()),
        error_y=dict(type='data', array=list(group_stds.values())),
        name='Mean Reconstruction Error'
    ))

    fig.update_layout(title="Summary of Reconstruction Errors",
                      xaxis_title="Groups",
                      yaxis_title="Mean Reconstruction Error",
                      showlegend=True)
    fig.write_html(output_path)

# Analyze differences between groups using a shared autoencoder
def analyze_group_differences_with_shared_autoencoder(group_folders):
    all_stacks = []
    group_errors = {}

    for group in group_folders:
        print(f"Loading data for group: {group}")
        stack = load_tif_stack(group)
        stack = stack[..., np.newaxis]  # Add channel dimension
        all_stacks.append(stack)

    # Preprocess stacks to ensure consistent dimensions
    all_stacks = preprocess_stacks(all_stacks)

    # Build and train a shared autoencoder
    input_shape = all_stacks[0].shape[1:]  # Assume all stacks now have the same shape
    autoencoder = build_autoencoder(input_shape)
    print("Training shared autoencoder on all groups...")
    train_shared_autoencoder(autoencoder, all_stacks)

    # Compute reconstruction error for each group
    for group, stack in zip(group_folders, all_stacks):
        print(f"Computing reconstruction error for group {group}...")
        errors, differences = compute_reconstruction_error(autoencoder, stack)
        group_errors[group] = errors

        # Save slice-by-slice comparisons
        comparison_folder = os.path.join(group, 'slice_comparisons')
        save_slice_comparisons(differences, stack, autoencoder.predict(stack), comparison_folder)

    # Generate summary metrics
    generate_summary_metrics(group_errors, 'summary_metrics.html')

    return group_errors

# Example usage
# group_folders = ["/path/to/CT_1", "/path/to/CT_2", "/path/to/CT_3", "/path/to/CT_4"]
# group_differences = analyze_group_differences_with_shared_autoencoder(group_folders)

# Example usage
group_folders = ["/Users/alexandergadin/Downloads/group_3-selected/CT_1", "/Users/alexandergadin/Downloads/group_3-selected/CT_2", "/Users/alexandergadin/Downloads/group_3-selected/CT_3", "/Users/alexandergadin/Downloads/group_3-selected/CT_4"]
group_differences = analyze_group_differences_with_shared_autoencoder(group_folders)
