In [1]:
WORKING_DIR = "/home/xavier/Documents/DAE_project"

# Process tables

## Merge motility labels

In [23]:
import pandas as pd
import re
import numpy as np
import os

# --- Configuration ---
WORKING_DIR = "/home/xavier/Documents/DAE_project"

# --- File Paths ---
EXPERIMENT_FILE_PATH = f'{WORKING_DIR}/dataset/Roy_training/Caro 3d 9.7.22_2.20_new.xlsx'
UPDATED_FILE_PATH = f'{WORKING_DIR}/dataset/Roy_training/Updated_with_He_et_al__1994.xlsx'
KAISER_FILE_PATH = f'{WORKING_DIR}/dataset/Roy_training/Kaiser strain list at UCD.xls'
IMAGES_BASE_PATH = f'{WORKING_DIR}/dataset/Roy_training/images/'

OUTPUT_FULL_FILE_PATH = f'{WORKING_DIR}/dataset/Roy_training/merged_strain_data_full.xlsx'
OUTPUT_FILE_PATH = f'{WORKING_DIR}/dataset/Roy_training/merged_strain_data.xlsx'

# --- Data Loading ---
try:
    print("Loading data files...")
    experiment_df = pd.read_excel(EXPERIMENT_FILE_PATH)
    updated_df = pd.read_excel(UPDATED_FILE_PATH)
    kaiser_df = pd.read_excel(KAISER_FILE_PATH)
    print("All files loaded successfully.")
except FileNotFoundError as e:
    print(f"Error: {e}. Please ensure all files are in the correct directory.")
    exit()

# --- Processing the 'Updated' Table ---
print("Processing the 'Updated' table...")

# 1. Select and filter columns
updated_df_processed = updated_df[
    ['Run', 'Mutant #', 'Movies', 'Reference', 'Jiangguo', 'Source_labelled', 'Source', 'Bib']].copy()

# 2. Keep rows where 'Run' has a value and remove duplicates
updated_df_processed.dropna(subset=['Run'], inplace=True)
updated_df_processed.drop_duplicates(inplace=True)
print(f"Filtered down to {len(updated_df_processed)} unique rows with 'Run' values.")


# 3. Clean the 'Mutant #' column to create 'Strain'
def clean_mutant_id(mutant_id):
    """
    Cleans the Mutant ID based on specified rules:
    - If an entry has a letter-digit pattern (e.g., DK1622A), it keeps the pattern and removes the rest (-> DK1622).
    - If an entry starts with a digit, it prepends 'DK'.
    """
    if pd.isna(mutant_id):
        return None

    mutant_id_str = str(mutant_id)
    parts = [part.strip() for part in mutant_id_str.split(',')]

    for part in parts:
        if not part:
            continue
        match = re.match(r'(DK\d+)', part)
        if match:
            return match.group(1)
        # Check if the first character is a digit
        if part and part[0].isdigit():
            return f'DK{part}'
    return parts[0] if parts else None


updated_df_processed['Strain'] = updated_df_processed['Mutant #'].apply(clean_mutant_id)
print("Cleaned 'Mutant #' column into 'Strain'.")


# --- Calculate Final Movie Count from Directories ---
def calculate_final_movies(run_ids_str, images_path):
    """
    Calculates the total number of movies by counting subdirectories for each run ID.
    """
    if pd.isna(run_ids_str) or not os.path.isdir(images_path):
        return 0

    total_movies = 0
    try:
        run_ids = [int(float(run_id.strip())) for run_id in str(run_ids_str).split(',') if run_id.strip()]
    except (ValueError, AttributeError):
        return 0

    # Get a list of all items in the images directory once for efficiency
    all_image_folders = os.listdir(images_path)

    for run_id in run_ids:
        # Pad run_id to 4 digits to match folder naming convention, e.g., 1 -> 0001
        run_id_str_padded = str(run_id).zfill(4)
        for folder_name in all_image_folders:
            if folder_name.endswith(run_id_str_padded):
                run_folder_path = os.path.join(images_path, folder_name)
                if os.path.isdir(run_folder_path):
                    # Count subdirectories inside the run folder
                    num_subfolders = sum(
                        os.path.isdir(os.path.join(run_folder_path, item)) for item in os.listdir(run_folder_path))
                    total_movies += num_subfolders
                    break  # Move to the next run_id once the folder is found
    return total_movies


print("Calculating final movie counts from image directories...")
updated_df_processed['Final Movies'] = updated_df_processed['Run'].apply(
    lambda run_ids: calculate_final_movies(run_ids, IMAGES_BASE_PATH)
)
print("Added 'Final Movies' column.")

# --- Retrieving Original Data from Experiment Table ---
print("Looking up original data from the experiment file...")
experiment_df.dropna(subset=['Run'], inplace=True)
experiment_df['Run'] = experiment_df['Run'].astype(int)


def get_original_info(run_ids_str, exp_df):
    """
    Looks up run IDs in the experiment_df and concatenates original mutant numbers and sources.
    """
    if pd.isna(run_ids_str):
        return pd.Series([None, None], index=['Original Mutant #', 'Original Source'])

    try:
        run_ids = [int(float(run_id.strip())) for run_id in str(run_ids_str).split(',') if run_id.strip()]
    except (ValueError, AttributeError):
        return pd.Series([None, None], index=['Original Mutant #', 'Original Source'])

    matches = exp_df[exp_df['Run'].isin(run_ids)]
    if matches.empty:
        return pd.Series([None, None], index=['Original Mutant #', 'Original Source'])

    original_mutants = ','.join(matches['Mutant #'].dropna().astype(str).unique())
    original_sources = ','.join(matches['Source'].dropna().astype(str).unique())
    return pd.Series([original_mutants, original_sources], index=['Original Mutant #', 'Original Source'])


updated_df_processed[['Original Mutant #', 'Original Source']] = updated_df_processed['Run'].apply(
    lambda run_ids: get_original_info(run_ids, experiment_df)
)
print("Added 'Original Mutant #' and 'Original Source' columns.")

# --- Processing the 'Kaiser' Table ---
print("Processing the 'Kaiser' table...")
# To prevent row duplication during the merge, we must ensure 'DK#' is unique in the Kaiser table.
# We will group by 'DK#' and aggregate the information from other columns.

kaiser_df_processed = kaiser_df[['DK#', 'genotype', 'phenotype', 'References']].copy()
# Drop rows where DK# is null as they can't be used for merging
kaiser_df_processed.dropna(subset=['DK#'], inplace=True)

# Convert all relevant columns to string to prevent aggregation errors with mixed types
for col in ['genotype', 'phenotype', 'References']:
    kaiser_df_processed[col] = kaiser_df_processed[col].astype(str)

# Group by 'DK#' and aggregate the other columns by joining unique, non-null values
kaiser_df_processed = kaiser_df_processed.groupby('DK#').agg({
    'genotype': lambda x: ', '.join(x.replace('nan', '').dropna().unique()),
    'phenotype': lambda x: ', '.join(x.replace('nan', '').dropna().unique()),
    'References': lambda x: ', '.join(x.replace('nan', '').dropna().unique())
}).reset_index()
print("Aggregated Kaiser table to ensure unique DK# entries, preventing duplicates in final output.")

# --- Merging the Tables ---
print("Joining the Updated and Kaiser tables...")
merged_df = pd.merge(
    updated_df_processed,
    kaiser_df_processed,
    left_on='Strain',
    right_on='DK#',
    how='left'
)
print("Join complete.")


# --- Adding Motility Column ---
def determine_motility(row):
    """Determines motility label based on 'Jiangguo' and 'phenotype' columns."""
    jiangguo_label = None
    phenotype_label = None

    if pd.notna(row['Jiangguo']):
        jg_val = str(row['Jiangguo']).strip()
        if jg_val in ['WT', 'A-S+', 'A+S-', 'A-S-']:
            jiangguo_label = jg_val

    if pd.notna(row['phenotype']):
        ph_val = str(row['phenotype'])
        if 'A-S+' in ph_val:
            phenotype_label = 'A-S+'
        elif 'A+S-' in ph_val:
            phenotype_label = 'A+S-'
        elif 'A-S-' in ph_val:
            phenotype_label = 'A-S-'
        elif 'A-' in ph_val:
            phenotype_label = 'A-S+'
        elif 'S-' in ph_val:
            phenotype_label = 'A+S-'

    final_label = jiangguo_label
    if phenotype_label:
        if jiangguo_label and jiangguo_label != phenotype_label:
            final_label = f"{phenotype_label}"
        else:
            final_label = phenotype_label
    return final_label


print("Adding 'motility' column...")
merged_df['motility'] = merged_df.apply(determine_motility, axis=1)

# --- Final Column Selection and Ordering ---
print("Reordering and selecting final columns...")
final_columns = [
    'Strain', 'Run', 'Movies', 'Final Movies',  # Use the new 'Final Movies' column
    'Original Mutant #', 'Original Source',
    'genotype', 'phenotype', 'Reference', 'motility',
    'Source_labelled', 'Source', 'References', 'Bib'
]
final_columns_exist = [col for col in final_columns if col in merged_df.columns]
final_df = merged_df[final_columns_exist]

# Filter out rows where 'Final Movies' is 0
final_df = final_df[final_df['Final Movies'] != 0].copy()

# Sort the final DataFrame by 'Strain'
print("Sorting final data by 'Strain'...")
final_df.sort_values(by='Strain', inplace=True)

# --- Saving the Result ---

final_df.to_excel(OUTPUT_FULL_FILE_PATH, index=False)
final_df[['Strain', 'Run', 'Final Movies', 'motility', 'Bib']].to_excel(OUTPUT_FILE_PATH, index=False)
print(f"Successfully created the final output file: {OUTPUT_FILE_PATH}")



Loading data files...
All files loaded successfully.
Processing the 'Updated' table...
Filtered down to 439 unique rows with 'Run' values.
Cleaned 'Mutant #' column into 'Strain'.
Calculating final movie counts from image directories...
Added 'Final Movies' column.
Looking up original data from the experiment file...
Added 'Original Mutant #' and 'Original Source' columns.
Processing the 'Kaiser' table...
Aggregated Kaiser table to ensure unique DK# entries, preventing duplicates in final output.
Joining the Updated and Kaiser tables...
Join complete.
Adding 'motility' column...
Reordering and selecting final columns...
Sorting final data by 'Strain'...
Successfully created the final output file: /home/xavier/Documents/DAE_project/dataset/Roy_training/merged_strain_data.xlsx


In [21]:
final_df

Unnamed: 0,Strain,Run,Movies,Final Movies,Original Mutant #,Original Source,genotype,phenotype,Reference,motility,Source_labelled,Source,References,Bib
137,ASX1,306,3.0,3,ASX1,"M Fontes, D Kaiser - … of the National Academy...",,,,,,"M Fontes, D Kaiser - … of the National Academy...",,
138,DK101,"407, 620, 660, 767",12.0,9,"DK101 ,DK101","J Hodgkin and D Kaiser, 1977, PNAS",,,https://link.springer.com/article/10.1007/bf00...,WT,Killeen & Nelson 1988,"J Hodgkin and D Kaiser, 1977, PNAS",,"hodgkin1979genetics, cheng1989dsg"
139,DK1013,341,3.0,2,DK1013,"LJ Shimkets - Journal of bacteriology, 1986 - ...",,"Mx1R, Mx4R, Mx8S, Mx8Cp2R, Mx9R, non-fruiting ...",https://journals.asm.org/doi/abs/10.1128/jb.16...,,,"LJ Shimkets - Journal of bacteriology, 1986 - ...",,shimkets1986correlation
140,DK1016,342,3.0,3,DK1016,"LJ Shimkets - Journal of bacteriology, 1986 - ...",,"Mx1R, Mx4R, Mx8S, Mx8Cp2S, non-fruiting in col...",https://journals.asm.org/doi/epdf/10.1128/jb.1...,,,"LJ Shimkets - Journal of bacteriology, 1986 - ...",,shimkets1986role
141,DK1031,338,3.0,2,DK1031,"LJ Shimkets - Journal of bacteriology, 1986 - ...",,"Mx1R, Mx4R, Mx8S, Mx8Cp2S, Mx9R, non-fruiting ...",https://journals.asm.org/doi/abs/10.1128/jb.16...,,,"LJ Shimkets - Journal of bacteriology, 1986 - ...",,shimkets1986correlation
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
476,MXAN7164,570,3.0,3,MXAN7164,,,,https://www.proquest.com/docview/305383337?pq-...,,,,,caberoy2005coordinating
479,Omega4531,626,3.0,1,Omega4531,"L Kroos, A Kuspa, D Kaiser - Developmental bio...",,,https://www.sciencedirect.com/science/article/...,,Shimkets 1998,"L Kroos, A Kuspa, D Kaiser - Developmental bio...",,kroos1986global
410,esgWen,227,3.0,2,esgWen,,,,,,,,,
477,omega4469,685,3.0,3,omega4469,"L Kroos, A Kuspa, D Kaiser - Developmental bio...",,,https://www.sciencedirect.com/science/article/...,,Shimkets 1998,"L Kroos, A Kuspa, D Kaiser - Developmental bio...",,kuspa1986intercellular


## Plot WT aggregating distribution

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

WT_LABEL_PATH = f"{WORKING_DIR}/dataset/WT/labeling_sheet.csv"

wt_df = pd.read_csv(WT_LABEL_PATH)
# Group by run_id and count T/F in aggregates_formed
counts = wt_df.groupby("run_id")["aggregates_formed"].value_counts().unstack(fill_value=0)

# Create bar plot
fig, ax = plt.subplots(figsize=(10, 6))

# F on top, T on bottom → stack bars with T first, then F
counts.plot(kind="bar", stacked=True, ax=ax, color={"T": "tab:blue", "F": "tab:orange"})

ax.set_ylabel("Count")
ax.set_title("Counts of Aggregates Formed (T and F) by run_id")
ax.legend(title="Aggregates Formed")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


# Perform AUC and PCA analysis for WT aggregate formation and motility

In [None]:
# ==============================================================================
# Integrated Script: Trajectory Distance, Predictive Analysis, Dimensionality Reduction, and Frame Extraction
# ==============================================================================
# Purpose:
# 1. Load and process trajectory data, imputing NaN values with the last valid frame.
# 2. Print a summary of sample and run counts for each class.
# 3. Calculate pairwise distances and train SVM classifiers to evaluate predictive power (AUC).
#    Includes a fallback from StratifiedGroupKFold to StratifiedKFold if splits are invalid.
# 4. Generate bar plots of AUC scores and SAVE THE UNDERLYING DATA to a CSV file.
# 5. Perform dimensionality reduction (PCA/UMAP) on the feature space at specified time points,
#    and SAVE THE REDUCED COORDINATES to a CSV file.
# 6. Optionally, visualize the SVM decision boundary on the dimensionality reduction plots,
#    now including a color bar to indicate class probability.
# 7. For specified time points, find the corresponding raw images, perform a center crop,
#    and save the processed images to an output folder, organized by time and class.
#
# REVISIONS IN THIS VERSION:
# - Optimized the SVM boundary visualization in `plot_dimensionality_reduction` for significant speed improvement.
#   - Increased the `meshgrid` step size to reduce the number of prediction points.
#   - Removed the unnecessary cross-validation loop for plotting; a single SVM is now trained on the
#     sampled 2D data for a much faster, yet still representative, visualization.
# - Maintained all advanced features: probability-based shading, color consistency, and the color bar.
# ==============================================================================

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist, pdist
from scipy.stats import sem
from itertools import combinations, product
from multiprocessing import Pool, cpu_count
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.metrics import roc_curve, auc
from sklearn.decomposition import PCA
from umap import UMAP
import warnings
import math
import cv2
import glob
import shutil

import matplotlib as mpl

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['text.usetex'] = False

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

CONFIG = {
    # --- Analysis Setup ---
    "analysis_type": "WT",  # "WT" or "motility"
    "random_seed": 42,  # Seed for all random operations to ensure reproducibility

    # --- Data and Model Paths ---
    "features_base_dir_wt": f"{WORKING_DIR}/encoded_features/WT_features",
    "labeling_csv_path_wt": f"{WORKING_DIR}/dataset/WT/labeling_sheet.csv",
    "WT_img_dir": f"{WORKING_DIR}/dataset/WT/images",

    "features_base_dir_motility": f"{WORKING_DIR}/encoded_features/Roy_training_features",
    "motility_csv_path": f'{WORKING_DIR}/dataset/Roy_training/merged_strain_data.xlsx',
    "motility_img_dir": f"{WORKING_DIR}/dataset/Roy_training/images",

    # --- Output Configuration ---
    "analysis_output_dir": f"{WORKING_DIR}/images/figure6/%s_analysis",
    "output_figure_name": "distance_and_prediction_summary.pdf",
    "roc_figure_name": "roc_curves_combined.pdf",
    "auc_barplot_name": "auc_barplot.pdf",
    "pca_plot_name": "pca_plot.pdf",
    "umap_plot_name": "umap_plot.pdf",

    # --- Rerun and Visualization Settings ---
    "force_rerun": True,
    "show_movie_distance_analysis": True,
    "n_splits": 3,

    # --- Frame Copying Settings ---
    "copy_frames": False,
    "copied_frames_dir": "copied_frames",

    # --- Analysis Parameters ---
    "selected_time_points": [1440, 0],
    "dist_method": "euclidean",
    "tolerance": 90,
    "num_workers": max(1, cpu_count() - 2),
    "required_frames_motility": 1441,

    # --- Dimensionality Reduction Settings ---
    "dimensionality_reduction": {
        "run": True,
        "method": "PCA",  # "PCA" or "UMAP"
        "sample_equal": True,
        "plot_in_one_figure": False,
        "show_svm_boundary": False,
        "plot_mean_features_at_times": []
    },

    # --- Motility Analysis Specific ---
    "motility_target_classes": ['WT', 'A+S-', 'A-S+', 'A-S-'],
    "motility_comparison_pairs": [
        ('WT', 'A+S-'),
        ('WT', 'A-S+'),
        ('WT', 'A-S-'),
        ('A+S-', 'A-S+'),
    ],
}

# Set the global random seed from the config for numpy operations
np.random.seed(CONFIG["random_seed"])

CONFIG["analysis_output_dir"] = CONFIG["analysis_output_dir"] % CONFIG["analysis_type"]
# Add paths for cached files
CONFIG["cached_distances_path"] = os.path.join(CONFIG['analysis_output_dir'], "cached_movie_distances.npz")
CONFIG["cached_dist_matrix_path"] = os.path.join(CONFIG['analysis_output_dir'], "cached_dist_matrix.npz")


# --- Image Processing Helper Functions ---

def resize_crop(img_dir, resize_by=1., resolution=512, brightness_norm=True, brightness_mean=107):
    """
    Loads an image, resizes it, and takes a center crop.
    """
    img = cv2.imread(img_dir, cv2.IMREAD_UNCHANGED)
    if img is None:
        return None

    if img.dtype != np.uint8:
        img = (img / 256).astype(np.uint8)

    if len(img.shape) == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    img_shape = img.shape
    resize_shape = (int(img_shape[1] * resize_by), int(img_shape[0] * resize_by))

    if resize_by != 1:
        img = cv2.resize(img, resize_shape, cv2.INTER_LANCZOS4)

    h, w = img.shape[:2]
    start_y = max(0, (h - resolution) // 2)
    start_x = max(0, (w - resolution) // 2)
    new_img = img[start_y:start_y + resolution, start_x:start_x + resolution]

    if brightness_norm:
        obj_v = np.mean(new_img)
        value = brightness_mean - obj_v
        value_array = np.full(new_img.shape, value, dtype=np.float64)
        new_img = np.clip(new_img.astype(np.float64) + value_array, 0, 255).astype(np.uint8)

    return new_img


def copy_and_process_frames(run_id, scope_id, class_name, time_points, base_img_dir, output_dir):
    """
    Finds, processes, and saves specific frames for a given sample.
    Finds the closest available frame if the exact frame is not found.
    """
    for t in time_points:
        try:
            frame_output_dir = os.path.join(output_dir, f"{t}min", class_name)
            os.makedirs(frame_output_dir, exist_ok=True)

            run_dir_pattern = os.path.join(base_img_dir, f"*Run{run_id:04d}*")
            matching_run_dirs = glob.glob(run_dir_pattern)

            scope_dir_path = ""
            if matching_run_dirs:
                run_dir = matching_run_dirs[0]
                scope_dir_path = os.path.join(run_dir, f"Scope{scope_id:02d}")
            else:
                scope_dir_path = os.path.join(base_img_dir, f"Run{run_id:04d}", f"Scope{scope_id:02d}")

            if not os.path.isdir(scope_dir_path):
                print(
                    f"  Warning: Scope directory not found for Run {run_id}, Scope {scope_id}. Path: {scope_dir_path}")
                continue

            all_images = glob.glob(os.path.join(scope_dir_path, "*.jpg"))
            if not all_images:
                print(f"  Warning: No JPG images found for Run {run_id}, Scope {scope_id} in {scope_dir_path}")
                continue

            target_frame_idx = t + 1
            best_match_path = None
            min_diff = float('inf')
            frame_number_pattern = re.compile(r'_(\d+)\.jpg$')

            for img_path in all_images:
                match = frame_number_pattern.search(os.path.basename(img_path))
                if match:
                    frame_num = int(match.group(1))
                    diff = abs(frame_num - target_frame_idx)
                    if diff < min_diff:
                        min_diff = diff
                        best_match_path = img_path

            if best_match_path:
                source_path = best_match_path
                found_frame_num_match = frame_number_pattern.search(os.path.basename(source_path))
                if found_frame_num_match:
                    found_frame_num = int(found_frame_num_match.group(1))
                    if found_frame_num != target_frame_idx:
                        print(
                            f"  Info: For Run {run_id}, Scope {scope_id}, Time {t} min, using closest frame {found_frame_num}.")

                processed_img = resize_crop(source_path)
                if processed_img is not None:
                    dest_filename = f"{class_name}_Run{run_id:04d}_Scope{scope_id:02d}.jpg"
                    dest_path = os.path.join(frame_output_dir, dest_filename)
                    cv2.imwrite(dest_path, processed_img)
                else:
                    print(f"  Warning: Failed to process image: {source_path}")
            else:
                print(f"  Warning: Image not found for Run {run_id}, Scope {scope_id}, Time {t} min.")

        except Exception as e:
            print(f"  ERROR processing frame for Run {run_id}, Scope {scope_id} at time {t}: {e}")


# --- Analysis Helper Functions ---

def impute_nans_with_previous_frame(trajectory):
    for i in range(1, trajectory.shape[0]):
        if np.isnan(trajectory[i]).any():
            trajectory[i] = trajectory[i - 1]
    return trajectory


def get_closest_frame_index(requested_frame, total_frames):
    if total_frames == 0:
        raise ValueError("Cannot get frame index from a trajectory with zero frames.")
    max_index = total_frames - 1
    return min(requested_frame, max_index)


def train_and_get_roc_data(features, labels, groups, use_group_kfold, analysis_name, precomputed_kernel=False):
    n_splits = CONFIG['n_splits']
    random_seed = CONFIG['random_seed']

    if use_group_kfold:
        cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
        split_iterator = cv.split(features, labels, groups)
    else:
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
        split_iterator = cv.split(features, labels)

    tprs, aucs = [], []
    mean_fpr = np.linspace(0, 1, 100)

    try:
        for train_idx, test_idx in split_iterator:
            if len(np.unique(labels[test_idx])) < 2:
                print(f"  Skipping invalid fold in CV for {analysis_name}: test set contains only one class.")
                continue

            if precomputed_kernel:
                model = SVC(kernel='precomputed', class_weight='balanced', probability=True, random_state=random_seed)
                model.fit(features[np.ix_(train_idx, train_idx)], labels[train_idx])
                probas_ = model.predict_proba(features[np.ix_(test_idx, train_idx)])
            else:
                pipeline = make_pipeline(StandardScaler(), SVC(kernel='rbf', class_weight='balanced', probability=True,
                                                               random_state=random_seed))
                pipeline.fit(features[train_idx], labels[train_idx])
                probas_ = pipeline.predict_proba(features[test_idx])

            fpr, tpr, _ = roc_curve(labels[test_idx], probas_[:, 1])
            tprs.append(np.interp(mean_fpr, fpr, tpr))
            tprs[-1][0] = 0.0
            aucs.append(auc(fpr, tpr))
    except Exception as e:
        print(f"  CV failed for {analysis_name} with error: {e}. Cannot generate ROC data.")
        return None

    if not tprs:
        print(f"  Could not generate any valid CV folds for {analysis_name}. Cannot generate ROC data.")
        return None

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = np.mean(aucs)
    std_auc = np.std(aucs)
    std_tpr = np.std(tprs, axis=0)

    return mean_fpr, mean_tpr, std_tpr, mean_auc, std_auc


# --- Plotting Functions ---

def plot_auc_barplot(auc_scores, time_points, output_path, title_prefix="", force_rerun=False):
    if not force_rerun and os.path.exists(output_path):
        print(f"Skipping existing AUC bar plot: {os.path.basename(output_path)}")
        return

    print(f"Generating AUC bar plot: {os.path.basename(output_path)}")
    labels = [f'{t // 60} h' for t in time_points]
    auc_means = [auc_scores.get(f'Time {t} min', (np.nan, np.nan))[0] for t in time_points]
    auc_stds = [auc_scores.get(f'Time {t} min', (np.nan, np.nan))[1] for t in time_points]

    y_err_lower, y_err_upper = [], []
    for mean, std in zip(auc_means, auc_stds):
        if np.isnan(mean) or np.isnan(std):
            y_err_lower.append(0)
            y_err_upper.append(0)
            continue
        margin = 1.96 * std
        y_err_upper.append(min(mean + margin, 1.0) - mean)
        y_err_lower.append(mean - max(mean - margin, 0.0))

    x = np.arange(len(labels))
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.bar(x, auc_means, yerr=[y_err_lower, y_err_upper], capsize=5, color='skyblue', ecolor='gray',
           label='Mean AUC (95% CI)')
    ax.set_ylabel('Mean AUC Score')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_ylim(0.45, 1.05)
    ax.axhline(y=0.5, color='r', linestyle='--', label='Random Chance (AUC=0.5)')
    ax.legend()
    plt.tight_layout()
    fig.savefig(output_path, dpi=300)
    plt.close(fig)


def plot_dimensionality_reduction(features, labels, time_points, method='PCA', sample_equal=True,
                                  plot_in_one_figure=True, output_path='dim_red.pdf', title_prefix="",
                                  legend_map=None, show_svm_boundary=False, force_rerun=False,
                                  plot_mean_features_at_times=None):
    """
    Samples data, performs PCA/UMAP, plots results, saves the data to a CSV,
    and optionally shows a shaded SVM decision boundary with a color bar.
    OPTIMIZED for performance.
    """
    if not force_rerun and os.path.exists(output_path) and plot_in_one_figure:
        print(f"Skipping existing dimensionality reduction plot: {os.path.basename(output_path)}")
        return

    print(f"Generating {method} plot(s): {os.path.basename(output_path)}")
    print(f"  Initial number of samples for plotting: {features.shape[0]}")

    random_seed = CONFIG['random_seed']
    rng = np.random.RandomState(random_seed)

    if plot_in_one_figure:
        fig, axes = plt.subplots(1, len(time_points), figsize=(5.5 * len(time_points), 5), sharex=False, sharey=False)
        if len(time_points) == 1: axes = [axes]
    else:
        fig, axes = None, None

    unique_labels = np.unique(labels)
    if len(unique_labels) < 2:
        print("Warning: Only one class present. Skipping dimensionality reduction plot.")
        return

    total_frames = features.shape[1]
    contour_object = None

    color_class_0 = '#3b75af'
    color_class_1 = '#d1495b'
    custom_palette_dict = None
    if legend_map:
        sorted_keys = sorted(legend_map.keys())
        if len(sorted_keys) == 2:
            custom_palette_dict = {
                legend_map[sorted_keys[0]]: color_class_0,
                legend_map[sorted_keys[1]]: color_class_1
            }

    for i, t_req in enumerate(time_points):
        t = get_closest_frame_index(t_req, total_frames)
        individual_output_path = output_path.replace('.pdf', f'_{t_req}min.pdf')
        if not plot_in_one_figure and not force_rerun and os.path.exists(individual_output_path):
            print(f"Skipping existing plot: {os.path.basename(individual_output_path)}")
            continue

        features_at_t = features[:, t, :]

        if sample_equal:
            min_samples = min(np.sum(labels == unique_labels[0]), np.sum(labels == unique_labels[1]))
            indices_0 = rng.choice(np.where(labels == unique_labels[0])[0], min_samples, replace=False)
            indices_1 = rng.choice(np.where(labels == unique_labels[1])[0], min_samples, replace=False)
            sampled_indices = np.concatenate([indices_0, indices_1])
            features_for_dim_red = features_at_t[sampled_indices]
            labels_for_dim_red = labels[sampled_indices]
        else:
            features_for_dim_red = features_at_t
            labels_for_dim_red = labels

        reducer = PCA(n_components=2, random_state=random_seed) if method == 'PCA' else UMAP(n_components=2,
                                                                                             random_state=random_seed)

        try:
            transformed_features = reducer.fit_transform(features_for_dim_red)
        except ValueError as e:
            print(f"  ERROR: Could not perform {method} at time {t_req} min. Error: {e}")
            continue

        plot_labels = pd.Series(labels_for_dim_red).map(legend_map) if legend_map else labels_for_dim_red

        base_name = os.path.splitext(os.path.basename(output_path))[0]
        data_filename = f"{base_name}_{t_req}min_data.csv"
        data_output_path = os.path.join(os.path.dirname(output_path), data_filename)
        if not os.path.exists(data_output_path) or force_rerun:
            print(f"  Saving {method} data for time {t_req} min to {os.path.basename(data_output_path)}...")
            pd.DataFrame({
                f'{method} Component 1': transformed_features[:, 0],
                f'{method} Component 2': transformed_features[:, 1],
                'Group': plot_labels
            }).to_csv(data_output_path, index=False, float_format='%.4f')
        else:
            print(f"  Skipping existing {method} data file: {os.path.basename(data_output_path)}")

        ax = axes[i] if plot_in_one_figure else plt.subplots(figsize=(7, 6))[1]
        if not plot_in_one_figure: fig_single = ax.get_figure()

        if show_svm_boundary:
            print(f"  Visualizing SVM boundaries for time {t_req} min...")
            x_min, x_max = transformed_features[:, 0].min() - 1, transformed_features[:, 0].max() + 1
            y_min, y_max = transformed_features[:, 1].min() - 1, transformed_features[:, 1].max() + 1

            # OPTIMIZATION: Increased meshgrid step size from 0.02 to 0.1 for a >20x speedup.
            # This creates a coarser grid for visualization without significant loss of quality.
            step_size = 0.1
            xx, yy = np.meshgrid(np.arange(x_min, x_max, step_size), np.arange(y_min, y_max, step_size))

            # OPTIMIZATION: Removed CV loop for visualization. A single SVM trained on the
            # sampled 2D data is sufficient and much faster for plotting a representative boundary.
            svm_2d = SVC(kernel='rbf', gamma='auto', probability=True, random_state=random_seed)
            svm_2d.fit(transformed_features, labels_for_dim_red)
            Z = svm_2d.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
            Z = Z.reshape(xx.shape)

            contour = ax.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.5, levels=np.linspace(0, 1, 21))
            contour_object = contour

            if not plot_in_one_figure:
                cbar = fig_single.colorbar(contour, ax=ax)
                cbar.set_label('Class Probability', rotation=270, labelpad=15)
                cbar.set_ticks([0, 0.5, 1])
                if legend_map:
                    class0_label, class1_label = legend_map.get(0, 'Class 0'), legend_map.get(1, 'Class 1')
                    cbar.ax.set_yticklabels([class0_label, 'Boundary', class1_label], fontsize=8, rotation=90,
                                            va='center')

        sns.scatterplot(x=transformed_features[:, 0], y=transformed_features[:, 1], hue=plot_labels,
                        palette=custom_palette_dict if custom_palette_dict else 'Set2',
                        ax=ax, alpha=0.8, edgecolor='k')
        ax.set_title(f'Time {t_req // 60} h')
        ax.set_xlabel(f'{method} 1')
        ax.set_ylabel(f'{method} 2')
        ax.legend(title='Group')

        if plot_mean_features_at_times and t_req in plot_mean_features_at_times:
            print(f"  Calculating and plotting mean features for time {t_req} min...")
            mean_feature_c0 = np.mean(features_at_t[labels == unique_labels[0]], axis=0)
            mean_feature_c1 = np.mean(features_at_t[labels == unique_labels[1]], axis=0)

            transformed_mean_c0 = reducer.transform(mean_feature_c0.reshape(1, -1))
            transformed_mean_c1 = reducer.transform(mean_feature_c1.reshape(1, -1))
            mean_of_means = (transformed_mean_c0 + transformed_mean_c1) / 2.0
            ax.scatter(transformed_mean_c0[:, 0], transformed_mean_c0[:, 1], marker='*', s=300, c=color_class_0,
                       edgecolor='white', zorder=10)
            ax.scatter(transformed_mean_c1[:, 0], transformed_mean_c1[:, 1], marker='*', s=300, c=color_class_1,
                       edgecolor='white', zorder=10)
            ax.scatter(mean_of_means[:, 0], mean_of_means[:, 1], marker='*', s=300, c='yellow', edgecolor='black',
                       zorder=10)

        if not plot_in_one_figure:
            plt.tight_layout()
            fig_single.savefig(individual_output_path, dpi=300)
            plt.close(fig_single)

    if plot_in_one_figure:
        fig.tight_layout(rect=[0, 0, 0.9, 1])
        if contour_object:
            cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
            cbar = fig.colorbar(contour_object, cax=cbar_ax)
            cbar.set_label('Class Probability', rotation=270, labelpad=15)
            cbar.set_ticks([0, 0.5, 1])
            if legend_map:
                class0_label, class1_label = legend_map.get(0, 'Class 0'), legend_map.get(1, 'Class 1')
                cbar.ax.set_yticklabels([class0_label, 'Boundary', class1_label], fontsize=8, rotation=90, va='center')
        fig.savefig(output_path, dpi=300)
        plt.close(fig)


# --- Main Analysis Functions ---

def run_wt_analysis():
    print("\n[STAGE 1/4] Loading and preparing WT data...")
    try:
        labels_df = pd.read_csv(CONFIG['labeling_csv_path_wt'])
    except FileNotFoundError:
        print(f"ERROR: Labeling sheet not found at '{CONFIG['labeling_csv_path_wt']}'.")
        return

    labels_df = labels_df.dropna(subset=['aggregates_formed'])
    labels_df = labels_df[labels_df['aggregates_formed'].isin(['T', 'F'])]
    labels_df['label'] = labels_df['aggregates_formed'].map({'T': 1, 'F': 0})

    print("\n--- Data Summary ---")
    for class_label, count in labels_df['aggregates_formed'].value_counts().items():
        class_name = 'Aggregate' if class_label == 'T' else 'No Aggregate'
        num_runs = labels_df[labels_df['aggregates_formed'] == class_label]['run_id'].nunique()
        print(f"  Class '{class_name}': {count} samples from {num_runs} unique runs.")
    print("--------------------\n")

    all_features = {}
    if CONFIG['copy_frames']:
        print("[STAGE 1.5/4] Copying and processing selected frames for WT analysis...")
        copied_frames_output_dir = os.path.join(CONFIG['analysis_output_dir'], CONFIG['copied_frames_dir'], 'WT')

    for _, row in labels_df.iterrows():
        run_id, scope_id = row['run_id'], row['scope_id']
        key = f"R{run_id:04d}_S{scope_id:02d}"
        npz_path = os.path.join(CONFIG['features_base_dir_wt'], f"Run{run_id:04d}", f"Scope{scope_id:02d}",
                                "features.npz")
        if not os.path.exists(npz_path):
            print(f"Warning: Feature file not found for {key}, skipping.")
            continue

        features_data = impute_nans_with_previous_frame(np.load(npz_path)['z'])
        all_features[key] = {'label': row['label'], 'run_id': row['run_id'], 'features': features_data}

        if CONFIG['copy_frames']:
            class_name = 'Aggregate' if row['label'] == 1 else 'No_Aggregate'
            copy_and_process_frames(run_id, scope_id, class_name, CONFIG['selected_time_points'], CONFIG['WT_img_dir'],
                                    copied_frames_output_dir)

    if not all_features: print("ERROR: No valid data could be loaded."); return

    min_frames = min(v['features'].shape[0] for v in all_features.values())
    all_trajectories = np.array([v['features'][:min_frames] for v in all_features.values()])
    all_labels = np.array([v['label'] for v in all_features.values()])
    all_groups = np.array([v['run_id'] for v in all_features.values()])

    if len(np.unique(all_labels)) < 2: print("Warning: Insufficient data for one or both classes."); return

    print("\n[STAGE 2/4] Analyzing predictive power of features...")
    auc_scores = {}
    n_splits = CONFIG['n_splits']
    use_group_kfold = len(np.unique(all_groups[all_labels == 0])) >= n_splits and len(
        np.unique(all_groups[all_labels == 1])) >= n_splits

    for t in CONFIG['selected_time_points']:
        features_at_t = all_trajectories[:, get_closest_frame_index(t, min_frames), :]
        analysis_name = f"Time {t} min"
        roc_data = train_and_get_roc_data(features_at_t, all_labels, all_groups, use_group_kfold, analysis_name)
        if roc_data is None and use_group_kfold:
            print(f"  Warning: StratifiedGroupKFold failed for {analysis_name}. Retrying with StratifiedKFold.")
            roc_data = train_and_get_roc_data(features_at_t, all_labels, all_groups, False, analysis_name)
        if roc_data: auc_scores[analysis_name] = (roc_data[3], roc_data[4])

    print("\n[STAGE 3/4] Saving result data...")
    if auc_scores:
        auc_data_path = os.path.join(CONFIG['analysis_output_dir'], "WT_auc_scores.csv")
        if not os.path.exists(auc_data_path) or CONFIG['force_rerun']:
            auc_data = [
                {'Time (min)': int(re.search(r'(\d+)', name).group(1)), 'Mean AUC': mean_auc, 'Std Dev AUC': std_auc}
                for name, (mean_auc, std_auc) in auc_scores.items()]
            pd.DataFrame(auc_data).sort_values('Time (min)').to_csv(auc_data_path, index=False, float_format='%.4f')
            print(f"  Saved WT AUC scores to: {os.path.basename(auc_data_path)}")
        else:
            print(f"  Skipping existing WT AUC scores file: {os.path.basename(auc_data_path)}")

    print("\n[STAGE 4/4] Generating analysis plots...")
    plot_auc_barplot(auc_scores, CONFIG['selected_time_points'],
                     os.path.join(CONFIG['analysis_output_dir'], f"WT_{CONFIG['auc_barplot_name']}"), "WT Analysis: ",
                     CONFIG['force_rerun'])

    if CONFIG['dimensionality_reduction']['run']:
        dr_config = CONFIG['dimensionality_reduction']
        output_name = f"WT_{dr_config['method']}_plot.pdf"
        plot_dimensionality_reduction(all_trajectories, all_labels, time_points=CONFIG['selected_time_points'],
                                      method=dr_config['method'], sample_equal=dr_config['sample_equal'],
                                      plot_in_one_figure=dr_config['plot_in_one_figure'],
                                      output_path=os.path.join(CONFIG['analysis_output_dir'], output_name),
                                      legend_map={0: 'No Aggregate', 1: 'Aggregate'},
                                      show_svm_boundary=dr_config['show_svm_boundary'],
                                      force_rerun=CONFIG['force_rerun'],
                                      plot_mean_features_at_times=dr_config.get('plot_mean_features_at_times'))


def run_motility_analysis():
    print("\n[STAGE 1/4] Loading and preparing motility data...")
    try:
        labels_df = pd.read_excel(CONFIG['motility_csv_path'])
    except FileNotFoundError:
        print(f"ERROR: Could not load motility data xlsx. Check path: {CONFIG['motility_csv_path']}")
        return

    labels_df = labels_df.dropna(subset=['motility', 'Strain'])
    labels_df = labels_df[labels_df['motility'].isin(CONFIG['motility_target_classes'])]
    strain_to_label_map = pd.Series(labels_df.motility.values, index=labels_df.Strain).to_dict()

    all_samples_dict = {class_name: [] for class_name in CONFIG['motility_target_classes']}
    dir_pattern = re.compile(r'Run(\d+)_Mutant(\d+)')
    scope_pattern = re.compile(r'Scope(\d+)')

    for dir_name in os.listdir(CONFIG['features_base_dir_motility']):
        match = dir_pattern.match(dir_name)
        if not match: continue
        run_id, mutant_num = int(match.group(1)), int(match.group(2))
        strain_id = f"DK{mutant_num}"
        if strain_id in strain_to_label_map:
            label = strain_to_label_map[strain_id]
            for scope_dir_name in os.listdir(os.path.join(CONFIG['features_base_dir_motility'], dir_name)):
                scope_match = scope_pattern.match(scope_dir_name)
                if scope_match:
                    scope_id = int(scope_match.group(1))
                    npz_path = os.path.join(CONFIG['features_base_dir_motility'], dir_name, scope_dir_name,
                                            "features.npz")
                    if os.path.exists(npz_path):
                        features = impute_nans_with_previous_frame(np.load(npz_path)['z'])
                        all_samples_dict[label].append({'features': features, 'run_id': run_id, 'scope_id': scope_id})

    print("\n--- Data Summary ---")
    for class_name, samples in all_samples_dict.items():
        print(
            f"  Class '{class_name}': {len(samples)} samples from {len(np.unique([s['run_id'] for s in samples]))} unique runs.")
    print("--------------------\n")

    if CONFIG['copy_frames']:
        print("[STAGE 1.5/4] Copying and processing selected frames for motility analysis...")
        copied_frames_output_dir = os.path.join(CONFIG['analysis_output_dir'], CONFIG['copied_frames_dir'], 'motility')
        for class_name, samples in all_samples_dict.items():
            for sample in samples:
                copy_and_process_frames(sample['run_id'], sample['scope_id'], class_name,
                                        CONFIG['selected_time_points'], CONFIG['motility_img_dir'],
                                        copied_frames_output_dir)

    all_trajectories_dict = {k: [s['features'] for s in v] for k, v in all_samples_dict.items()}
    all_runs_dict = {k: [s['run_id'] for s in v] for k, v in all_samples_dict.items()}
    required_frames = CONFIG['required_frames_motility']
    for class_name, trajs in all_trajectories_dict.items():
        processed = [
            np.vstack([t, np.repeat(t[-1:], required_frames - len(t), axis=0)]) if len(t) < required_frames else t[
                :required_frames] for t in trajs]
        all_trajectories_dict[class_name] = np.array(processed) if processed else np.array([])

    for class1_name, class2_name in CONFIG['motility_comparison_pairs']:
        print(f"\n--- Comparing '{class1_name}' vs. '{class2_name}' ---")
        trajs1, trajs2 = all_trajectories_dict.get(class1_name), all_trajectories_dict.get(class2_name)
        if trajs1 is None or len(trajs1) == 0 or trajs2 is None or len(trajs2) == 0:
            print(f"  Warning: Insufficient data. Skipping pair.");
            continue

        runs1, runs2 = all_runs_dict.get(class1_name), all_runs_dict.get(class2_name)
        pair_trajs = np.concatenate([trajs1, trajs2])
        pair_labels = np.array([0] * len(trajs1) + [1] * len(trajs2))
        pair_groups = np.concatenate([runs1, runs2])

        print("  [Step 1/3] Analyzing predictive power...")
        auc_scores = {}
        n_splits = CONFIG['n_splits']
        use_group_kfold = len(np.unique(runs1)) >= n_splits and len(np.unique(runs2)) >= n_splits

        for t in CONFIG['selected_time_points']:
            features_at_t = pair_trajs[:, get_closest_frame_index(t, required_frames), :]
            analysis_name = f"Time {t} min"
            roc_data = train_and_get_roc_data(features_at_t, pair_labels, pair_groups, use_group_kfold, analysis_name)
            if roc_data is None and use_group_kfold:
                print(f"    Warning: StratifiedGroupKFold failed. Retrying with StratifiedKFold.")
                roc_data = train_and_get_roc_data(features_at_t, pair_labels, pair_groups, False, analysis_name)
            if roc_data: auc_scores[analysis_name] = (roc_data[3], roc_data[4])

        print("  [Step 2/3] Saving result data...")
        if auc_scores:
            auc_data_path = os.path.join(CONFIG['analysis_output_dir'],
                                         f"motility_{class1_name}_vs_{class2_name}_auc_scores.csv")
            if not os.path.exists(auc_data_path) or CONFIG['force_rerun']:
                auc_data = [{'Time (min)': int(re.search(r'(\d+)', name).group(1)), 'Mean AUC': mean_auc,
                             'Std Dev AUC': std_auc} for name, (mean_auc, std_auc) in auc_scores.items()]
                pd.DataFrame(auc_data).sort_values('Time (min)').to_csv(auc_data_path, index=False, float_format='%.4f')
                print(f"    Saved AUC scores to: {os.path.basename(auc_data_path)}")
            else:
                print(f"    Skipping existing AUC scores file: {os.path.basename(auc_data_path)}")

        print("  [Step 3/3] Generating analysis plots...")
        title_prefix = f"{class1_name}_vs_{class2_name}: "
        plot_auc_barplot(auc_scores, CONFIG['selected_time_points'], os.path.join(CONFIG['analysis_output_dir'],
                                                                                  f"motility_{class1_name}_vs_{class2_name}_{CONFIG['auc_barplot_name']}"),
                         title_prefix, CONFIG['force_rerun'])

        if CONFIG['dimensionality_reduction']['run']:
            dr_config = CONFIG['dimensionality_reduction']
            output_name = f"motility_{class1_name}_vs_{class2_name}_{dr_config['method']}_plot.pdf"
            plot_dimensionality_reduction(pair_trajs, pair_labels, time_points=CONFIG['selected_time_points'],
                                          method=dr_config['method'], sample_equal=dr_config['sample_equal'],
                                          plot_in_one_figure=dr_config['plot_in_one_figure'],
                                          output_path=os.path.join(CONFIG['analysis_output_dir'], output_name),
                                          legend_map={0: class1_name, 1: class2_name},
                                          show_svm_boundary=dr_config['show_svm_boundary'],
                                          force_rerun=CONFIG['force_rerun'],
                                          plot_mean_features_at_times=dr_config.get('plot_mean_features_at_times'))


def main():
    output_dir = CONFIG['analysis_output_dir']
    os.makedirs(output_dir, exist_ok=True)
    analysis_type = CONFIG['analysis_type']
    print(f"--- Starting Analysis Script for '{analysis_type}' ---")

    if analysis_type == 'WT':
        run_wt_analysis()
    elif analysis_type == 'motility':
        run_motility_analysis()
    else:
        print(f"ERROR: Unknown analysis type '{analysis_type}'. Choose \'WT\' or \'motility\'.")
        return

    print("\n--- Script execution complete ---")


if __name__ == '__main__':
    # Before running, ensure the WORKING_DIR at the top of the script is set correctly.
    if WORKING_DIR == ".":
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print("!!! WARNING: `WORKING_DIR` is not set. Please update it to  !!!")
        print("!!! your project's base directory before running this script. !!!")
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    main()


--- Starting Analysis Script for 'WT' ---

[STAGE 1/4] Loading and preparing WT data...

--- Data Summary ---
  Class 'Aggregate': 220 samples from 34 unique runs.
  Class 'No Aggregate': 132 samples from 38 unique runs.
--------------------


[STAGE 2/4] Analyzing predictive power of features...

[STAGE 3/4] Saving result data...
  Saved WT AUC scores to: WT_auc_scores.csv

[STAGE 4/4] Generating analysis plots...
Generating AUC bar plot: WT_auc_barplot.pdf


## Combine bar plot

In [9]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

import matplotlib as mpl

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['text.usetex'] = False

# ==============================================================================
# Configuration
# ==============================================================================
# --- File Paths ---
# Assumes the CSV files are in the same directory as the script.
# If not, provide the full path to the files.
motility_csv_path = f'{WORKING_DIR}/images/figure6/motility_analysis/motility_A+S-_vs_A-S+_auc_scores.csv'
wt_csv_path = f'{WORKING_DIR}/images/figure6/WT_analysis/WT_auc_scores.csv'
output_plot_path = f'{WORKING_DIR}/images/figure6/motility_analysis/combined_auc_barplot.pdf'

# --- Plotting Parameters ---
time_points_to_plot = [1440, 0]  # The order of time points for each group
bar_colors = ['#812db3', '#812db3', '#51ab4f', '#51ab4f']
figure_size = (12, 6)
y_axis_limit = [0.45, 1.05]


# ==============================================================================
# Main Script
# ==============================================================================

def create_combined_auc_plot():
    """
    Loads AUC score data from two CSV files and plots them on a single
    bar chart in a specified order.
    """
    print("--- Starting Combined AUC Plot Generation ---")

    # --- 1. Load and Validate Data ---
    if not os.path.exists(motility_csv_path):
        print(f"ERROR: Motility data file not found at '{motility_csv_path}'")
        return
    if not os.path.exists(wt_csv_path):
        print(f"ERROR: WT data file not found at '{wt_csv_path}'")
        return

    print(f"Loading motility data from: {motility_csv_path}")
    motility_df = pd.read_csv(motility_csv_path)

    print(f"Loading WT data from: {wt_csv_path}")
    wt_df = pd.read_csv(wt_csv_path)

    # --- 2. Prepare Data for Plotting ---
    plot_data = {
        'labels': [],
        'means': [],
        'stds': []
    }

    # Extract Motility data in the specified order
    for time in time_points_to_plot:
        row = motility_df[motility_df['Time (min)'] == time]
        if not row.empty:
            plot_data['labels'].append(f'A+S- vs A-S+\n{time // 60} h')
            plot_data['means'].append(row['Mean AUC'].iloc[0])
            plot_data['stds'].append(row['Std Dev AUC'].iloc[0])
        else:
            print(f"Warning: Time point {time} min not found in motility data.")

    # Extract WT data in the specified order
    for time in time_points_to_plot:
        row = wt_df[wt_df['Time (min)'] == time]
        if not row.empty:
            plot_data['labels'].append(f'WT Agg. vs No Agg.\n{time // 60} h')
            plot_data['means'].append(row['Mean AUC'].iloc[0])
            plot_data['stds'].append(row['Std Dev AUC'].iloc[0])
        else:
            print(f"Warning: Time point {time} min not found in WT data.")

    if not plot_data['means']:
        print("ERROR: No data was extracted for plotting. Please check CSV files and time points.")
        return

    print(f"Plotting data for labels: {plot_data['labels']}")

    # --- 3. Calculate Asymmetric 95% Confidence Intervals ---
    y_err_lower = []
    y_err_upper = []
    for mean, std in zip(plot_data['means'], plot_data['stds']):
        margin = 1.96 * std  # 95% CI margin
        upper_error = min(mean + margin, 1.0) - mean
        lower_error = mean - max(mean - margin, 0.0)
        y_err_upper.append(upper_error)
        y_err_lower.append(lower_error)

    asymmetric_error = [y_err_lower, y_err_upper]

    # --- 4. Generate the Plot ---
    print("Generating the bar plot...")
    fig, ax = plt.subplots(figsize=figure_size)
    ax.grid(True, axis='y', linestyle='--', linewidth=0.5, zorder=0)

    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    x_pos = np.arange(len(plot_data['labels']))

    ax.bar(x_pos, plot_data['means'], yerr=asymmetric_error,
           color=bar_colors, capsize=5, ecolor='gray', zorder=2)

    ax.set_ylabel('Mean AUC Score')
    # ax.set_title('Comparison of Predictive Power')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(plot_data['labels'])
    ax.set_ylim(y_axis_limit)

    # Add a line for random chance
    ax.axhline(y=0.5, color='r', linestyle='--', label='Random Chance (AUC=0.5)')

    # Create custom legend handles
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor=bar_colors[0], edgecolor=bar_colors[0], label='A+S- vs A-S+'),
        Patch(facecolor=bar_colors[len(time_points_to_plot)], edgecolor=bar_colors[len(time_points_to_plot)],
              label='WT Agg. vs No Agg.')
    ]
    ax.legend(handles=legend_elements, loc='upper right')

    plt.tight_layout()

    # --- 5. Save the Plot ---
    try:
        fig.savefig(output_plot_path, dpi=300)
        print(f"Successfully saved plot to: {output_plot_path}")
    except Exception as e:
        print(f"ERROR: Could not save the plot. Reason: {e}")

    plt.close(fig)
    print("--- Script finished ---")


if __name__ == '__main__':
    create_combined_auc_plot()


--- Starting Combined AUC Plot Generation ---
Loading motility data from: D:/Projects/DAE_project/images/figure6/motility_analysis/motility_A+S-_vs_A-S+_auc_scores.csv
Loading WT data from: D:/Projects/DAE_project/images/figure6/WT_analysis/WT_auc_scores.csv
Plotting data for labels: ['A+S- vs A-S+\n24 h', 'A+S- vs A-S+\n0 h', 'WT Agg. vs No Agg.\n24 h', 'WT Agg. vs No Agg.\n0 h']
Generating the bar plot...
Successfully saved plot to: D:/Projects/DAE_project/images/figure6/motility_analysis/combined_auc_barplot.pdf
--- Script finished ---


In [4]:
# ==============================================================================
# Standalone Script for Image Reconstruction from Mean Features
# ==============================================================================
# Purpose:
# This script loads encoded features for specified classes, calculates the
# mean feature vector for each class and their midpoint, and then uses a
# pre-trained generator network (e.g., StyleGAN) to synthesize representative
# images from these vectors.
#
# It is designed to be independent of the main analysis pipeline.
# ==============================================================================

import os
import re
import numpy as np
import pandas as pd
import torch
import cv2
import warnings

import dnnlib
import legacy

CONFIG = {
    # --- Analysis Target ---
    "analysis_type": "WT",  # "WT" or "motility"
    "time_point_to_reconstruct": 0,  # Time in minutes (e.g., 0, 720, 1440)

    # --- Paths ---
    "network_pkl_path": f"{WORKING_DIR}/models/network-snapshot-001512-patched.pkl",
    "output_dir": f"{WORKING_DIR}/images/figure6/WT_analysis/centers",

    # WT analysis paths
    "features_base_dir_wt": f"{WORKING_DIR}/encoded_features/WT_features",
    "labeling_csv_path_wt": f"{WORKING_DIR}/dataset/WT/labeling_sheet.csv",

    # Motility analysis paths
    "features_base_dir_motility": f"{WORKING_DIR}/encoded_features/Roy_training_features",
    "motility_csv_path": f'{WORKING_DIR}/dataset/Roy_training/merged_strain_data.xlsx',

    # --- Class Selection ---
    # For "WT" analysis_type
    "wt_classes_to_compare": ('F', 'T'),  # ('F' = No Aggregate, 'T' = Aggregate)

    # For "motility" analysis_type
    "motility_classes_to_compare": ('WT', 'A-S-'),  # e.g., ('WT', 'A+S-'), ('A-S+', 'A-S-')
}

os.environ['CC'] = "/usr/bin/gcc-9"
os.environ['CXX'] = "/usr/bin/g++-9"


# ==============================================================================
# ---                          HELPER FUNCTIONS                            ---
# ==============================================================================

def impute_nans_with_previous_frame(trajectory):
    """Fills NaN values in a trajectory with the values from the last valid frame."""
    for i in range(1, trajectory.shape[0]):
        if np.isnan(trajectory[i]).any():
            trajectory[i] = trajectory[i - 1]
    return trajectory


def get_closest_frame_index(requested_frame, total_frames):
    """Finds the valid index for a requested frame number."""
    if total_frames == 0:
        raise ValueError("Cannot get frame index from a trajectory with zero frames.")
    max_index = total_frames - 1
    return min(requested_frame, max_index)


# ==============================================================================
# ---                        CORE RECONSTRUCTION LOGIC                     ---
# ==============================================================================

def reconstruct_images(features_c0, features_c1, class_name_c0, class_name_c1, config):
    """
    Calculates mean vectors and generates images using the StyleGAN generator.
    """
    print("\n--- Starting Image Reconstruction ---")

    if len(features_c0) == 0 or len(features_c1) == 0:
        print(f"Error: Not enough data for one or both classes. "
              f"Found {len(features_c0)} samples for '{class_name_c0}' and "
              f"{len(features_c1)} for '{class_name_c1}'. Aborting.")
        return

    # 1. Calculate Mean Feature Vectors
    print("Calculating mean feature vectors...")
    mean_feature_c0 = np.mean(features_c0, axis=0)
    mean_feature_c1 = np.mean(features_c1, axis=0)
    mean_of_means = (mean_feature_c0 + mean_feature_c1) / 2.0
    print(f"  - Calculated mean for {len(features_c0)} '{class_name_c0}' samples.")
    print(f"  - Calculated mean for {len(features_c1)} '{class_name_c1}' samples.")
    print("  - Calculated midpoint vector.")

    # 2. Setup Device and Load Generator
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    print(f"Loading generator from '{config['network_pkl_path']}'...")
    try:
        with dnnlib.util.open_url(config['network_pkl_path']) as fp:
            models = legacy.load_network_pkl(fp)
            G = models['G_ema'].to(device)
    except Exception as e:
        print(f"FATAL ERROR: Could not load the network PKL file. Check the path. Error: {e}")
        return

    # 3. Generate Images
    print("Synthesizing images from mean vectors...")
    batch_zs = np.vstack([mean_feature_c0, mean_feature_c1, mean_of_means])
    batch_zs_tensor = torch.from_numpy(batch_zs).to(device)

    # Assumes an unconditional generator (class labels `c` is None)
    synth_images = G(batch_zs_tensor, None, noise_mode="const")

    # 4. Post-process and Save Images
    synth_images = (synth_images + 1) * 127.5  # Denormalize from [-1, 1] to [0, 255]
    synth_images = synth_images.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy()

    # Create a unique sub-directory for this reconstruction run
    output_subdir_name = f"{config['analysis_type']}_T{config['time_point_to_reconstruct']}_{class_name_c0}_vs_{class_name_c1}"
    final_output_dir = os.path.join(config['output_dir'], output_subdir_name)
    os.makedirs(final_output_dir, exist_ok=True)
    print(f"Saving images to: {final_output_dir}")

    # Define filenames and save
    paths = {
        "class0": os.path.join(final_output_dir, f"reconstruction_{class_name_c0}.png"),
        "class1": os.path.join(final_output_dir, f"reconstruction_{class_name_c1}.png"),
        "midpoint": os.path.join(final_output_dir, f"reconstruction_Midpoint.png"),
    }

    # OpenCV expects BGR format, so convert from RGB
    cv2.imwrite(paths["class0"], cv2.cvtColor(synth_images[0], cv2.COLOR_RGB2BGR))
    cv2.imwrite(paths["class1"], cv2.cvtColor(synth_images[1], cv2.COLOR_RGB2BGR))
    cv2.imwrite(paths["midpoint"], cv2.cvtColor(synth_images[2], cv2.COLOR_RGB2BGR))

    print(f"  - Saved: {os.path.basename(paths['class0'])}")
    print(f"  - Saved: {os.path.basename(paths['class1'])}")
    print(f"  - Saved: {os.path.basename(paths['midpoint'])}")
    print("--- Reconstruction complete! ---")


# ==============================================================================
# ---                            MAIN EXECUTION                            ---
# ==============================================================================

def main():
    """Main function to load data and orchestrate the reconstruction."""
    cfg = CONFIG

    # --- Data Loading Logic ---
    if cfg['analysis_type'] == 'WT':
        print(f"--- Starting WT Analysis for Reconstruction ---")
        class0_label, class1_label = cfg['wt_classes_to_compare']
        class_name_map = {'F': 'No_Aggregate', 'T': 'Aggregate'}
        class_name_c0, class_name_c1 = class_name_map[class0_label], class_name_map[class1_label]

        print(f"Comparing classes: '{class_name_c0}' ({class0_label}) vs. '{class_name_c1}' ({class1_label})")

        try:
            labels_df = pd.read_csv(cfg['labeling_csv_path_wt'])
        except FileNotFoundError:
            print(f"ERROR: Labeling sheet not found at '{cfg['labeling_csv_path_wt']}'.")
            return

        labels_df = labels_df[labels_df['aggregates_formed'].isin([class0_label, class1_label])]

        features_c0, features_c1 = [], []

        for _, row in labels_df.iterrows():
            run_id, scope_id = row['run_id'], row['scope_id']
            npz_path = os.path.join(cfg['features_base_dir_wt'], f"Run{run_id:04d}", f"Scope{scope_id:02d}",
                                    "features.npz")

            if not os.path.exists(npz_path): continue

            features_data = impute_nans_with_previous_frame(np.load(npz_path)['z'])
            frame_idx = get_closest_frame_index(cfg['time_point_to_reconstruct'], features_data.shape[0])
            feature_vec = features_data[frame_idx]

            if row['aggregates_formed'] == class0_label:
                features_c0.append(feature_vec)
            else:
                features_c1.append(feature_vec)

        reconstruct_images(features_c0, features_c1, class_name_c0, class_name_c1, cfg)

    elif cfg['analysis_type'] == 'motility':
        print(f"--- Starting Motility Analysis for Reconstruction ---")
        class_name_c0, class_name_c1 = cfg['motility_classes_to_compare']
        print(f"Comparing classes: '{class_name_c0}' vs. '{class_name_c1}'")

        try:
            labels_df = pd.read_excel(cfg['motility_csv_path'])
        except FileNotFoundError:
            print(f"ERROR: Motility data not found at '{cfg['motility_csv_path']}'.")
            return

        strain_to_label_map = pd.Series(labels_df.motility.values, index=labels_df.Strain).to_dict()
        dir_pattern = re.compile(r'Run(\d+)_Mutant(\d+)')

        features_c0, features_c1 = [], []

        for dir_name in os.listdir(cfg['features_base_dir_motility']):
            match = dir_pattern.match(dir_name)
            if not match: continue

            mutant_num = int(match.group(2))
            strain_id = f"DK{mutant_num}"

            if strain_id in strain_to_label_map:
                label = strain_to_label_map[strain_id]
                if label not in [class_name_c0, class_name_c1]: continue

                # Find all scope directories within this run/mutant folder
                run_mutant_path = os.path.join(cfg['features_base_dir_motility'], dir_name)
                for scope_dir_name in os.listdir(run_mutant_path):
                    npz_path = os.path.join(run_mutant_path, scope_dir_name, "features.npz")
                    if os.path.exists(npz_path):
                        features_data = impute_nans_with_previous_frame(np.load(npz_path)['z'])
                        frame_idx = get_closest_frame_index(cfg['time_point_to_reconstruct'], features_data.shape[0])
                        feature_vec = features_data[frame_idx]

                        if label == class_name_c0:
                            features_c0.append(feature_vec)
                        else:
                            features_c1.append(feature_vec)

        reconstruct_images(features_c0, features_c1, class_name_c0, class_name_c1, cfg)

    else:
        print(f"ERROR: Unknown analysis type '{cfg['analysis_type']}'. Choose 'WT' or 'motility'.")
        return


if __name__ == '__main__':
    if not os.path.exists(CONFIG["network_pkl_path"]):
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print("!!! WARNING: `network_pkl_path` is not set or file not found.   !!!")
        print("!!! Please update the CONFIG section before running this script.  !!!")
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    else:
        main()


--- Starting WT Analysis for Reconstruction ---
Comparing classes: 'No_Aggregate' (F) vs. 'Aggregate' (T)

--- Starting Image Reconstruction ---
Calculating mean feature vectors...
  - Calculated mean for 132 'No_Aggregate' samples.
  - Calculated mean for 220 'Aggregate' samples.
  - Calculated midpoint vector.
Using device: cuda
Loading generator from '/home/xavier/Documents/DAE_project/models/network-snapshot-001512-patched.pkl'...
Synthesizing images from mean vectors...
Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "upfirdn2d_plugin"... Done.
Saving images to: /home/xavier/Documents/DAE_project/images/figure6/WT_analysis/centers/WT_T0_No_Aggregate_vs_Aggregate
  - Saved: reconstruction_No_Aggregate.png
  - Saved: reconstruction_Aggregate.png
  - Saved: reconstruction_Midpoint.png
--- Reconstruction complete! ---


# Optional

## Reconstruct label centers

## Sample images

In [5]:
# ==============================================================================
# Script: Image Sampler and Processor
# ==============================================================================
# Purpose:
# This script processes image files by first sampling from a labeling sheet.
# 1. Loads a CSV file that classifies experiments (by run_id and scope_id)
#    into different classes (e.g., aggregates formed 'T' or 'F').
# 2. Randomly samples a specified number of experiments from each class.
# 3. For each sampled experiment, it processes a list of specified frame numbers.
# 4. For each frame, it locates the image, resizes, crops, normalizes brightness,
#    and saves the result as a PNG file.
#
# Instructions:
# - Update the CONFIG dictionary with your specific parameters.
# - 'labeling_csv_path': Path to the CSV file with run/scope classifications.
# - 'samples_per_class': How many experiments to randomly sample from each class.
# - 'target_frames_to_process': A list of frame numbers to process for each
#   sampled experiment.
# ==============================================================================

import os
import numpy as np
import pandas as pd
import cv2
import warnings

# Suppress warnings from scikit-image about low contrast images
warnings.filterwarnings("ignore", category=UserWarning, module='skimage.io')

# --- Configuration ---
CONFIG = {
    # --- Paths ---
    "base_image_dir": f"{WORKING_DIR}/dataset/WT/images",
    "output_dir": f"{WORKING_DIR}/images/figure6/WT_analysis/sampled",
    "labeling_csv_path": f"{WORKING_DIR}/dataset/WT/labeling_sheet.csv",

    # --- Sampling Parameters ---
    "samples_per_class": 3,  # Number of experiments to sample from each class
    "target_frames_to_process": [1, 1441],  # Frames to process for each sampled experiment
    "random_seed": 42,  # Seed for reproducible random sampling

    # --- Image Processing Parameters ---
    "resize_by": 1.0,
    "resolution": 512,
    "brightness_norm": True,
    "brightness_mean": 107.2,
    "locations": ["center"],  # Can be a list, e.g., ["left", "center", "right"]
    "crop_offset": 128,
}


def resize_crop(img_name, strain_dir, resize_by=1.0, resolution=512, brightness_norm=True, brightness_mean=107.2,
                locations=None, crop_offset=128):
    """
    Loads, resizes, and crops an image from multiple locations.
    """
    if locations is None:
        locations = ["center"]
    img_path = os.path.join(strain_dir, img_name)
    if not os.path.exists(img_path): return None

    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    if img is None: return None
    if img.dtype != np.uint8: img = np.uint8(img / 256)

    img_h, img_w = img.shape[:2]
    resize_w, resize_h = int(img_w * resize_by), int(img_h * resize_by)

    if resize_by != 1.0:
        img = cv2.resize(img, (resize_w, resize_h), interpolation=cv2.INTER_LANCZOS4)

    cropped_imgs = []
    for location in locations:
        y_start = (resize_h - resolution) // 2
        y_end = y_start + resolution

        if location == "left":
            x_start = crop_offset
        elif location == "right":
            x_start = resize_w - crop_offset - resolution
        else:  # "center" or default
            x_start = (resize_w - resolution) // 2

        x_end = x_start + resolution
        new_img = img[y_start:y_end, x_start:x_end]

        if brightness_norm:
            obj_v = np.mean(new_img)
            value = brightness_mean - obj_v
            new_img = cv2.add(new_img, value)
        cropped_imgs.append((new_img, location))  # Return image and its location
    return cropped_imgs


def process_and_save_frame(run_id, scope_id, frame_num, label, base_dir, output_dir, processing_params):
    """
    Loads, processes, and saves a single image frame.

    Args:
        run_id (int): The run identifier.
        scope_id (int): The scope identifier.
        frame_num (int): The specific frame number to process.
        label (str): The class label for the experiment (e.g., 'T' or 'F').
        base_dir (str): The root directory containing the image data.
        output_dir (str): The directory where the output .png files will be saved.
        processing_params (dict): A dictionary with image processing settings.
    """
    try:
        # --- 1. Find the correct run folder and construct the image path ---
        run_suffix = f"Run{run_id:04d}"
        run_folder_name = None

        try:
            run_folder_name = next(
                d for d in os.listdir(base_dir) if d.endswith(run_suffix) and os.path.isdir(os.path.join(base_dir, d)))
        except StopIteration:
            print(f"Warning: No folder found ending with '{run_suffix}' in '{base_dir}'. Skipping frame {frame_num}.")
            return

        scope_dir = os.path.join(base_dir, run_folder_name, f"Scope{scope_id:02d}")
        image_filename = f"Run{run_id:04d}_scope{scope_id:d}-00_{frame_num:04d}.jpg"

        # --- 2. Process the image using the resize_crop function ---
        cropped_results = resize_crop(
            img_name=image_filename,
            strain_dir=scope_dir,
            **processing_params
        )

        if not cropped_results:
            print(
                f"Warning: Cropping failed for run {run_id}, scope {scope_id}, frame {frame_num}. Image might not exist or be invalid.")
            return

        # --- 3. Save the resulting image(s) ---
        os.makedirs(output_dir, exist_ok=True)

        for img, location in cropped_results:
            # Construct a descriptive output filename including the class label
            output_filename = f"run{run_id:04d}_scope{scope_id:02d}_class{label}_{frame_num:04d}_{location}.png"
            output_path = os.path.join(output_dir, output_filename)

            cv2.imwrite(output_path, img)
            print(f"Successfully processed frame {frame_num} ({location}). Saved to: {output_path}")

    except Exception as e:
        print(f"An error occurred while processing run {run_id}, scope {scope_id}, frame {frame_num}: {e}")


def main():
    """
    Main execution function to sample experiments and process frames.
    """
    print("--- Starting Image Sampler and Processor Script ---")

    # --- 1. Load and prepare the labeling data ---
    try:
        labels_df = pd.read_csv(CONFIG['labeling_csv_path'])
        labels_df = labels_df.dropna(subset=['aggregates_formed', 'run_id', 'scope_id'])
        labels_df = labels_df[labels_df['aggregates_formed'].isin(['T', 'F'])]
    except FileNotFoundError:
        print(f"ERROR: Labeling sheet not found at '{CONFIG['labeling_csv_path']}'. Exiting.")
        return
    except KeyError as e:
        print(f"ERROR: The CSV file is missing a required column: {e}. Exiting.")
        return

    # --- 2. Randomly sample from each class ---
    samples_per_class = CONFIG['samples_per_class']
    class_T_df = labels_df[labels_df['aggregates_formed'] == 'T']
    class_F_df = labels_df[labels_df['aggregates_formed'] == 'F']

    if len(class_T_df) < samples_per_class or len(class_F_df) < samples_per_class:
        print("Warning: Not enough samples in the CSV for the requested number.")
        print(f"  - Class 'T' has {len(class_T_df)} samples.")
        print(f"  - Class 'F' has {len(class_F_df)} samples.")
        print(f"  - Requested {samples_per_class} samples per class.")
        # Adjusting sample count to the minimum available
        samples_per_class = min(len(class_T_df), len(class_F_df))
        if samples_per_class == 0:
            print("ERROR: Cannot proceed with 0 samples in one of the classes. Exiting.")
            return
        print(f"  Proceeding with {samples_per_class} samples per class.")

    sampled_T = class_T_df.sample(n=samples_per_class, random_state=CONFIG['random_seed'])
    sampled_F = class_F_df.sample(n=samples_per_class, random_state=CONFIG['random_seed'])

    combined_samples = pd.concat([sampled_T, sampled_F])
    print(f"\nSuccessfully sampled {len(combined_samples)} total experiments.")

    # --- 3. Process the sampled frames ---
    base_dir = CONFIG["base_image_dir"]
    output_dir = CONFIG["output_dir"]
    target_frames = CONFIG["target_frames_to_process"]

    processing_params = {
        "resize_by": CONFIG["resize_by"],
        "resolution": CONFIG["resolution"],
        "brightness_norm": CONFIG["brightness_norm"],
        "brightness_mean": CONFIG["brightness_mean"],
        "locations": CONFIG["locations"],
        "crop_offset": CONFIG["crop_offset"],
    }

    print(f"Configuration:")
    print(f"  - Base Directory: {base_dir}")
    print(f"  - Output Directory: {output_dir}")
    print(f"  - Frames to process per sample: {target_frames}")

    # Loop through each sampled experiment
    for index, row in combined_samples.iterrows():
        run_id = int(row['run_id'])
        scope_id = int(row['scope_id'])
        label = row['aggregates_formed']

        print(f"\nProcessing sampled experiment: Run {run_id}, Scope {scope_id} (Class: {label})")

        # Loop through each specified frame number for that experiment
        for frame in target_frames:
            process_and_save_frame(run_id, scope_id, frame, label, base_dir, output_dir, processing_params)

    print("\n--- Script execution complete ---")


if __name__ == '__main__':
    main()

--- Starting Image Sampler and Processor Script ---

Successfully sampled 6 total experiments.
Configuration:
  - Base Directory: /home/xavier/Documents/DAE_project/dataset/WT/images
  - Output Directory: /home/xavier/Documents/DAE_project/images/figure6/WT_analysis
  - Frames to process per sample: [1, 1441]

Processing sampled experiment: Run 220, Scope 46 (Class: T)
Successfully processed frame 1 (center). Saved to: /home/xavier/Documents/DAE_project/images/figure6/WT_analysis/run0220_scope46_classT_0001_center.png
Successfully processed frame 1441 (center). Saved to: /home/xavier/Documents/DAE_project/images/figure6/WT_analysis/run0220_scope46_classT_1441_center.png

Processing sampled experiment: Run 373, Scope 40 (Class: T)
Successfully processed frame 1 (center). Saved to: /home/xavier/Documents/DAE_project/images/figure6/WT_analysis/run0373_scope40_classT_0001_center.png
Successfully processed frame 1441 (center). Saved to: /home/xavier/Documents/DAE_project/images/figure6/WT_a