In [1]:
%load_ext autoreload
%autoreload 2
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-864c07c4-8eeb-5b23-8d57-eaeb942a9a0f"
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from hydra.utils import instantiate
from PIL import Image
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader, Dataset

device = "cuda:0"

In [2]:
os.chdir("/allen/aics/modeling/ritvik/projects/benchmarking_representations/")
save_path = "./test_npm1_save_embeddings/"

In [3]:
from cyto_dl.models.utils.mlflow import load_model_from_checkpoint

from br.data.get_datamodules import get_data
from br.models.load_models import load_model_from_path
from br.models.save_embeddings import get_pc_loss


def get_data_and_models(dataset_name, batch_size, results_path, debug=False):
    data_list = get_data(dataset_name, batch_size, results_path, debug)
    all_models, run_names, model_sizes = load_model_from_path(
        dataset_name, results_path
    )  # default list of models in load_models.py
    return data_list, all_models, run_names, model_sizes

In [4]:
dataset_name = "npm1"
batch_size = 2
debug = True
results_path = (
    "/allen/aics/modeling/ritvik/projects/benchmarking_representations/br/configs/results/"
)
data_list, all_models, run_names, model_sizes = get_data_and_models(
    dataset_name, batch_size, results_path, debug
)

# Compute embeddings and emissions

In [5]:
run_names

['SO3_pointcloud_SDF',
 'SO3_image_SDF',
 'SO3_image_seg',
 'Classical_image_SDF',
 'Classical_image_seg']

In [8]:
from br.models.save_embeddings import save_embeddings

splits_list = ["train", 'val', "test"]
meta_key = None
eval_scaled_img = [True] * 5

gt_mesh_dir = "/allen/aics/assay-dev/users/Alex/replearn/rep_paper/data/var_blobby_noalign/meshes"
gt_sampled_pts_dir = "/allen/aics/assay-dev/users/Alex/replearn/rep_paper/data/sampled_pcs/npm1_var_noalign_global/1_res/0"
gt_scale_factor_dict_path = "/allen/aics/assay-dev/users/Alex/replearn/rep_paper/data/npm1_var_scale_factor_32res_noalign_global.npz"

eval_scaled_img_params = [
                        {"eval_scaled_img_model_type":"iae",
                          "eval_scaled_img_resolution":32,
                          "gt_mesh_dir":gt_mesh_dir,
                          "gt_scale_factor_dict_path":None,
                          "gt_sampled_pts_dir":gt_sampled_pts_dir,
                          "mesh_ext":"stl"},
                        {"eval_scaled_img_model_type":"sdf",
                          "eval_scaled_img_resolution":32,
                          "gt_mesh_dir":gt_mesh_dir,
                          "gt_scale_factor_dict_path":gt_scale_factor_dict_path,
                          "gt_sampled_pts_dir":None,
                          "mesh_ext":"stl"},
                        {"eval_scaled_img_model_type":"seg",
                          "eval_scaled_img_resolution":32,
                          "gt_mesh_dir":gt_mesh_dir,
                          "gt_scale_factor_dict_path":gt_scale_factor_dict_path,
                          "gt_sampled_pts_dir":None,
                          "mesh_ext":"stl"},
                        {"eval_scaled_img_model_type":"sdf",
                          "eval_scaled_img_resolution":32,
                          "gt_mesh_dir":gt_mesh_dir,
                          "gt_scale_factor_dict_path":gt_scale_factor_dict_path,
                          "gt_sampled_pts_dir":None,
                          "mesh_ext":"stl"},
                        {"eval_scaled_img_model_type":"seg",
                          "eval_scaled_img_resolution":32,
                          "gt_mesh_dir":gt_mesh_dir,
                          "gt_scale_factor_dict_path":gt_scale_factor_dict_path,
                          "gt_sampled_pts_dir":None,
                          "mesh_ext":"stl"},]
loss_eval_list = [torch.nn.MSELoss(reduction='none')]
sample_points_list = [False]*5
skew_scale = None
save_embeddings(
    save_path,
    data_list,
    all_models,
    run_names,
    debug,
    splits_list,
    device,
    meta_key,
    loss_eval_list,
    sample_points_list,
    skew_scale,
    eval_scaled_img,
    eval_scaled_img_params,
)

Processing train


  0%|                                                                                                             | 4/4267 [00:11<3:17:21,  2.78s/it]


Processing val


  0%|                                                                                                                        | 0/754 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# dataset_name = 'npm1_test'
run_names = ["equiv_vnn"]
from src.models.compute_features import get_embeddings

all_ret, df = get_embeddings(run_names, dataset_name)

In [None]:
import pandas as pd

df = pd.read_csv("/allen/aics/assay-dev/users/Alex/replearn/rep_paper/data/var_npm1_manifest.csv")
cols_to_use = df.columns.difference(all_ret.columns).tolist() + ["CellId"]
all_ret = all_ret.merge(df[cols_to_use], on="CellId")
mesh_df = pd.read_csv(
    "/allen/aics/assay-dev/users/Alex/replearn/rep_paper/data/var_npm1_manifest.csv"
)

In [None]:
updated_feat_df = pd.read_csv(
    "/allen/aics/assay-dev/users/Alex/replearn/rep_paper/processing_data/npm1_fullres_features.csv"
)

In [None]:
all_vals = []
for ind, row in updated_feat_df.iterrows():
    if row["connectivity_cc"] >= 5.0:
        new_val = ">=5"
    else:
        new_val = str(row["connectivity_cc"])
    all_vals.append(new_val)

In [None]:
updated_feat_df["new_connectivity_thresh"] = all_vals

In [None]:
updated_feat_df

In [None]:
all_ret = all_ret.merge(updated_feat_df[["CellId", "new_connectivity_thresh"]], on="CellId")

In [None]:
all_ret["new_connectivity_thresh"].value_counts()

In [None]:
from src.features.classification import get_classification_df

connect_class = get_classification_df(all_ret, "new_connectivity_thresh")

In [None]:
# regress_cols = ['avg_dists', 'std_dists', 'mean_volume',
#                                            'std_volume', 'mean_surface_area', 'std_surface_area']
cols = ["avg_dists"]
all_ret = all_ret.drop(columns=cols)

In [None]:
all_ret = all_ret.merge(updated_feat_df[["CellId"] + regress_cols], on="CellId")

In [None]:
from src.features.regression import get_regression_df

regress = get_regression_df(all_ret, regress_cols, None)

In [None]:
regress.groupby(["target"]).mean()

In [None]:
# regress.to_csv('./npm1_global/regression.csv')

In [None]:
import pyvista as pv
from cyto_dl.image.transforms import RotationMask
from skimage.io import imread
from sklearn.decomposition import PCA
from src.data.utils import mesh_seg_model_output
from tqdm import tqdm

num_pieces = ">=5"

for num_pieces in ["1.0", "2.0", "3.0", "4.0", ">=5"]:
    # num_pieces = '4.0'
    this_sub_m = all_ret.loc[all_ret["new_connectivity_thresh"] == num_pieces].reset_index(
        drop=True
    )
    all_features = this_sub_m[[i for i in this_sub_m.columns if "mu" in i]].values
    latent_dim = 512
    dim_size = latent_dim
    x_label = "pcloud"
    pca = PCA(n_components=dim_size)
    pca_features = pca.fit_transform(all_features)
    pca_std_list = pca_features.std(axis=0)
    rank = 0
    all_xhat = []
    all_closest_real = []
    all_closest_img = []
    latent_walk_range = [-2, 0, 2]
    # latent_walk_range = [-2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2]
    for value_index, value in enumerate(tqdm(latent_walk_range, total=len(latent_walk_range))):
        z_inf = torch.zeros(1, dim_size)
        z_inf[:, rank] += value * pca_std_list[rank]
        z_inf = pca.inverse_transform(z_inf)

        dist = (all_features - z_inf) ** 2
        dist = np.sum(dist, axis=1)
        closest_idx = np.argmin(dist)
        closest_real_id = this_sub_m.iloc[closest_idx]["CellId"]
        mesh = pv.read(
            mesh_df.loc[mesh_df["CellId"] == closest_real_id]["mesh_path_noalign"].iloc[0]
        )
        mesh.save(f"./npm1_global/closest2/{num_pieces}_{rank}_{value_index}.ply")

In [None]:
this_ret = all_ret
matrix = this_ret[[i for i in this_ret.columns if "mu" in i]].values

In [None]:
from src.features.archetype import AA_Fast

n_archetypes = 5
aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix)

In [None]:
import pandas as pd

archetypes_df = pd.DataFrame(aa.Z, columns=[f"mu_{i}" for i in range(matrix.shape[1])])

In [None]:
archetypes_df

In [None]:
all_features = matrix
for i in range(n_archetypes):
    this_mu = archetypes_df.iloc[i].values
    dist = (all_features - this_mu) ** 2
    dist = np.sum(dist, axis=1)
    closest_idx = np.argmin(dist)
    closest_real_id = this_ret.iloc[closest_idx]["CellId"]
    print(dist, closest_real_id)
    mesh = pv.read(mesh_df.loc[mesh_df["CellId"] == closest_real_id]["mesh_path_noalign"].iloc[0])
    mesh.save(f"./npm1_global/archetype2/{i}.ply")

In [None]:
# all_features =  matrix
# for i in range(n_archetypes):
#     this_mu = archetypes_df.iloc[i].values
#     dist = (all_features - this_mu) ** 2
#     dist = np.sum(dist, axis=1)
#     closest_idx = np.argmin(dist)
#     closest_real_id = this_ret.iloc[closest_idx]['CellId']
#     print(dist, closest_real_id)
#     mesh = pv.read(mesh_df.loc[mesh_df['CellId'] == closest_real_id]['mesh_path_noalign'].iloc[0])
#     mesh.save(f'./npm1_test/closest/archetype/{i}.ply')

In [None]:
all_ret["STR_connectivity_cc_thresh"]

In [None]:
for hh in all_ret["new_connectivity_thresh"].unique():
    this_ret = all_ret.loc[all_ret["new_connectivity_thresh"] == hh].reset_index(drop=True)
    labels = this_ret["structure_name"].values
    # labels = this_ret['cell_stage_fine'].values
    matrix = this_ret[[i for i in this_ret.columns if "mu" in i]].values
    from src.features.archetype import AA_Fast

    n_archetypes = 3
    aa = AA_Fast(n_archetypes, max_iter=1000, tol=1e-6).fit(matrix)
    all_features = matrix
    for i in range(n_archetypes):
        print(hh, i)
        this_mu = archetypes_df.iloc[i].values
        dist = (all_features - this_mu) ** 2
        dist = np.sum(dist, axis=1)
        closest_idx = np.argmin(dist)
        closest_real_id = this_ret.iloc[closest_idx]["CellId"]
        print(dist, closest_real_id)
        mesh = pv.read(
            mesh_df.loc[mesh_df["CellId"] == closest_real_id]["mesh_path_noalign"].iloc[0]
        )
        mesh.save(f"./npm1_global/archetype/per_piece/{hh}_{i}_{closest_real_id}.ply")

In [None]:
from skimage.io import imread

img = imread(all_ret.loc[all_ret["CellId"] == 974872]["crop_seg_masked"].iloc[0])

In [None]:
all_ret["volume_of_nucleus_um3"] = all_ret["NUC_shape_volume_lcc"] * 0.108**3

In [None]:
bins = [(370.839, 577.444], (577.444, 784.05], (163.2, 370.839], (784.05, 990.656],  ]

In [None]:
all_ret["volume_of_nucleus_um3"].max()

In [None]:
all_ret["outlier"].value_counts()

In [None]:
[i for i in all_ret.columns if "outlier" in i]

In [None]:
feat = "volume_of_nucleus_um3"
upper = np.quantile(all_ret[feat], q=0.975)
lower = np.quantile(all_ret[feat], q=0.025)

this = all_ret.loc[all_ret[feat] < upper]
this = this.loc[this[feat] > lower].reset_index(drop=True)
# this = all_ret

# this = this.loc[this['CellId'] != 956566].reset_index(drop=True)

In [None]:
this[feat].hist()

In [None]:
this["vol_bins"] = pd.cut(this[feat], bins=5)
this["vol_bins_ind"] = pd.factorize(this["vol_bins"])[0]

# this['vol_bins'] = pd.qcut(this[feat], q=5)
# this['vol_bins_ind'] = pd.factorize(this['vol_bins'])[0]

In [None]:
this["vol_bins"].value_counts()

In [None]:
this["vol_bins_ind"].value_counts()

In [None]:
this["vol_bins_ind"].value_counts()

In [None]:
this["vol_bins"].astype(str).iloc[0]

In [None]:
all_features = this[[i for i in this.columns if "mu" in i]].values
this["vol_bins"] = this["vol_bins"].astype(str)
for hh in this["vol_bins"].unique():
    this_ret = this.loc[this["vol_bins"] == hh].reset_index(drop=True)

    this_mu = np.expand_dims(
        this_ret[[i for i in this_ret.columns if "mu" in i]].mean(axis=0), axis=0
    )
    dist = (all_features - this_mu) ** 2
    # dist = np.sum(dist, axis=1)
    k = 2
    # print(min(latent_dim, all_features.shape[0]))
    inds = np.argpartition(dist.sum(axis=-1), k)[:k]  # get 10 closest
    closest_samples = this.iloc[inds].reset_index(drop=True)
    for ind, row in closest_samples.iterrows():
        # closest_real_id = this.iloc[closest_idx]['CellId']
        closest_real_id = row["CellId"]
        print(
            closest_idx,
            this_ret["vol_bins"].unique(),
            all_features.shape,
            this_ret.shape,
            this_ret["NUC_shape_volume_lcc"].mean(),
            closest_real_id,
        )
        mesh = pv.read(
            mesh_df.loc[mesh_df["CellId"] == closest_real_id]["mesh_path_noalign"].iloc[0]
        )
        mesh.save(f"./npm1_global/vol_bin2/{hh}_{ind}_{closest_real_id}.ply")

In [None]:
dist.shape

In [None]:
this_mu.shape

In [None]:
this