### This notebook fits a reference spline to HF and AB reference data

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
from glob2 import glob
from pathlib import Path
from tqdm import tqdm
from src.functions.plot_functions import format_2d_plotly, format_3d_plotly

In [None]:
# load embryo_df for our current best model
# root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/"

root = Path("/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/morphseq/")
model_name = "20241107_ds_sweep01_optimum"
model_class = "legacy"
# later_than = 20250501
experiments = ["20240812", "20250215",
               "20250703_chem3_28C_T00_1325", "20250703_chem3_34C_T00_1131", "20250703_chem3_34C_T01_1457", 
               "20250703_chem3_35C_T00_1101", "20250703_chem3_35C_T01_1437", 
                '20250622_chem_28C_T00_1425', '20250622_chem_28C_T01_1658', '20250622_chem_34C_T00_1256', 
                '20250622_chem_34C_T01_1632', '20250622_chem_35C_T00_1223_check', '20250622_chem_35C_T01_1605', 
                '20250623_chem_28C_T02_1259', '20250623_chem_34C_T02_1231', '20250623_chem_35C_T02_1204', 
                '20250624_chem02_28C_T00_1356', '20250624_chem02_28C_T01_1808', '20250624_chem02_34C_T00_1243', 
                '20250624_chem02_34C_T01_1739', '20250624_chem02_35C_T00_1216', '20250624_chem02_35C_T01_1711', 
                '20250625_chem02_28C_T02_1332', '20250625_chem02_34C_T02_1301', '20250625_chem02_35C_T02_1228']


# load latent embeddings
latent_path = root / "analysis" / "latent_embeddings" / model_class / model_name
df_list = []
for e, exp in enumerate(tqdm(experiments)):
    df_path = latent_path / f"morph_latents_{exp}.csv"
    df_temp = pd.read_csv(df_path)
    df_list.append(df_temp)

latent_df = pd.concat(df_list) 

# load metadata
meta_path = root / "metadata" / "embryo_metadata_files"
df_list = []
for e, exp in enumerate(tqdm(experiments)):
    df_path = meta_path / f"{exp}_embryo_metadata.csv"
    df_temp = pd.read_csv(df_path)
    df_list.append(df_temp)

meta_df = pd.concat(df_list) 

# remove one problematic ID
print(meta_df.shape)
rm_ids = ["20250624_chem02_35C_T00_1216_C02_e01", "20250624_chem02_35C_T01_1711_C02_e01","20250625_chem02_35C_T02_1228_C02_e01"]
meta_df = meta_df.loc[~meta_df["embryo_id"].isin(rm_ids)]
print(meta_df.shape)
# path to save data
out_path = os.path.join(root, "results", "20240707", "")
os.makedirs(out_path, exist_ok=True)

# path to figures and data
# fig_path = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/slides/morphseq/20250312/morph_metrics/"
# os.makedirs(fig_path, exist_ok=True)
# meta_df.head()

#### Merge embeddings with metadata

In [None]:
# join
keep_cols = ['snip_id', 'well', 'nd2_series_num', 'microscope', 'time_int',  'genotype',
             'chem_perturbation', 'start_age_hpf', 'temperature',
              'well_qc_flag', 'Time Rel (s)', 'use_embryo_flag', 'frame_flag', 'dead_flag']

master_df = meta_df.loc[:, keep_cols].merge(latent_df, how="inner", on=["snip_id"])
master_df["experiment_date"] = master_df["experiment_date"].astype(str)
print(latent_df.shape)
print(master_df.shape)

#### Make some helpful flags

In [None]:
# easily distinguish between experiments
chem_flag_vec = [
    "ref0" if exp == "20240812"
    else "ref1" if exp == "20250215"
    else exp.split("_")[1].replace("0", "")
    for exp in master_df["experiment_date"]
]
master_df["exp_id"] = chem_flag_vec

# split datasets
ref_embryo_df = master_df.loc[master_df["exp_id"].isin(["ref0", "ref1"])].copy()
chem_embryo_df = master_df.loc[~master_df["exp_id"].isin(["ref0", "ref1"])].copy()

# get time point info
time_int_vec = [int(exp.split("_")[3].replace("T", "")) for exp in chem_embryo_df["experiment_date"]]
chem_embryo_df.loc[:, "time_int"] = time_int_vec

time_stamp_vec = [int(exp.replace("_check","").split("_")[-1]) for exp in chem_embryo_df["experiment_date"]]
chem_embryo_df.loc[:, "time_stamp"] = time_stamp_vec

datetime_stamp_vec = [int(exp.replace("_check","").split("_")[0] + 
                          exp.replace("_check","").split("_")[-1]) for exp in chem_embryo_df["experiment_date"]]
chem_embryo_df.loc[:, "datetime_stamp"] = datetime_stamp_vec

# ceate new embryo_id var
eid_vec = chem_embryo_df["exp_id"].str.cat(chem_embryo_df["well"], sep="_").str.cat(chem_embryo_df["temperature"].astype(int).astype(str) 
                                                                                    + "C", sep="_")
snip_vec = chem_embryo_df["exp_id"].str.cat(chem_embryo_df["well"], sep="_").str.cat(chem_embryo_df["temperature"].astype(int).astype(str)
                                                                                     + "C", sep="_").str.cat(
                                        "T" + chem_embryo_df["time_int"].astype(str).str.zfill(4) , sep="_"      # col 3 (int → 3-digit string)
                                    )
chem_embryo_df["snip_id"] = snip_vec
chem_embryo_df["embryo_id"] = eid_vec

In [None]:
# print(np.mean(chem_embryo_df["frame_flag"]))
# chem_embryo_df.loc[chem_embryo_df.embryo_id=="20250624_C05_34C"]

### QC

In [None]:
# Apply quick fix to some chem labels in #2
# chem_embryo_df.loc[chem_embryo_df["exp_id"]=="chem2", "chem_perturbation"].unique()
fix_dict = {'tgfb_i_13':'tgfb_i_6', 'wnt_i_13':'wnt_i_6', 'fgf_i_13':'fgf_i_6', 'bmp_i_13':'bmp_i_6'}

m = chem_embryo_df["exp_id"].eq("chem2")
# …and replace only there
chem_embryo_df.loc[m, "chem_perturbation"] = (
    chem_embryo_df.loc[m, "chem_perturbation"].replace(fix_dict)
)

# remove problem observations from dataset
qc_emb_list = ["chem_C10_28C", "chem2_A01_35C", "chem2_E01_35C", "chem2_A11_35C", "chem2_B11_35C", 
               "chem2_B11_28C", "chem2_A11_34C", "chem_A05_28C", "chem_A02_35C", "chem2_C05_28C"]
qc_snip_list = ["chem3_F06_35C_T0001"]

print(chem_embryo_df.shape)
chem_embryo_df = chem_embryo_df.loc[~chem_embryo_df["snip_id"].isin(qc_snip_list)]
print(chem_embryo_df.shape)
chem_embryo_df = chem_embryo_df.loc[~chem_embryo_df["embryo_id"].isin(qc_emb_list)]
print(chem_embryo_df.shape)
qc_mask = (~chem_embryo_df["dead_flag"]) & (~chem_embryo_df["frame_flag"]) 
chem_embryo_df = chem_embryo_df.loc[qc_mask]
print(chem_embryo_df.shape)

### Fit PCA to just the ref and hotfish data

In [None]:
from sklearn.decomposition import PCA
import re 

# params
n_components = 10
z_pattern = "z_mu_b"
mu_cols = [col for col in ref_embryo_df.columns if re.search(z_pattern, col)]
pca_cols = [f"PCA_{p:02}_bio" for p in range(n_components)]

# fit
np.random.seed(345)
ref_indices = np.random.choice(ref_embryo_df.shape[0], chem_embryo_df.shape[0],replace=False)
morph_pca = PCA(n_components=n_components)
morph_pca.fit(pd.concat([chem_embryo_df[mu_cols], ref_embryo_df.loc[ref_indices, mu_cols]]))#, ref_df[mu_cols]]))

# transform
ref_pca_array = morph_pca.transform(ref_embryo_df[mu_cols])
chem_pca_array = morph_pca.transform(chem_embryo_df[mu_cols])

to_cols = ["snip_id", "embryo_id", "exp_id", "temperature", "timepoint", "chem_perturbation"]
from_cols = ["snip_id", "embryo_id", "exp_id", "temperature", "time_int", "chem_perturbation"]
ref_pca_df = pd.DataFrame(ref_pca_array, columns=pca_cols)
ref_pca_df[to_cols] = ref_embryo_df[from_cols].to_numpy()


chem_pca_df = pd.DataFrame(chem_pca_array, columns=pca_cols)
chem_pca_df[to_cols] = chem_embryo_df[from_cols].to_numpy()

In [None]:
# ref_indices

### Do the same for UMAP

In [None]:
# import umap.umap_ as umap
# from sklearn.preprocessing import StandardScaler

# n_u = 3
# umap_cols = [f"UMAP_{p:02}_bio" for p in range(n_u)]

# reducer_bio = umap.UMAP(n_components=n_u)
# z_mu_array_b = morph_pca.transform(pd.concat([chem_embryo_df[mu_cols]]))#, chem02_df[mu_cols], chem03_df[mu_cols]])#, ref_df[mu_cols]]).to_numpy()
# # z_mu_array_b = morph_pca.transform(pd.concat([chem01_df[mu_cols], chem02_df[mu_cols], chem03_df[mu_cols], ref_df[mu_cols]]).to_numpy())
# # scaled_z_mu_bio = StandardScaler().fit_transform(z_mu_array_b)
# reducer_bio.fit(z_mu_array_b)

# # transform
# ref_umap_array = reducer_bio.transform(morph_pca.transform(ref_embryo_df[mu_cols]))
# chem_umap_array = reducer_bio.transform(morph_pca.transform(chem_embryo_df[mu_cols]))

# ref_umap_df = pd.DataFrame(ref_umap_array, columns=umap_cols)
# ref_umap_df[to_cols] = ref_embryo_df[from_cols].to_numpy()


# chem_umap_df = pd.DataFrame(chem_umap_array, columns=umap_cols)
# chem_umap_df[to_cols] = chem_embryo_df[from_cols].to_numpy()

In [None]:
var_cumulative = np.cumsum(morph_pca.explained_variance_ratio_)
fig = px.line(x=np.arange(n_components), y=var_cumulative, markers=True)

fig.update_layout(xaxis=dict(title="PC number"),
                  yaxis=dict(title="total variance explained"),
                  title="PCA decomposition of morphVAE latent space",
                     font=dict(
                        family="Arial, sans-serif",
                        size=18,  # Adjust this value to change the global font size
                        color="black"
                    ))

fig = format_2d_plotly(fig, marker_size=12)

fig.show()

# fig.write_image(os.path.join(fig_path, "morph_pca_var_explained.png"))

### Visualize embryo morphologies

In [None]:
plot_cols = pca_cols[:3]

# plot ref data just to see
fig = px.scatter_3d(ref_pca_df, x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], color="exp_id", opacity=0.1)
fig.show()

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2) # & chem_pca_df["chem_perturbation"].isin(["mTOR_i_6"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], color="chem_perturbation", hover_data={"snip_id"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)")))
fig.show()

### Plot just the controls

In [None]:
# Chem 1
plot_time = 0

plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] == plot_time) & chem_pca_df["chem_perturbation"].isin(["DMSO_6"]) & \
                chem_pca_df["exp_id"].isin(["chem"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="chem_perturbation", symbol="temperature", hover_data={"snip_id", "temperature"})

fig = format_3d_plotly(fig, marker_size=10, theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

In [None]:
# Chem 2
plot_time = 0

plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] == plot_time) & chem_pca_df["chem_perturbation"].isin(["DMSO_6"]) & \
                chem_pca_df["exp_id"].isin(["chem2"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="chem_perturbation", symbol="temperature", hover_data={"snip_id", "temperature"})

fig = format_3d_plotly(fig, marker_size=10, theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

In [None]:
# Chem 3
plot_time = 0

plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] == plot_time) & chem_pca_df["chem_perturbation"].isin(["DMSO_6"]) & \
                chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="chem_perturbation", symbol="temperature", hover_data={"snip_id", "temperature"})

fig = format_3d_plotly(fig, marker_size=10, theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

### Now let's actually try to put this to work
We should be able to use temp-only and chem-only morphologies to construct a "null" expectation for where a combined pertubation should live

In [None]:
from itertools import product
# get deltas by experiment, treatment, and temperature
# limit ourselves to the first time point for now--if we can avoid bringing flux into it we should

n_pc = 10 # number of PCs to use
times_to_use = np.asarray([0])
ctrl_temp = 28
ctrl_pert = "DMSO_6"

analysis_df = chem_pca_df.loc[chem_pca_df["timepoint"].isin(times_to_use)].reset_index(drop=True)
analysis_df["temperature"] = analysis_df["temperature"].astype(int)

# get vectors
exp_id_vec = chem_embryo_df["exp_id"].unique()

df_list = []
for exp in tqdm(exp_id_vec):
    mask = analysis_df["exp_id"].eq(exp) 
    
    # get unique treatment vectors
    pert_id_vec = analysis_df.loc[mask, "chem_perturbation"].unique()
    pert_id_vec = pert_id_vec[pert_id_vec != ctrl_pert]
    temp_id_vec = analysis_df.loc[mask, "temperature"].unique()
    # temp_id_vec = temp_id_vec[temp_id_vec != ctrl_temp]

    for temp, chem in product(temp_id_vec, pert_id_vec):
        # extract baseline PCA cols
        ref_mask = mask & analysis_df["chem_perturbation"].eq(ctrl_pert) & analysis_df["temperature"].eq(temp) 
        ref_array = analysis_df.loc[ref_mask, pca_cols].to_numpy()
    
        pert_mask = mask & analysis_df["chem_perturbation"].eq(chem) & analysis_df["temperature"].eq(temp)
        pert_array = analysis_df.loc[pert_mask, pca_cols].to_numpy()
        
        # all pairwise differences, vector-wise
        delta = pert_array[:, None, :] - ref_array[None, :, :]      # shape → (n_pert, n_ref, k)
        
        # if you want a 2-D table instead of a 3-D tensor:
        delta_flat = delta.reshape(-1, delta.shape[-1]) # (n_pert*n_ref, k)
        temp_df   = pd.DataFrame(delta_flat, columns=pca_cols)

        temp_df["exp_id"] = exp
        temp_df["chem"] = chem
        temp_df["temp"] = temp

        temp_df = temp_df.loc[:, ["exp_id", "chem", "temp"] + pca_cols]

        df_list.append(temp_df)

pert_df = pd.concat(df_list, axis=0, ignore_index=True)

Now we want to look at the **differences** in pertubation vectors, both wrpt magnitude and direction

In [None]:
np.random.seed(345)
n_bootstrap = 1000 
n_samp = 8
dd_df_list = []

for exp in tqdm(exp_id_vec):
    mask0 = pert_df["exp_id"].eq(exp) 

    # get unique treatment vectors
    pert_id_vec = pert_df.loc[mask0, "chem"].unique()

    for chem in pert_id_vec:
        mask1 = mask0 & pert_df["chem"].eq(chem) 
        temp_id_vec = pert_df.loc[mask1, "temp"].unique()
        # temp_id_vec = temp_id_vec[temp_id_vec != ctrl_temp]
        
        # use chem delta at 28C as our reference
        ref_mask = mask1 & pert_df["temp"].eq(ctrl_temp)
        ref_array = pert_df.loc[ref_mask, pca_cols].to_numpy()
    
        for temp in temp_id_vec:
            target_mask = mask1 & pert_df["temp"].eq(temp)
            target_array = pert_df.loc[target_mask, pca_cols].to_numpy()
            
            bootstrap_diff_means = np.zeros(n_bootstrap)
            bootstrap_means = np.zeros(n_bootstrap)
            bootstrap_cosine_distances = np.zeros(n_bootstrap)  
            n_samples = np.min([n_samp, ref_array.shape[0], target_array.shape[0]])
            
            ref_norms = np.linalg.norm(ref_array, axis=1)
            target_norms = np.linalg.norm(target_array, axis=1)
            
            # Normalize arrays once (for cosine similarity calculation)
            ref_array_normed = ref_array / ref_norms[:, np.newaxis]     
            target_array_normed = target_array / target_norms[:, np.newaxis]  
            
            for i in range(n_bootstrap):
                # Resample indices with replacement
                sample_indices0 = np.random.choice(n_samples, n_samples, replace=True)
                sample_indices1 = np.random.choice(n_samples, n_samples, replace=True)
            
                # Magnitude difference
                bootstrap_diff_means[i] = np.mean(target_norms[sample_indices1]) - np.mean(ref_norms[sample_indices0]) 
                bootstrap_means[i] = np.mean(target_norms[sample_indices1]) 
            
                # Cosine distance calculation (1 - cosine similarity)
                cos_sim = np.einsum('ij,ij->i',
                                    ref_array_normed[sample_indices0], 
                                    target_array_normed[sample_indices1])
            
                bootstrap_cosine_distances[i] = np.mean(cos_sim)   
            
            # Compute bootstrap estimates (magnitudes)
            bootstrap_diff_mean = np.median(bootstrap_diff_means)
            bootstrap_diff_se = (np.percentile(bootstrap_diff_means, 84) - np.percentile(bootstrap_diff_means, 16))/2

            bootstrap_mean = np.median(bootstrap_means)
            bootstrap_se = (np.percentile(bootstrap_means, 84) - np.percentile(bootstrap_means, 16))/2
            
            # Compute bootstrap estimates (cosine distances)
            cosine_mean = np.median(bootstrap_cosine_distances)           
            cosine_se = (np.percentile(bootstrap_cosine_distances, 84) - np.percentile(bootstrap_cosine_distances, 16))/2
            
            # Results
            temp_df = pd.DataFrame([[exp, chem, temp]], columns=["exp_id", "chem", "temp"])
            temp_df["pert_size_delta"] = bootstrap_diff_mean
            temp_df["pert_size_delta_se"] = bootstrap_diff_se

            temp_df["pert_size"] = bootstrap_mean
            temp_df["pert_size_se"] = bootstrap_se
            
            temp_df["pert_cosine_dist"] = cosine_mean                  
            temp_df["pert_cosine_dist_se"] = cosine_se                           

            dd_df_list.append(temp_df)

pert_comp_df = pd.concat(dd_df_list, ignore_index=True)

In [None]:
plot_filter = pert_comp_df["temp"].ne(28)

fig = px.scatter(pert_comp_df.loc[plot_filter], x="pert_size_delta", y="pert_cosine_dist", 
                 error_x="pert_size_delta_se", error_y="pert_cosine_dist_se", color="chem", symbol="exp_id",
                 hover_data={"exp_id", "temp"})

# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)

# Add reference lines at x=0 and y=0
xm = 3
ym = 1

# Add reference lines at x = 0 and y = 0 that span full axis range
fig.add_shape(
    type="line",
    x0=0, x1=0,
    y0=-ym, y1=ym,
    line=dict(color="white", width=3, dash="dash"),
    xref='x', yref='y'
)
fig.add_shape(
    type="line",
    x0=-xm, x1=xm,
    y0=0, y1=0,
    line=dict(color="white", width=3, dash="dash"),
    xref='x', yref='y'
)

fig = format_2d_plotly(fig, axis_labels=["temperature-dependent effect increase", "similarity"], marker_size=10)

fig.update_layout(
    xaxis=dict(range=[-xm, xm], zeroline=False),
    yaxis=dict(range=[-1, 1], zeroline=False),
    width=800,
    height=600
)

fig.show()

### Create wide form of table

In [None]:
# Set index columns
id_cols = ["exp_id", "chem"]

# All other columns are value columns, except 'temp'
value_cols = [
    "pert_size_delta", "pert_size_delta_se",
    "pert_size", "pert_size_se",
    "pert_cosine_dist", "pert_cosine_dist_se"
]

# Use pivot to widen
pert_df_wide = pert_comp_df.pivot_table(
    index=id_cols,
    columns="temp",
    values=value_cols
)

# Optionally flatten column MultiIndex
pert_df_wide.columns = [f"{col}_{temp}" for col, temp in pert_df_wide.columns]
pert_df_wide = pert_df_wide.reset_index()
pert_df_wide.head()

### Size effects

In [None]:
plot_filter = pert_df_wide["exp_id"].ne("chem10")

fig = px.scatter(pert_df_wide.loc[plot_filter], x="pert_size_28", y="pert_size_35", 
                 error_x="pert_size_se_28", error_y="pert_size_se_35", color="chem", symbol="exp_id",
                 hover_data={"exp_id"})

# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    marker_size=10,
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)
m = 5
ref_line = np.linspace(0, m)
fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(dash="dash", color="white"), showlegend=False))

fig = format_2d_plotly(fig, axis_labels=["effect size (28C)", "effect size (35C)"], marker_size=10)


fig.update_layout(
    xaxis=dict(range=[0, m], zeroline=False),
    yaxis=dict(range=[0, m], zeroline=False),
    width=800,
    height=600
)

### Look at variability

In [None]:
group_vars = ["timepoint", "exp_id", "temperature", "chem_perturbation"]

chem_pca_summary = (
    chem_pca_df[group_vars + pca_cols]      # keep only needed columns
      .groupby(group_vars)                  # group by the four keys
      .agg(['mean', 'var'])                 # compute both stats
)

# Flatten the MultiIndex columns:  ('PC1','mean') → 'PC1_mean'
chem_pca_summary.columns = [
    f"{col}_{stat}" for col, stat in chem_pca_summary.columns
]

chem_pca_summary = chem_pca_summary.reset_index()

var_cols = [col for col in chem_pca_summary.columns if "var" in col]

chem_pca_summary["total_var"] = np.sum(chem_pca_summary.loc[:, var_cols], axis=1)
chem_pca_summary["total_std"] = np.sqrt(chem_pca_summary["total_var"].to_numpy())

# generate smaller table to just look at 1D summary stats
drop_cols = [col for col in chem_pca_summary if "PCA" in col]
chem_narrow = chem_pca_summary.drop(labels=drop_cols, axis=1)
chem_narrow["temperature"] = chem_narrow["temperature"].astype(int)
# Set index columns
id_cols = ["exp_id", "chem_perturbation", "timepoint"]

# All other columns are value columns, except 'temp'
value_cols = [
    "total_var", "total_std"
]

# Use pivot to widen
chem_wide = chem_narrow.pivot_table(
    index=id_cols,
    columns="temperature",
    values=value_cols
)

# Optionally flatten column MultiIndex
chem_wide.columns = [f"{col}_{temp}" for col, temp in chem_wide.columns]
chem_wide = chem_wide.reset_index()
chem_wide.head()

In [None]:
plot_filter =  chem_wide["timepoint"].ne(2) # chem_wide["exp_id"].ne("chem2") &

fig = px.scatter(chem_wide.loc[plot_filter], x="total_std_28", y="total_std_35", color="chem_perturbation", symbol="exp_id",
                 hover_data={"exp_id", "timepoint"})

# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    marker_size=10,
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)
m = 2.5
ref_line = np.linspace(0, m)
fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(dash="dash", color="white"), showlegend=False))

fig = format_2d_plotly(fig, axis_labels=["phenotype variablity (28C)", "phenotype variablity (35C)"], marker_size=10)


fig.update_layout(
    xaxis=dict(range=[0, m], zeroline=False),
    yaxis=dict(range=[0, m], zeroline=False),
    width=800,
    height=600
)

In [None]:
plot_filter =  chem_wide["timepoint"].ne(2) # chem_wide["exp_id"].ne("chem2") &

fig = px.scatter(chem_wide.loc[plot_filter], x="total_std_28", y="total_std_34", color="chem_perturbation", symbol="exp_id",
                 hover_data={"exp_id", "timepoint"})

# Increase marker size and remove error bar caps
fig.update_traces( # Increase dot size
    marker_size=10,
    error_x=dict(thickness=1, width=0),  # Remove horizontal caps
    error_y=dict(thickness=1, width=0),  # Remove vertical caps
)
m = 2.5
ref_line = np.linspace(0, m)
fig.add_trace(go.Scatter(x=ref_line, y=ref_line, mode="lines", line=dict(dash="dash", color="white"), showlegend=False))

fig = format_2d_plotly(fig, axis_labels=["phenotype variablity (28C)", "phenotype variablity (34C)"], marker_size=10)


fig.update_layout(
    xaxis=dict(range=[0, m], zeroline=False),
    yaxis=dict(range=[0, m], zeroline=False),
    width=800,
    height=600
)
fig.show()

### Let's kick the tires a little

#### Fgf-i

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["fgf_i_13", "DMSO_6"]) & \
                chem_pca_df["exp_id"].isin(["chem"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="chem_perturbation", symbol="temperature", hover_data={"snip_id", "temperature"})

fig = format_3d_plotly(fig, marker_size=10, theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

#### BMP-i

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["bmp_i_6"]) & \
                chem_pca_df["exp_id"].isin(["chem2"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="temperature", hover_data={"snip_id", "temperature"})

fig = format_3d_plotly(fig, marker_size=10, theme="light")

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

#### Shh-i

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2)  & chem_pca_df["chem_perturbation"].isin(["shh_i_6"]) & \
                ~chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="exp_id", hover_data={"snip_id", "temperature", "timepoint"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

#### NF-KB

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["nfkb_i_6"]) & chem_pca_df["exp_id"].isin(["chem2"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="temperature", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### HSP90

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["hsp90_i_6"]) & \
                        ~chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="temperature", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### RA

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["ra_lo_i_6", "DMSO_6"]) & \
                        chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="chem_perturbation", symbol="temperature", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

#### Wnt

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2)  & chem_pca_df["chem_perturbation"].isin(["wnt_i_13"]) & \
                                    ~chem_pca_df["exp_id"].isin(["chem2"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="timepoint", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### BMP

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["bmp_i_13"]) & \
                                    chem_pca_df["exp_id"].isin(["chem"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="temperature", hover_data={"snip_id", "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### PI3K

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 1)  & chem_pca_df["chem_perturbation"].isin(["DMSO_6", "pi3k_lo_i_6"]) & \
                                    chem_pca_df["exp_id"].isin(["chem3"])

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="temperature", hover_data={"snip_id", "chem_perturbation",
                                                                           "temperature"})

fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.show()

### mTOR

In [None]:
plot_cols = pca_cols[:3]
# chem01_pca_df["temperature"] = chem01_pca_df["temperature"].astype(float)
plot_filter =  (chem_pca_df["timepoint"] < 2)  & chem_pca_df["chem_perturbation"].isin(["DMSO_6", "mTOR_i_6"]) & \
                ~chem_pca_df["exp_id"].isin(["chem3"])# & chem_pca_df["temperature"].eq(34)

fig = px.scatter_3d(chem_pca_df.loc[plot_filter], x=plot_cols[0], y=plot_cols[1], z=plot_cols[2], 
                    color="temperature", symbol="chem_perturbation", hover_data={"snip_id", "chem_perturbation",
                                                                                 "temperature"})

# fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
#                           marker=dict(color="rgba(0,0,0,0.01)"), showlegend=False))
fig.add_trace(go.Scatter3d(x=ref_pca_df[plot_cols[0]], y=ref_pca_df[plot_cols[1]], z=ref_pca_df[plot_cols[2]], mode="markers",
                          marker=dict(color=ref_embryo_df["time_int"], opacity=0.01), showlegend=False))
fig.show()

np.mean(chem_embryo_df["use_embryo_flag"])

In [None]:
meta_df.loc[mask].shape

In [None]:
check_list = meta_df.loc[meta_df.focus_flag & (~meta_df.frame_flag) & (~meta_df.dead_flag), "embryo_id"].tolist()
check_list