In [None]:
from __future__ import annotations
import sys
from typing import Optional, Iterable, Union
import h5py
import cv2
import numpy as np
import pandas as pd
from skimage.measure import regionprops, label
from skimage.io import imread
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from scipy import ndimage


### Source the cell images and extract some features (old)

In [None]:
def extract_nuclei_cell_features(image: np.ndarray, mask: np.ndarray, cell_name: str) -> pd.DataFrame:
    # Label connected regions in the mask
    labeled_mask = label(mask)

    # Extract properties for each cell region
    properties = regionprops(labeled_mask, intensity_image=image)

    feat_dict = {}
    
    feat_dict['label'] = properties[0].label
    feat_dict['area'] = properties[0].area
    feat_dict['perimeter'] = properties[0].perimeter
    feat_dict['mean_intensity'] = properties[0].mean_intensity
    feat_dict['eccentricity'] = properties[0].eccentricity
    feat_dict['solidity'] = properties[0].solidity
    feat_dict['extent'] = properties[0].extent
    feat_dict['major_axis_length'] = properties[0].major_axis_length
    feat_dict['minor_axis_length'] = properties[0].minor_axis_length
    feat_dict['cell_name'] = cell_name
    # Calculate total intensity inside and outside the mask
    total_intensity_inside = np.sum(image[mask > 0])
    total_intensity_outside = np.sum(image[mask == 0])
    total_intensity = total_intensity_inside + total_intensity_outside

    # Calculate the ratio of energy outside vs inside
    if total_intensity > 0:
        feat_dict['intensity_inside'] = total_intensity_inside
        feat_dict['intensity_outside'] = total_intensity_outside
        feat_dict['intensity_ratio_outside_inside'] = total_intensity_outside / total_intensity_inside if total_intensity_inside > 0 else np.inf
        feat_dict['intensity_fraction_outside'] = total_intensity_outside / total_intensity
    else:
        feat_dict['intensity_inside'] = 0
        feat_dict['intensity_outside'] = 0
        feat_dict['intensity_ratio_outside_inside'] = 0
        feat_dict['intensity_fraction_outside'] = 0
    return feat_dict    

def normalize_and_filter_image(image: np.ndarray, apply_gaussian: bool = True, sigma: float = 1.0) -> np.ndarray:
    """
    Normalisiert das Bild und wendet optional einen Gauss-Filter an
    
    Args:
        image: Input image array
        apply_gaussian: Whether to apply Gaussian filter
        sigma: Standard deviation for Gaussian filter
    """
    # Optional: Gauss-Filter anwenden
    if apply_gaussian:
        image = ndimage.gaussian_filter(image, sigma=sigma)
    
    # Erst normalisieren
    img_mean = image.mean()
    img_std = image.std()
    if img_std > 0:
        norm_image = (image - img_mean) / img_std
    else:
        print("Warning: Standard deviation is zero during normalization.")
        norm_image = image
    
    return norm_image

def illumination_correction_basicpy(images: list, apply_to_image: np.ndarray = None) -> tuple:
    """
    Führt Beleuchtungskorrektur mit BasicPy durch
    
    Args:
        images: Liste von Bildern für die Hintergrundschätzung
        apply_to_image: Einzelnes Bild zur Korrektur (optional)
    
    Returns:
        (korrigiertes_bild, flatfield, darkfield) wenn apply_to_image gegeben
        (flatfield, darkfield) sonst
    """
    # Konvertiere Liste zu numpy array
    if isinstance(images, list):
        images_array = np.stack(images, axis=0)
    else:
        images_array = images
    
    # Erstelle BaSiC-Objekt
    basic = basicpy.BaSiC(get_darkfield=True, smoothness_flatfield=1)
    
    # Führe BaSiC-Korrektur durch
    basic.fit(images_array)
    
    # Erhalte Flatfield und Darkfield
    flatfield = basic.flatfield
    darkfield = basic.darkfield
    
    if apply_to_image is not None:
        # Wende Korrektur auf einzelnes Bild an
        corrected = basic.transform(apply_to_image[np.newaxis, ...])[0]
        return corrected, flatfield, darkfield
    
    return flatfield, darkfield

def normalize_image(image: np.ndarray) -> np.ndarray:
    img_mean = image.mean()
    img_std = image.std()
    if img_std > 0:
        norm_image = (image - img_mean) / img_std
    else:
        print("Warning: Standard deviation is zero during normalization.")
        norm_image = image
    return norm_image

def get_aggregated_pixel_intensity_stats(image: np.ndarray, mask: np.ndarray, channel: str) -> dict:
    
    stats = {}
    masked_pixels = normalize_image(image)[mask > 0]
    if masked_pixels.size > 0:
        stats[f'mean_intensity_in_mask_{channel}'] = np.mean(masked_pixels)
        stats[f'median_intensity_in_mask_{channel}'] = np.median(masked_pixels)
    else:
        stats[f'mean_intensity_in_mask_{channel}'] = 0
        stats[f'median_intensity_in_mask{channel}'] = 0
    return stats

def load_from_h5_file(file_path: str, count: Union[int, None] = None, features_path: str = None, 
                      apply_gaussian: bool = True, sigma: float = 1.0,
                      apply_illumination_correction: bool = False) -> dict:
    """
    Lädt Daten aus H5-Datei mit optionaler Gauss-Filterung und Beleuchtungskorrektur
    
    Args:
        file_path: Pfad zur H5-Datei
        count: Anzahl der zu ladenden Zellen
        features_path: Pfad zum Speichern der Features
        apply_gaussian: Ob Gauss-Filter angewendet werden soll
        sigma: Standard-Abweichung für Gauss-Filter
        apply_illumination_correction: Ob Beleuchtungskorrektur angewendet werden soll
    """
    examples = {}
    features = []
    
    # Sammle Bilder für Beleuchtungskorrektur falls gewünscht
    channel_images_for_correction = {}
    
    with h5py.File(file_path, "r") as f:
        # Erster Pass: Sammle Bilder für Beleuchtungskorrektur
        if apply_illumination_correction:
            print("Sammle Bilder für Beleuchtungskorrektur...")
            temp_count = 0
            for group_name in tqdm(f.keys()):
                if temp_count >= min(100, count if count else 100):  # Maximal 100 Bilder für Korrektur
                    break
                temp_count += 1
                grp = f[group_name]
                for channel_name, channel_data in grp.items():
                    if channel_name == 'seg':  # Überspringe Segmentierung
                        continue
                    if channel_name not in channel_images_for_correction:
                        channel_images_for_correction[channel_name] = []
                    
                    # Nimm mittleres Bild jedes Kanals
                    plane_names = list(channel_data.keys())
                    middle_idx = len(plane_names) // 2
                    middle_plane = plane_names[middle_idx]
                    image = channel_data[middle_plane][()]
                    channel_images_for_correction[channel_name].append(image)
            
            # Berechne Flatfield und Darkfield für jeden Kanal
            print("Berechne Beleuchtungskorrektur...")
            flatfields = {}
            darkfields = {}
            for channel_name, images in channel_images_for_correction.items():
                if len(images) > 1:
                    try:
                        flatfield, darkfield = illumination_correction_basicpy(images)
                        flatfields[channel_name] = flatfield
                        darkfields[channel_name] = darkfield
                        print(f"Beleuchtungskorrektur für Kanal {channel_name} berechnet")
                    except Exception as e:
                        print(f"Fehler bei Beleuchtungskorrektur für Kanal {channel_name}: {e}")
                        flatfields[channel_name] = None
                        darkfields[channel_name] = None
        
        # Zweiter Pass: Lade und verarbeite alle Daten
        print("Lade und verarbeite Zelldaten...")
        group_index = 0
        for group_name in tqdm(f.keys()):
            if count is not None and group_index >= count:
                break
            group_index += 1
            grp = f[group_name]
            examples[group_name] = {}
            plane_count = 0
            
            for channel_name, channel_data in grp.items():
                channel_path = f"{channel_name}"
                examples[group_name][channel_path] = []
                plane_count = len(channel_data.values())
                
                for plane_name, plane_data in channel_data.items():
                    image = plane_data[()]
                    
                    if channel_name != 'seg':
                        # Beleuchtungskorrektur anwenden falls verfügbar
                        if (apply_illumination_correction and 
                            channel_name in flatfields and 
                            flatfields[channel_name] is not None):
                            try:
                                # Wende BaSiC-Korrektur an
                                basic = basicpy.BaSiC(get_darkfield=True)
                                basic.flatfield = flatfields[channel_name]
                                basic.darkfield = darkfields[channel_name]
                                image = basic.transform(image[np.newaxis, ...])[0]
                            except Exception as e:
                                print(f"Fehler bei Beleuchtungskorrektur für {group_name}, {channel_name}: {e}")
                        
                        # Gauss-Filter und Normalisierung
                        processed_img = normalize_and_filter_image(image, apply_gaussian, sigma)
                    else:
                        processed_img = image
                    
                    examples[group_name][channel_path].append((plane_name, processed_img))
                    
            if features_path is not None:
                try:
                    extracted_nuclei_cell_features = extract_nuclei_cell_features(examples[group_name]['405'][plane_count//2][1], examples[group_name]['seg'][plane_count//2][1], group_name)
                    extract_red_channel_stats = get_aggregated_pixel_intensity_stats(examples[group_name]['561'][plane_count//2][1], examples[group_name]['seg'][plane_count//2][1], channel='561')
                    extract_green_channel_stats = get_aggregated_pixel_intensity_stats(examples[group_name]['488'][plane_count//2][1], examples[group_name]['seg'][plane_count//2][1], channel='488')
                    features_dict = {**extracted_nuclei_cell_features, **extract_red_channel_stats, **extract_green_channel_stats}
                    examples[group_name]['features'] = features_dict
                    features.append(features_dict)
                except Exception as e:
                    print(f"Error extracting features for group {group_name}: {e}", file=sys.stderr)
                    
    features_df = pd.DataFrame(features)
    if features_path is not None:
        file_name_prefix = "" if count is None else str(count) + "_"
        suffix = "_corrected" if apply_illumination_correction else "_gaussian" if apply_gaussian else ""
        features_preprocessed_df.to_parquet(file_name_prefix + features_path.replace(".parquet", f"{suffix}.parquet"))
    return examples

# Laden der Daten mit Beleuchtungskorrektur und Gauss-Filter
examples = load_from_h5_file("/mydata/iris/andreas/fucci_3t3_221124_filtered_noNG030JP208.h5", 
                           count=2000, features_path="cell_features_corrected.parquet",
                           apply_gaussian=True, sigma=1.0,
                           apply_illumination_correction=True)
features_preprocessed_df = pd.read_parquet("2000_cell_features_corrected.parquet")

### Source the cell images and extract some features (new)

In [None]:
from data_processing import H5CellDataset, FeatureExtractor, ImagePreprocessor
import numpy as np
np.random.seed(42)

feature_extractor = FeatureExtractor()
image_preprocessor = ImagePreprocessor(apply_gaussian=True, sigma=0.2, apply_illumination_correction=True)

dataset_preprocessed = H5CellDataset(
    "/myhome/iris/data/fucci_3t3_221124_filtered_noNG030JP208_with_nuclei_seg.h5",
    feature_extractor=feature_extractor,
    preprocessor=image_preprocessor,
    return_raw=True,
    return_features=True,
)
dataset_raw = H5CellDataset(
    "/myhome/iris/data/fucci_3t3_221124_filtered_noNG030JP208_with_nuclei_seg.h5",
    feature_extractor=feature_extractor,
    preprocessor=None,
    return_raw=True,
    return_features=True,
)
indices = np.random.choice(len(dataset_preprocessed), size=1000, replace=False)
features_preprocessed_df = dataset_preprocessed.get_features_dataframe(indices=indices)
features_raw_df = dataset_raw.get_features_dataframe(indices=indices)

### Visualize some cell examples

In [None]:
import matplotlib.pyplot as plt
show_n = 2


def show_cell_examples(dataset, features_df, show_n=10, channels_to_show=["405", "488", "561", "seg", "nuclei_seg", "bf"]):
    cell_indices = features_df["index_cell"].tolist()
    for idx, cell_idx in enumerate(cell_indices):
        cell_data = dataset[cell_idx]
        cells = cell_data["raw_data"]
        cell_name = cell_data["cell_name"]
        if idx >= show_n:
            break
        x_axis_length = len(cells.values()) if channels_to_show is None else len(channels_to_show)
        fig, axes = plt.subplots(x_axis_length, 3, figsize=(15, 5*x_axis_length))
        fig.suptitle(f"Cell {cell_name}")

        features_to_show = ["cell_name", "nucleus_area", "cell_area", "mean_intensity_in_mask_561", "mean_intensity_in_mask_488"]

        features_text = "\n".join(
            [
                f"{k}: {v:.3f}" if isinstance(v, float) else f"{k}: {v}"
                for k, v in features_df[features_to_show].query(f"cell_name == '{cell_name}'")
                .to_dict(orient="records")[0]
                .items()
            ]
        )
        fig.text(0.5, 0.5 * 0.1, features_text, fontsize=10, ha="center")
        channel_idx = 0
        for channel, imgs in cells.items():
            if channels_to_show is not None and channel not in channels_to_show:
                continue
            if channel == "features":
                continue
            for i, (plane_name, img) in enumerate(imgs):
                if x_axis_length == 1:
                    ax = axes[i]
                else:
                    ax = axes[channel_idx, i]
                if type(img) != np.ndarray:
                    import pdb; pdb.set_trace()
                    print(f"Skipping non-array image for cell {cell_name}, channel {channel}, plane {plane_name}")
                    continue
                ax.imshow(
                    img,
                    cmap="gray",
                    aspect="auto",
                )
                if channel != "seg" and channel != "nuclei_seg":
                    ax.contour(cells["seg"][i][1], colors="r")
                    ax.contour(cells["nuclei_seg"][i][1], colors="g")
                ax.axis("off")
                ax.set_title(f"{channel} - Plane {plane_name}")
                
            channel_idx += 1
        plt.show()

print("Preprocessed Data Examples:")
show_cell_examples(dataset_preprocessed, features_preprocessed_df, show_n=show_n)
print("Raw Data Examples:")
show_cell_examples(dataset_raw, features_raw_df, show_n=show_n)

In [None]:
outlier_features = features_preprocessed_df[features_preprocessed_df["cell_nucleus_area_ratio"] > 1]
print(f"Number of cells with nucleus area larger than cell area: {len(outlier_features)} out of {len(features_preprocessed_df)}")

shown_examples = 0
for idx,feature in features_preprocessed_df.iterrows():
    if shown_examples >= 10:
        break
    if feature["cell_nucleus_area_ratio"] > 1:
        shown_examples += 1
        show_cell_examples(dataset_preprocessed, features_df=features_preprocessed_df.iloc[[idx]], show_n=20, channels_to_show=["bf"])

In [None]:
# 2D scatter plot of mean intensities for both channels
fig, ax = plt.subplots(figsize=(10, 8))
color_feature = "cell_area"  # Change this to any other feature if needed
scatter = ax.scatter(
    features_preprocessed_df['mean_intensity_in_mask_488'],
    features_preprocessed_df['mean_intensity_in_mask_561'],
    c=features_preprocessed_df[color_feature],
    cmap='Blues',
    alpha=1.0,
    s=50
)

ax.set_xlabel('Mean Intensity in Mask - Channel 488 (Green: G1 phase)', fontsize=12)
ax.set_ylabel('Mean Intensity in Mask - Channel 561 (Red:  S/G2/M phase)', fontsize=12)
ax.set_title('Mean Intensity Comparison: Channel 488 vs Channel 561', fontsize=14)
ax.grid(True, alpha=0.3)

# Add colorbar
plt.colorbar(scatter, ax=ax, label=color_feature)

plt.tight_layout()
plt.show()

# Print some statistics
print(f"Channel 488 - Mean: {features_preprocessed_df['mean_intensity_in_mask_488'].mean():.3f}, Std: {features_preprocessed_df['mean_intensity_in_mask_488'].std():.3f}")
print(f"Channel 561 - Mean: {features_preprocessed_df['mean_intensity_in_mask_561'].mean():.3f}, Std: {features_preprocessed_df['mean_intensity_in_mask_561'].std():.3f}")
print(f"Correlation: {features_preprocessed_df[['mean_intensity_in_mask_488', 'mean_intensity_in_mask_561']].corr().iloc[0, 1]:.3f}")

### Reduce dimension and visualize clusters from features

In [None]:
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
import numpy as np
import plotly.express as px

# Select numerical features for t-SNE
feature_cols = [col for col in features_preprocessed_df.columns if (col.startswith('cell_') or col.startswith('nucleus_')) and col != 'cell_name'] 
X = features_preprocessed_df[feature_cols].values

# Standardize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Run t-SNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_tsne = tsne.fit_transform(X_scaled)

# Add t-SNE coordinates to dataframe
tsne_features_df = features_preprocessed_df.copy()
tsne_features_df['tsne_1'] = X_tsne[:, 0]
tsne_features_df['tsne_2'] = X_tsne[:, 1]

fig = px.scatter(
    tsne_features_df,
    x="tsne_1",
    y="tsne_2",
    color="mean_intensity_in_mask_488",
    hover_data=["cell_name"],
    title="t-SNE: 488nm Channel (Green)",
    color_continuous_scale="Greens",
)
fig.show()
# Create separate scatter plots for each channel with their respective colors
fig = px.scatter(
    tsne_features_df,
    x="tsne_1",
    y="tsne_2",
    color="mean_intensity_in_mask_561",
    hover_data=["cell_name"],
    title="t-SNE: 561nm Channel (Red)",
    color_continuous_scale="Reds",
)
fig.show()

# Create an RGB color based on the two channel intensities

# Normalize intensities to [0, 1] for color mapping
norm_561 = (tsne_features_df['mean_intensity_in_mask_561'] / 
    tsne_features_df['mean_intensity_in_mask_561'].max()
)
norm_488 = (tsne_features_df['mean_intensity_in_mask_488'] / 
    tsne_features_df['mean_intensity_in_mask_488'].max()
)

# Map 561 to red, 488 to green (assuming these are the respective channels)
rgb_colors = np.stack([norm_561, norm_488, np.zeros_like(norm_561)], axis=1)
rgb_colors = [f"rgb({int(r*255)},{int(g*255)},{int(b*255)})" for r, g, b in rgb_colors]
fig = px.scatter(
    tsne_features_df,
    x="tsne_1",
    y="tsne_2",
    color=rgb_colors,
    hover_data=["cell_name"],
    title="t-SNE: True Color by Channel Intensities",
)
fig.update_traces(marker=dict(size=10, line=dict(width=0)))
fig.show()

In [None]:
# Helper function to get indices for different fluorescence cases
def get_example_indices(norm_561, norm_488, n=3):
    # Very red: high 561, low 488
    red_idx = ((norm_561 > 0.8) & (norm_488 < 0.2)).nlargest(n).index
    # Very green: high 488, low 561
    green_idx = ((norm_488 > 0.8) & (norm_561 < 0.2)).nlargest(n).index
    # Mixed: both high
    mixed_idx = ((norm_561 > 0.6) & (norm_488 > 0.6)).nlargest(n).index
    # Both low
    low_idx = ((norm_561 < 0.2) & (norm_488 < 0.2)).nlargest(n).index
    return {
        "Very Red": red_idx,
        "Very Green": green_idx,
        "Mixed": mixed_idx,
        "Low Both": low_idx,
    }

example_indices = get_example_indices(norm_561, norm_488, n=3)

for label, idxs in example_indices.items():
    print(f"\n{label} examples:")
    for i in idxs:
        print(f"  Cell index: {i}, norm_561={norm_561[i]:.2f}, norm_488={norm_488[i]:.2f}")
        show_cell_examples(dataset_preprocessed, features_df=features_preprocessed_df.iloc[[i]], show_n=1)

In [None]:
import seaborn as sns

# Select features to correlate with mean intensities
corr_features = feature_cols + ['mean_intensity_in_mask_561', 'mean_intensity_in_mask_488']
corr_df = features_preprocessed_df[corr_features]

# Compute correlation matrix
corr_matrix = corr_df.corr()

# Visualize correlation with mean_intensity_in_mask_561 and mean_intensity_in_mask_488
plt.figure(figsize=(20, 12))
sns.heatmap(
    corr_matrix[['mean_intensity_in_mask_561', 'mean_intensity_in_mask_488']].loc[feature_cols],
    annot=True, cmap='coolwarm', vmin=-1, vmax=1
)
plt.title('Correlation of Features with Channel Mean Intensities')
plt.show()

# Find top 4 features most correlated (absolute value) with either channel
top_features = (
    corr_matrix[['mean_intensity_in_mask_561', 'mean_intensity_in_mask_488']]
    .loc[feature_cols]
    .abs()
    .max(axis=1)
    .sort_values(ascending=False)
    .head(6)
    .index.tolist()
)

# Plot scatter plots for each top feature vs both mean intensities
fig, axes = plt.subplots(3, 2, figsize=(28, 20))
for i, feat in enumerate(top_features):
    ax = axes.flat[i]
    sns.scatterplot(
        x=features_preprocessed_df[feat],
        y=features_preprocessed_df['mean_intensity_in_mask_561'],
        label='561', alpha=0.6, ax=ax
    )
    sns.scatterplot(
        x=features_preprocessed_df[feat],
        y=features_preprocessed_df['mean_intensity_in_mask_488'],
        label='488', alpha=0.6, ax=ax
    )
    ax.set_xlabel(feat)
    ax.set_ylabel('Mean Intensity')
    ax.set_title(f"{feat} vs Channel Mean Intensities")
    ax.legend()
plt.tight_layout()
plt.show()

### Get outliers and visualize some examples

In [None]:
from sklearn.ensemble import IsolationForest

def get_outlier_cells(features_df, feature_cols, contamination=0.05):
    outlier_model = IsolationForest(contamination=contamination, random_state=42)
    outlier_features = features_df[feature_cols]
    outlier_pred = outlier_model.fit_predict(outlier_features)
    outlier_idx = outlier_pred == -1
    print(f"Anzahl Outlier: {outlier_idx.sum()}")
    return features_preprocessed_df[outlier_idx].index.tolist()

def show_bf_and_405_images_for_outliers(examples, outlier_idx, max_n=10):
    """
    Zeigt sowohl BF- als auch 405-Bilder für Outlier-Zellen
    """
    for i, idx in enumerate(outlier_idx):
        if i >= max_n:
            break
        show_cell_examples(examples, features_df=features_preprocessed_df.iloc[[idx]], show_n=1)


# Outlier-Erkennung basierend auf verschiedenen Features
feature_cols = [
    col
    for col in features_preprocessed_df.columns
    if not col.endswith("488") and not col.endswith("561")
]

outlier_idx = get_outlier_cells(features_preprocessed_df, feature_cols, contamination=0.05)
show_bf_and_405_images_for_outliers(dataset_preprocessed, outlier_idx, max_n=10)

In [None]:
feature_cols