In [4]:
import pandas as pd
import os
import numpy as np
import seaborn as sns 
import plotly.express as px
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.cm as cm
from pathlib import Path


# Use the parent directory of this file for results
# results_dir = os.getcwd()
results_dir = "/net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250917"
data_dir = os.path.join(results_dir, "data")
plot_dir = os.path.join(results_dir, "plots")

print(f"Results directory: {results_dir}")
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)


morphseq_root = os.environ.get('MORPHSEQ_REPO_ROOT')
morphseq_root = "/net/trapnell/vol1/home/mdcolon/proj/morphseq"
print(f"MORPHSEQ_REPO_ROOT: {morphseq_root}")
os.chdir(morphseq_root)

from src.functions.embryo_df_performance_metrics import *
from src.functions.spline_morph_spline_metrics import *





Results directory: /net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250917
MORPHSEQ_REPO_ROOT: /net/trapnell/vol1/home/mdcolon/proj/morphseq


In [5]:
# Import TZ experiments
experiments = ["20250305",'20250415', '20250416', '20250425', '20250501', 
               '20250509', '20250512', '20250515_part2', '20250519', '20250520', 
               '20250626', '20250703', '20250711', '20250725', '20250728']

build06_dir = "/net/trapnell/vol1/home/mdcolon/proj/morphseq/morphseq_playground/metadata/build06_output"

# Load all experiments
dfs = []
for exp in experiments:
    try:
        file_path = f"{build06_dir}/df03_final_output_with_latents_{exp}.csv"
        df = pd.read_csv(file_path)
        df['source_experiment'] = exp
        dfs.append(df)
        print(f"Loaded {exp}: {len(df)} rows")
    except:
        print(f"Missing: {exp}")

# Combine all data
combined_df = pd.concat(dfs, ignore_index=True)
print(f"\nTotal: {len(combined_df)} rows from {len(dfs)} experiments")

# Save combined data
combined_df.to_csv(os.path.join(data_dir, "TZ_experiments_combined_20250917.csv"), index=False)

Missing: 20250305
Missing: 20250415
Loaded 20250416: 3523 rows
Missing: 20250425
Missing: 20250501
Missing: 20250509
Loaded 20250512: 7061 rows
Loaded 20250515_part2: 981 rows
Loaded 20250519: 3095 rows
Loaded 20250520: 1374 rows
Missing: 20250626
Missing: 20250703
Loaded 20250711: 8330 rows
Missing: 20250725
Missing: 20250728

Total: 24364 rows from 6 experiments


In [6]:
# Breakdown by experiment
print("=" * 60)
print("GENOTYPE BREAKDOWN BY EXPERIMENT")
print("=" * 60)

for exp in sorted(combined_df['source_experiment'].unique()):
    exp_df = combined_df[combined_df['source_experiment'] == exp]
    print(f"\nExperiment {exp} (Total: {len(exp_df)} rows):")
    print("-" * 40)
    genotype_counts = exp_df['phenotype'].value_counts()
    for genotype, count in genotype_counts.items():
        percentage = (count / len(exp_df)) * 100
        print(f"  {genotype}: {count} ({percentage:.1f}%)")

# Summary table
print("\n" + "=" * 60)
print("SUMMARY TABLE")
print("=" * 60)
summary_table = combined_df.groupby(['source_experiment', 'phenotype']).size().unstack(fill_value=0)
print(summary_table)

# Add totals
summary_table.loc['TOTAL'] = summary_table.sum()
summary_table['TOTAL'] = summary_table.sum(axis=1)

print("\nWith totals:")
print(summary_table)

GENOTYPE BREAKDOWN BY EXPERIMENT

Experiment 20250416 (Total: 3523 rows):
----------------------------------------

Experiment 20250512 (Total: 7061 rows):
----------------------------------------
  cep290_heterozygous: 2560 (36.3%)
  cep290_wildtype: 2197 (31.1%)
  cep290_homozygous: 1246 (17.6%)
  cep290_unkown: 1058 (15.0%)

Experiment 20250515_part2 (Total: 981 rows):
----------------------------------------
  cep290_heterozygous: 514 (52.4%)
  cep290_homozygous: 217 (22.1%)
  cep290_wildtype: 179 (18.2%)
  cep290_unkown: 71 (7.2%)

Experiment 20250519 (Total: 3095 rows):
----------------------------------------
  b9d2_heterozygous: 1398 (45.2%)
  b9d2_wildtype: 732 (23.7%)
  b9d2_homozygous: 658 (21.3%)
  b9d2_unkown: 307 (9.9%)

Experiment 20250520 (Total: 1374 rows):
----------------------------------------
  b9d2_unkown: 1374 (100.0%)

Experiment 20250711 (Total: 8330 rows):
----------------------------------------
  tmem67_heterozygote: 3183 (38.2%)
  tmem67_homozygous: 2966 (

In [7]:
# Fix genotype naming
print("Current genotypes:")
print(combined_df['genotype'].value_counts())

# Convert genotype names
combined_df['genotype'] = combined_df['genotype'].replace({
    'cep290_het': 'cep290_heterozygous',
    'cep290_homo': 'cep290_homozygous', 
    'wildtype': 'cep290_wildtype',
    "cep290_unkown": "cep290_unknown",
    "b9d2_unkown": "b9d2_unknown",
    "tmem67_heterozygote": "tmem67_heterozygous",
    # "cep290_hetorzygote": "cep290_heterozygous",
    # "cep290_homozygote": "cep290_homozygous"
})

print("\nAfter renaming:")
print(combined_df['genotype'].value_counts())

Current genotypes:
genotype
tmem67_heterozygote    3183
cep290_heterozygous    3074
tmem67_homozygous      2966
cep290_wildtype        2376
tmem67_wildtype        2181
b9d2_unkown            1681
cep290_homozygous      1463
b9d2_heterozygous      1398
cep290_unkown          1129
b9d2_wildtype           732
b9d2_homozygous         658
Name: count, dtype: int64

After renaming:
genotype
tmem67_heterozygous    3183
cep290_heterozygous    3074
tmem67_homozygous      2966
cep290_wildtype        2376
tmem67_wildtype        2181
b9d2_unknown           1681
cep290_homozygous      1463
b9d2_heterozygous      1398
cep290_unknown         1129
b9d2_wildtype           732
b9d2_homozygous         658
Name: count, dtype: int64


In [8]:
# Breakdown by experiment
print("=" * 60)
print("GENOTYPE BREAKDOWN BY EXPERIMENT")
print("=" * 60)

for exp in sorted(combined_df['source_experiment'].unique()):
    exp_df = combined_df[combined_df['source_experiment'] == exp]
    print(f"\nExperiment {exp} (Total: {len(exp_df)} rows):")
    print("-" * 40)
    genotype_counts = exp_df['genotype'].value_counts()
    for genotype, count in genotype_counts.items():
        percentage = (count / len(exp_df)) * 100
        print(f"  {genotype}: {count} ({percentage:.1f}%)")

# Summary table
print("\n" + "=" * 60)
print("SUMMARY TABLE")
print("=" * 60)
summary_table = combined_df.groupby(['source_experiment', 'genotype']).size().unstack(fill_value=0)
print(summary_table)

# Add totals
summary_table.loc['TOTAL'] = summary_table.sum()
summary_table['TOTAL'] = summary_table.sum(axis=1)

print("\nWith totals:")
print(summary_table)

GENOTYPE BREAKDOWN BY EXPERIMENT

Experiment 20250416 (Total: 3523 rows):
----------------------------------------

Experiment 20250512 (Total: 7061 rows):
----------------------------------------
  cep290_heterozygous: 2560 (36.3%)
  cep290_wildtype: 2197 (31.1%)
  cep290_homozygous: 1246 (17.6%)
  cep290_unknown: 1058 (15.0%)

Experiment 20250515_part2 (Total: 981 rows):
----------------------------------------
  cep290_heterozygous: 514 (52.4%)
  cep290_homozygous: 217 (22.1%)
  cep290_wildtype: 179 (18.2%)
  cep290_unknown: 71 (7.2%)

Experiment 20250519 (Total: 3095 rows):
----------------------------------------
  b9d2_heterozygous: 1398 (45.2%)
  b9d2_wildtype: 732 (23.7%)
  b9d2_homozygous: 658 (21.3%)
  b9d2_unknown: 307 (9.9%)

Experiment 20250520 (Total: 1374 rows):
----------------------------------------
  b9d2_unknown: 1374 (100.0%)

Experiment 20250711 (Total: 8330 rows):
----------------------------------------
  tmem67_heterozygous: 3183 (38.2%)
  tmem67_homozygous: 29

In [9]:
# Use only z_mu_b columns and filter out bad rows by experiment
z_mu_cols = [col for col in combined_df.columns if col.startswith('z_mu_b_')]
print(f"Found {len(z_mu_cols)} z_mu_b columns")

# Check NaN values by experiment
print("\nNaN analysis by experiment:")
for exp in combined_df['source_experiment'].unique():
    exp_data = combined_df[combined_df['source_experiment'] == exp]
    rows_with_nan = exp_data[z_mu_cols].isnull().any(axis=1)
    print(f"{exp}: {rows_with_nan.sum()} NaN rows out of {len(exp_data)} total ({rows_with_nan.sum()/len(exp_data)*100:.1f}%)")

# Overall NaN check
rows_with_nan = combined_df[z_mu_cols].isnull().any(axis=1)
print(f"\nOverall: {rows_with_nan.sum()} NaN rows out of {len(combined_df)} total")

# Filter out rows with NaN values
clean_df = combined_df[~rows_with_nan].copy()
print(f"After cleaning: {len(clean_df)} rows remaining")

# Show remaining data by experiment
print("\nRemaining data by experiment:")
for exp in clean_df['source_experiment'].unique():
    exp_count = len(clean_df[clean_df['source_experiment'] == exp])
    print(f"{exp}: {exp_count} rows")

# Update combined_df to be the clean version
combined_df = clean_df

Found 80 z_mu_b columns

NaN analysis by experiment:
20250416: 0 NaN rows out of 3523 total (0.0%)
20250512: 0 NaN rows out of 7061 total (0.0%)
20250515_part2: 0 NaN rows out of 981 total (0.0%)
20250519: 0 NaN rows out of 3095 total (0.0%)
20250520: 0 NaN rows out of 1374 total (0.0%)


20250711: 0 NaN rows out of 8330 total (0.0%)

Overall: 0 NaN rows out of 24364 total
After cleaning: 24364 rows remaining

Remaining data by experiment:
20250416: 3523 rows
20250512: 7061 rows
20250515_part2: 981 rows
20250519: 3095 rows
20250520: 1374 rows
20250711: 8330 rows


In [10]:
# Check genotypes available
print("Available genotypes:")
print(combined_df['genotype'].value_counts())

print("\nAvailable phenotypes:")
if 'phenotype' in combined_df.columns:
    print(combined_df['phenotype'].value_counts())
else:
    print("No phenotype column found")

print("\nAvailable perturbations:")
if 'chem_perturbation' in combined_df.columns:
    print(combined_df['chem_perturbation'].value_counts())

Available genotypes:
genotype
tmem67_heterozygous    3183
cep290_heterozygous    3074
tmem67_homozygous      2966
cep290_wildtype        2376
tmem67_wildtype        2181
b9d2_unknown           1681
cep290_homozygous      1463
b9d2_heterozygous      1398
cep290_unknown         1129
b9d2_wildtype           732
b9d2_homozygous         658
Name: count, dtype: int64

Available phenotypes:
phenotype
tmem67_heterozygote    3183
cep290_heterozygous    3074
tmem67_homozygous      2966
cep290_wildtype        2376
tmem67_wildtype        2181
b9d2_unkown            1681
cep290_homozygous      1463
b9d2_heterozygous      1398
cep290_unkown          1129
b9d2_wildtype           732
b9d2_homozygous         658
Name: count, dtype: int64

Available perturbations:
Series([], Name: count, dtype: int64)


fixing naming errors in plate metadta


In [11]:
def build_splines_and_segments(
    df,
    model_index,
    LocalPrincipalCurveClass,
    save_dir = None,
    comparisons=None,
    group_by_col="genotype",
    z_mu_biological_columns=None,
    n_components=3,
    bandwidth=1.0,
    max_iter=250,
    tol=1e-3,
    angle_penalty_exp=2,
    early_stage_offset=1.0,
    late_stage_offset=3.0,
    k=50
):
    """
    1) Builds splines for each group in `comparisons` using LocalPrincipalCurve
    2) Creates `df_augmented` by assigning segment IDs for each group
    3) Returns `pert_splines`, `df_augmented`, and `segment_info_df`
    
    Parameters
    ----------
    df : pd.DataFrame
        Input DataFrame containing at least [group_by_col, "predicted_stage_hpf"] and either
        ["PCA_1", "PCA_2", "PCA_3"] OR z_mu_biological_columns for PCA computation.
    comparisons : list
        List of group values to process.
    group_by_col : str
        Column name to group by (default: "genotype"). Can be any column including integer columns.
    z_mu_biological_columns : list, optional
        List of column names to use for PCA if PCA columns don't exist. If None and PCA columns
        are missing, will attempt to auto-detect biological feature columns.
    n_components : int, optional
        Number of PCA components to compute (default: 3).
    save_dir : str
        Directory to save the spline CSV if desired.
    model_index : int
        Model index used in naming output files.
    LocalPrincipalCurveClass : class
        Reference to your LocalPrincipalCurve class (or a similar spline-fitting class).
    bandwidth : float
        Bandwidth parameter for LocalPrincipalCurve.
    max_iter : int
        Max iterations for LocalPrincipalCurve fitting.
    tol : float
        Tolerance for LocalPrincipalCurve convergence.
    angle_penalty_exp : int
        Angle penalty exponent for LocalPrincipalCurve.
    early_stage_offset : float
        Window (in hours) for selecting "early" timepoints to compute the average start point.
    late_stage_offset : float
        Window (in hours) for selecting "late" timepoints to compute the average end point.
    k : int
        Number of segments to split each spline into.

    Returns
    -------
    pert_splines : pd.DataFrame
        DataFrame containing the spline points for each group.
    df_augmented : pd.DataFrame
        Original DataFrame plus a `segment_id` column.
    segment_info_df : pd.DataFrame
        Per-segment PCA info (principal_axis, midpoint, etc.).
    """

    # ----------------------------
    # 0. Check for PCA columns and apply PCA if needed
    # ----------------------------
    pca_columns = ["PCA_1", "PCA_2", "PCA_3"]
    has_pca = all(col in df.columns for col in pca_columns)

    if not has_pca:
        print("PCA columns not found. Applying PCA...")

        # If z_mu_biological_columns not provided, use z_mu_b columns
        if z_mu_biological_columns is None:
            # Look for z_mu_b columns first
            z_mu_b_cols = [col for col in df.columns if col.startswith('z_mu_b_')]
            
            if z_mu_b_cols:
                z_mu_biological_columns = z_mu_b_cols
                print(f"Using z_mu_b columns for PCA: {len(z_mu_b_cols)} columns")
            else:
                # Fallback to other biological feature patterns
                potential_cols = [col for col in df.columns if any(pattern in col.lower()
                                for pattern in ['z_mu', 'embedding', 'feature', 'latent', 'biological'])]
                z_mu_biological_columns = potential_cols

        if not z_mu_biological_columns:
            raise ValueError("No suitable columns found for PCA. Please provide z_mu_biological_columns parameter.")

        print(f"Using {len(z_mu_biological_columns)} columns for PCA")

        # Filter out rows with NaN values before PCA
        print(f"Rows before NaN filtering: {len(df)}")
        df_clean = df.dropna(subset=z_mu_biological_columns)
        print(f"Rows after NaN filtering: {len(df_clean)}")

        # Apply PCA using the existing function
        df = apply_pca_on_pert_comparisons(
            df=df_clean,
            z_mu_biological_columns=z_mu_biological_columns,
            pert_comparisons=None,  # Use all data for PCA
            n_components=n_components
        )
        print("PCA applied successfully.")
    else:
        print("PCA columns found. Using existing PCA coordinates.")

    # ----------------------------
    # 1. Generate a color palette
    # and Handle `comparisons` Parameter
    # ----------------------------
    if comparisons is None:
        comparisons = list(df[group_by_col].unique())
        print(f"No comparisons specified. Using all available {group_by_col} values: {comparisons}")
    else:
        print(f"Using specified comparisons for spline building: {comparisons}")
        

    color_palette = px.colors.qualitative.Plotly
    if len(comparisons) > len(color_palette):
        extended_palette = color_palette * (len(comparisons) // len(color_palette) + 1)
    else:
        extended_palette = color_palette

    # ----------------------------
    # 2. Build Spline Data
    # ----------------------------
    print(f"Building spline data for each {group_by_col}...")
    splines_records = []

    df = df[df[group_by_col].isin(comparisons)]
    
    for idx, pert in enumerate(tqdm(comparisons, desc=f"Creating splines for each {group_by_col}")):
        # Filter the DataFrame for the given group
        pert_df = df[df[group_by_col] == pert].copy()
        if pert_df.empty:
            # If no data points for this phenotype, skip
            continue

        # Extract PCA coordinates
        pert_3d = pert_df[["PCA_1", "PCA_2", "PCA_3"]].values
        
        # Compute average early stage point
        min_time = pert_df["predicted_stage_hpf"].min()
        early_mask = (pert_df["predicted_stage_hpf"] >= min_time) & \
                     (pert_df["predicted_stage_hpf"] < min_time + early_stage_offset)
        avg_early_timepoint = pert_df.loc[early_mask, ["PCA_1", "PCA_2", "PCA_3"]].mean().values
        
        # Compute average late stage point
        max_time = pert_df["predicted_stage_hpf"].max()
        late_mask = (pert_df["predicted_stage_hpf"] >= (max_time - late_stage_offset))
        avg_late_timepoint = pert_df.loc[late_mask, ["PCA_1", "PCA_2", "PCA_3"]].mean().values

        # Downsample for curve fitting (example: 5% for wt, 10% for others)
        if len(pert_3d) == 0:
            continue
        
        if pert == "wt":
            subset_size = max(1, int(0.05 * len(pert_3d)))
        else:
            subset_size = max(1, int(0.10 * len(pert_3d)))

        # Randomly select a subset of points for fitting
        rng = np.random.RandomState(42)
        subset_indices = rng.choice(len(pert_3d), size=subset_size, replace=False)
        pert_3d_subset = pert_3d[subset_indices, :]

        # Fit LocalPrincipalCurve
        lpc = LocalPrincipalCurveClass(
            bandwidth=bandwidth,
            max_iter=max_iter,
            tol=tol,
            angle_penalty_exp=angle_penalty_exp
        )
        
        # Fit with the optional start_points/end_point to anchor the spline
        lpc.fit(
            pert_3d_subset,
            start_points=avg_early_timepoint,
            end_point=avg_late_timepoint,
            remove_similar_end_start_points=True
        )
        
        spline_points = None
        if len(lpc.cubic_splines) > 0:
            # If your local principal curve class stores the final spline
            spline_points = lpc.cubic_splines[0]
        else:
            # If no spline was built, skip
            continue
        
        # Create a temporary DataFrame for the current spline
        spline_df = pd.DataFrame(spline_points, columns=["PCA_1", "PCA_2", "PCA_3"])
        spline_df[group_by_col] = pert
        
        # Collect for later concatenation
        splines_records.append(spline_df)

    # Concatenate all spline DataFrames
    if splines_records:
        pert_splines = pd.concat(splines_records, ignore_index=True)
    else:
        # Fallback to an empty DataFrame if no splines
        pert_splines = pd.DataFrame(columns=["PCA_1", "PCA_2", "PCA_3", group_by_col])

    # Optionally, save the spline data
    if save_dir:
        spline_csv_path = os.path.join(save_dir, f"pert_splines_{model_index}_unique.csv")
        pert_splines.to_csv(spline_csv_path, index=False)
        print(f"Spline DataFrame 'pert_splines' saved to: {spline_csv_path}")

    # ----------------------------
    # 3. Create segments for each group using the function above
    # ----------------------------
    print("Assigning segments and building segment_info_df...")
    df_augmented, segment_info_df, pert_splines_out = create_spline_segments_for_df(
        df=df,
        pert_splines=pert_splines,
        k=k,
        group_by_col=group_by_col
    )

    # Return all three final structures
    return pert_splines_out, df_augmented, segment_info_df

In [12]:
# Build splines using the proper functions from the imported modules
model_index = 74  # Using same model index as general_exam_plots

# Build splines and segments using the imported function
pert_splines, df_augmented, segment_info_df = build_splines_and_segments(
    df=combined_df,
    model_index=model_index,
    LocalPrincipalCurveClass=LocalPrincipalCurve,
    bandwidth=0.5,
    max_iter=250,
    tol=1e-3,
    angle_penalty_exp=2,
    early_stage_offset=1.0,
    late_stage_offset=3.0,
    k=50,
    save_dir=data_dir,
    group_by_col="genotype",
)

print(f"Built splines:")
print(f"pert_splines shape: {pert_splines.shape}")
print(f"df_augmented shape: {df_augmented.shape}")
print(f"segment_info_df shape: {segment_info_df.shape}")

# Save the splines data
pert_splines.to_csv(os.path.join(data_dir, "TZ_pert_splines.csv"), index=False)
df_augmented.to_csv(os.path.join(data_dir, "TZ_df_augmented.csv"), index=False)

PCA columns not found. Applying PCA...
Using z_mu_b columns for PCA: 80 columns
Using 80 columns for PCA
Rows before NaN filtering: 24364
Rows after NaN filtering: 24364
No perturbations specified. Using all available phenotypes: [nan, 'cep290_wildtype', 'cep290_heterozygous', 'cep290_unkown', 'cep290_homozygous', 'b9d2_homozygous', 'b9d2_heterozygous', 'b9d2_unkown', 'b9d2_wildtype', 'tmem67_heterozygote', 'tmem67_homozygous', 'tmem67_wildtype']

PCA Explained Variance:
----------------------------
Total Explained Variance by 3 components: 90.80%

Principal Component Explained Variance Ratio
              PCA_1                   53.08%
              PCA_2                   28.19%
              PCA_3                    9.52%
----------------------------

PCA applied successfully.
No comparisons specified. Using all available genotype values: [nan, 'cep290_wildtype', 'cep290_heterozygous', 'cep290_unknown', 'cep290_homozygous', 'b9d2_homozygous', 'b9d2_heterozygous', 'b9d2_unknown', 'b9

Creating splines for each genotype:   0%|                                                                                                                                                   | 0/12 [00:00<?, ?it/s]

Starting point not in dataset. Using closest point: [ 1.4993594   1.97512895 -0.71960684]


Creating splines for each genotype:  17%|███████████████████████▏                                                                                                                   | 2/12 [00:00<00:04,  2.36it/s]

Starting point not in dataset. Using closest point: [ 1.47409282  2.2556861  -0.47994114]


Creating splines for each genotype:  33%|██████████████████████████████████████████████▎                                                                                            | 4/12 [00:01<00:03,  2.14it/s]

Starting point not in dataset. Using closest point: [ 1.45095666  2.03813221 -0.73111303]
Starting point not in dataset. Using closest point: [ 1.33225687  2.2381893  -0.668478  ]


Creating splines for each genotype:  50%|█████████████████████████████████████████████████████████████████████▌                                                                     | 6/12 [00:02<00:02,  2.91it/s]

Starting point not in dataset. Using closest point: [ 1.44256462  1.92026924 -0.66068651]
Starting point not in dataset. Using closest point: [ 1.46172947  2.03082142 -0.51567736]


Creating splines for each genotype:  58%|█████████████████████████████████████████████████████████████████████████████████                                                          | 7/12 [00:02<00:01,  2.97it/s]

Starting point not in dataset. Using closest point: [ 1.46103654  1.68378629 -0.57256132]


Creating splines for each genotype:  75%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                  | 9/12 [00:03<00:00,  3.84it/s]

Starting point not in dataset. Using closest point: [ 1.49800284  1.86209737 -0.4584709 ]
Starting point not in dataset. Using closest point: [ 1.84819709 -1.41199512  0.39698485]


Creating splines for each genotype:  83%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████                       | 10/12 [00:03<00:00,  2.47it/s]

Starting point not in dataset. Using closest point: [ 1.99937174 -0.76604345  0.37670129]


Creating splines for each genotype:  92%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌           | 11/12 [00:04<00:00,  2.40it/s]

Starting point not in dataset. Using closest point: [ 1.68181097  1.85008086 -0.53805412]


Creating splines for each genotype: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:04<00:00,  2.42it/s]


Spline DataFrame 'pert_splines' saved to: /net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250917/data/pert_splines_74_unique.csv
Assigning segments and building segment_info_df...


Processing genotype: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:16<00:00,  1.41s/it]


Built splines:
pert_splines shape: (5500, 4)
df_augmented shape: (24364, 276)
segment_info_df shape: (550, 16)


In [13]:
# Copy the plot_pca_with_splines function from general_exam_plots
def plot_pca_with_splines(df_points, df_splines, 
                         point_opacity=0.7, 
                         save_dir=None, 
                         downsample_dict=None, 
                         color_dict=None,
                         title="PCA Plot with Splines",
                         min_snip_count=20,
                         phenotypes_of_interest=None,
                         show_splines=True,
                         spline_width=6,
                         group_by_col='genotype',
                         filename=None,
                         color_by=None,  # New parameter for coloring by continuous variable
                         colorscale='viridis'):  # Colorscale for continuous variables
    """Plot PCA points and their splines with customizable options."""
    import plotly.graph_objects as go
    import plotly.express as px
    import os
    
    # Filter by phenotypes of interest first
    if phenotypes_of_interest is not None:
        df_points = df_points[df_points[group_by_col].isin(phenotypes_of_interest)].copy()
        if df_splines is not None:
            df_splines = df_splines[df_splines[group_by_col].isin(phenotypes_of_interest)].copy()
    
    # Filter embryos based on snip count
    embryo_counts = df_points.groupby('embryo_id')['snip_id'].count()
    valid_embryos = embryo_counts[embryo_counts >= min_snip_count].index.tolist()
    
    # Filter out embryos with too few snips
    removed_embryos = set(df_points['embryo_id'].unique()) - set(valid_embryos)
    if removed_embryos:
        print(f"Removed {len(removed_embryos)} embryos with fewer than {min_snip_count} snips.")
    
    # Filter dataframe to include only valid embryos
    filtered_df = df_points[df_points['embryo_id'].isin(valid_embryos)].copy()
    
    # Check if any data remains after filtering
    if filtered_df.empty:
        print("Warning: No data remaining after filtering.")
        fig = go.Figure()
        fig.update_layout(title="No Data Available")
        return fig
    
    # Get unique phenotypes (after filtering)
    phenotypes = filtered_df[group_by_col].unique()
    
    # Create figure
    fig = go.Figure()
    
    # If color_by is specified, color by continuous variable
    if color_by is not None and color_by in filtered_df.columns:
        print(f"Coloring points by {color_by}")
        
        # FIXED: Calculate global color range across ALL data going into plot
        global_min = filtered_df[color_by].min()
        global_max = filtered_df[color_by].max()
        print(f"Global {color_by} range: {global_min:.2f} to {global_max:.2f}")
        
        # Process each phenotype separately to maintain grouping
        for phenotype in phenotypes:
            # Filter points for this phenotype
            points = filtered_df[filtered_df[group_by_col] == phenotype].copy()
            
            # Apply downsampling if specified
            if downsample_dict is not None and phenotype in downsample_dict:
                fraction = downsample_dict[phenotype]
                if 0 < fraction < 1:
                    points = points.sample(frac=fraction, random_state=42)
            
            # FIXED: Create consistent hover text with all information
            hover_text = []
            for _, row in points.iterrows():
                hover_info = (
                    f"<b>{group_by_col.title()}: {row[group_by_col]}</b><br>"
                    f"Embryo ID: {row['embryo_id']}<br>"
                    f"Snip ID: {row['snip_id']}<br>"
                    f"Predicted Stage (hpf): {row['predicted_stage_hpf']:.2f}<br>"
                    f"PCA_1: {row['PCA_1']:.3f}<br>"
                    f"PCA_2: {row['PCA_2']:.3f}<br>"
                    f"PCA_3: {row['PCA_3']:.3f}<br>"
                    f"{color_by}: {row[color_by]:.2f}"
                )
                hover_text.append(hover_info)
            
            # Add points to plot with GLOBAL color scale
            fig.add_trace(
                go.Scatter3d(
                    x=points['PCA_1'],
                    y=points['PCA_2'],
                    z=points['PCA_3'],
                    mode='markers',
                    marker=dict(
                        size=4,
                        color=points[color_by],
                        colorscale=colorscale,
                        opacity=point_opacity,
                        showscale=True,
                        colorbar=dict(title=color_by),
                        cmin=global_min,  # FIXED: Set global color range
                        cmax=global_max   # FIXED: Set global color range
                    ),
                    name=f"{phenotype}",
                    showlegend=True,
                    hovertemplate='%{text}<extra></extra>',
                    text=hover_text
                )
            )
            
            # Add spline for this phenotype
            if show_splines and df_splines is not None:
                spline = df_splines[df_splines[group_by_col] == phenotype]
                if not spline.empty:
                    # Use a solid color for splines when coloring by continuous variable
                    spline_color = px.colors.qualitative.Plotly[list(phenotypes).index(phenotype) % len(px.colors.qualitative.Plotly)]
                    fig.add_trace(
                        go.Scatter3d(
                            x=spline['PCA_1'],
                            y=spline['PCA_2'],
                            z=spline['PCA_3'],
                            mode='lines',
                            line=dict(
                                color=spline_color,
                                width=spline_width
                            ),
                            name=f"{phenotype} (spline)",
                            showlegend=True
                        )
                    )
    
    else:
        # Original behavior: color by group_by_col
        # Create color dictionary if not provided
        if color_dict is None:
            default_colors = [
                '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
                '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
            ]
            color_dict = {phenotype: default_colors[i % len(default_colors)] for i, phenotype in enumerate(phenotypes)}
        
        # Process each phenotype
        for phenotype in phenotypes:
            # Filter points for this phenotype
            points = filtered_df[filtered_df[group_by_col] == phenotype].copy()
            
            # Apply downsampling if specified
            if downsample_dict is not None and phenotype in downsample_dict:
                fraction = downsample_dict[phenotype]
                if 0 < fraction < 1:
                    points = points.sample(frac=fraction, random_state=42)
            
            # Get color for this phenotype
            color = color_dict.get(phenotype, '#1f77b4')
            
            # FIXED: Create consistent hover text with all information
            hover_text = []
            for _, row in points.iterrows():
                hover_info = (
                    f"<b>{group_by_col.title()}: {row[group_by_col]}</b><br>"
                    f"Embryo ID: {row['embryo_id']}<br>"
                    f"Snip ID: {row['snip_id']}<br>"
                    f"Predicted Stage (hpf): {row['predicted_stage_hpf']:.2f}<br>"
                    f"PCA_1: {row['PCA_1']:.3f}<br>"
                    f"PCA_2: {row['PCA_2']:.3f}<br>"
                    f"PCA_3: {row['PCA_3']:.3f}"
                )
                hover_text.append(hover_info)
            
            # Add points to plot
            fig.add_trace(
                go.Scatter3d(
                    x=points['PCA_1'],
                    y=points['PCA_2'],
                    z=points['PCA_3'],
                    mode='markers',
                    marker=dict(
                        size=4,
                        color=color,
                        opacity=point_opacity
                    ),
                    name=f"{phenotype} (points)",
                    showlegend=True,
                    hovertemplate='%{text}<extra></extra>',
                    text=hover_text
                )
            )
            
            # Add spline for this phenotype
            if show_splines and df_splines is not None:
                spline = df_splines[df_splines[group_by_col] == phenotype]
                if not spline.empty:
                    fig.add_trace(
                        go.Scatter3d(
                            x=spline['PCA_1'],
                            y=spline['PCA_2'],
                            z=spline['PCA_3'],
                            mode='lines',
                            line=dict(
                                color=color,
                                width=spline_width
                            ),
                            name=f"{phenotype} (spline)",
                            showlegend=True
                        )
                    )
    
    # Update layout
    fig.update_layout(
        scene=dict(
            xaxis_title='PCA 1',
            yaxis_title='PCA 2',
            zaxis_title='PCA 3'
        ),
        title=title,
        margin=dict(l=0, r=0, b=0, t=40),
        legend=dict(
            x=0.01,
            y=0.99,
            bordercolor="Black",
            borderwidth=1
        )
    )
    
    # Save figure if save_dir is provided
    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Use custom filename if provided, otherwise default to title-based name
        if filename is not None:
            save_filename = filename if filename.endswith('.html') else f"{filename}.html"
        else:
            save_filename = f"{title.replace(' ', '_')}.html"
            
        save_path = os.path.join(save_dir, save_filename)
        fig.write_html(save_path)
        print(f"Plot saved to: {save_path}")
    
    return fig

In [14]:
# combined_df 

In [15]:
# Now use the plot function with the splines data
# Create color dictionary for all genotypes with auto-scaling palette
import plotly.express as px

genotypes = combined_df['genotype'].unique().tolist()
print(f"Found {len(genotypes)} genotypes: {genotypes}")

# Generate enough colors for all genotypes
if len(genotypes) <= len(px.colors.qualitative.Plotly):
    colors = px.colors.qualitative.Plotly[:len(genotypes)]
else:
    # If we need more colors, cycle through multiple palettes
    all_colors = (px.colors.qualitative.Plotly + 
                  px.colors.qualitative.Set1 + 
                  px.colors.qualitative.Set2)
    colors = all_colors[:len(genotypes)]

color_dict = {genotype: colors[i] for i, genotype in enumerate(genotypes)}
print(f"Generated color palette for {len(color_dict)} genotypes")

# Create the plot with splines colored by genotype (original)
fig1 = plot_pca_with_splines(
    df_points=df_augmented,
    df_splines=pert_splines,
    point_opacity=0.65,
    color_dict=color_dict,
    save_dir=plot_dir,
    title="TZ Experiments with Splines by Genotype",
    phenotypes_of_interest=genotypes,  # Use all genotypes instead of top 6
    show_splines=True,
    spline_width=15,
    min_snip_count=10,  # Lower threshold for TZ data
    group_by_col='genotype',  # Specify we're using genotype column
    filename='TZ_genotype_splines_plot'  # Custom filename
)

# Create a second plot colored by predicted_stage_hpf
fig2 = plot_pca_with_splines(
    df_points=df_augmented,
    df_splines=pert_splines,
    point_opacity=0.65,
    save_dir=plot_dir,
    title="TZ Experiments Colored by Developmental Stage",
    phenotypes_of_interest=genotypes,  # Use all genotypes
    show_splines=True,
    spline_width=15,
    min_snip_count=10,
    group_by_col='genotype',
    color_by='predicted_stage_hpf',  # Color by developmental stage
    colorscale='plasma',  # Different colorscale
    filename='TZ_stage_colored_plot'
)

# fig1.show()
# fig2.show()
print(f"Created plots with all {len(genotypes)} genotypes and saved to {plot_dir}")

Found 12 genotypes: [nan, 'cep290_wildtype', 'cep290_heterozygous', 'cep290_unknown', 'cep290_homozygous', 'b9d2_homozygous', 'b9d2_heterozygous', 'b9d2_unknown', 'b9d2_wildtype', 'tmem67_heterozygous', 'tmem67_homozygous', 'tmem67_wildtype']
Generated color palette for 12 genotypes
Removed 69 embryos with fewer than 10 snips.


Plot saved to: /net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250917/plots/TZ_genotype_splines_plot.html
Removed 69 embryos with fewer than 10 snips.
Coloring points by predicted_stage_hpf
Global predicted_stage_hpf range: 7.00 to 127.94
Plot saved to: /net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250917/plots/TZ_stage_colored_plot.html
Created plots with all 12 genotypes and saved to /net/trapnell/vol1/home/mdcolon/proj/morphseq/results/mcolon/20250917/plots


Unnamed: 0,image_id,embryo_id,snip_id,frame_index,area_px,bbox_x_min,bbox_y_min,bbox_x_max,bbox_y_max,mask_confidence,...,z_sigma_b_91,z_sigma_b_92,z_sigma_b_93,z_sigma_b_94,z_sigma_b_95,z_sigma_b_96,z_sigma_b_97,z_sigma_b_98,z_sigma_b_99,source_experiment
0,20250416_A01_ch00_t0049,20250416_A01_e01,20250416_A01_e01_t0049,49,134875.0,0.254340,0.406122,0.971354,0.696208,0.85,...,-0.009283,-0.037781,-0.032273,-0.076436,-0.012015,0.003293,0.004194,0.071856,-0.002426,20250416
1,20250416_A01_ch00_t0045,20250416_A01_e01,20250416_A01_e01_t0045,45,130893.0,0.255208,0.407035,0.959201,0.695295,0.85,...,0.010598,-0.043417,-0.036820,-0.070282,-0.019894,0.006405,0.009552,0.076997,-0.007899,20250416
2,20250416_A01_ch00_t0051,20250416_A01_e01,20250416_A01_e01_t0051,51,135931.0,0.255208,0.406122,0.979167,0.693924,0.85,...,-0.009493,-0.023123,-0.019317,-0.073106,-0.003603,-0.002479,0.009091,0.066040,-0.000359,20250416
3,20250416_A01_ch00_t0044,20250416_A01_e01,20250416_A01_e01_t0044,44,130062.0,0.255208,0.407035,0.954861,0.695295,0.85,...,0.012714,-0.042843,-0.027540,-0.062661,-0.024849,0.007280,-0.013583,0.069797,-0.008688,20250416
4,20250416_A01_ch00_t0052,20250416_A01_e01,20250416_A01_e01_t0052,52,136846.0,0.255208,0.406578,0.980903,0.694381,0.85,...,-0.012598,-0.027236,-0.019540,-0.063449,-0.012674,-0.005051,0.007555,0.062848,0.008178,20250416
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24359,20250711_H12_ch00_t0023,20250711_H12_e01,20250711_H12_e01_t0023,23,122271.0,0.149306,0.394244,0.676215,0.742348,0.85,...,0.026204,0.010527,-0.013890,-0.032507,-0.008373,-0.030821,0.049402,0.014236,0.010132,20250711
24360,20250711_H12_ch00_t0019,20250711_H12_e01,20250711_H12_e01_t0019,19,115293.0,0.458333,0.349018,0.763889,0.745546,0.85,...,0.031393,0.009409,-0.008797,-0.014967,-0.000692,-0.017334,0.032821,-0.000179,0.004642,20250711
24361,20250711_H12_ch00_t0043,20250711_H12_e01,20250711_H12_e01_t0043,43,132734.0,0.436632,0.206030,0.798611,0.647784,0.85,...,0.015627,0.000085,0.005657,-0.047938,-0.017249,-0.029161,0.007121,0.092808,0.001477,20250711
24362,20250711_H12_ch00_t0037,20250711_H12_e01,20250711_H12_e01_t0037,37,129037.0,0.159722,0.199635,0.596354,0.606213,0.85,...,-0.014544,0.016132,0.000956,-0.035742,0.013989,0.002497,-0.037557,0.024579,0.025255,20250711


GENOTYPE BREAKDOWN BY EXPERIMENT

Experiment 20250305 (Total: 2999 rows):
----------------------------------------
  wildtype: 1240 (41.3%)
  cep290_het: 800 (26.7%)
  cep290_homo: 635 (21.2%)
  unknown: 324 (10.8%)

Experiment 20250416 (Total: 3523 rows):
----------------------------------------

Experiment 20250512 (Total: 7061 rows):
----------------------------------------
  cep290_heterozygous: 2560 (36.3%)
  cep290_wildtype: 2197 (31.1%)
  cep290_homozygous: 1246 (17.6%)
  cep290_unkown: 1058 (15.0%)

Experiment 20250515_part2 (Total: 981 rows):
----------------------------------------
  cep290_heterozygous: 514 (52.4%)
  cep290_homozygous: 217 (22.1%)
  cep290_wildtype: 179 (18.2%)
  cep290_unkown: 71 (7.2%)

Experiment 20250519 (Total: 3095 rows):
----------------------------------------
  b9d2_heterozygous: 1398 (45.2%)
  b9d2_wildtype: 732 (23.7%)
  b9d2_homozygous: 658 (21.3%)
  b9d2_unkown: 307 (9.9%)

Experiment 20250520 (Total: 1374 rows):
--------------------------------