In [None]:
%load_ext autoreload
%autoreload 2
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "MIG-ffdee303-0dd4-513d-b18c-beba028b49c7"
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from hydra.utils import instantiate
from PIL import Image
from torch.utils.data import DataLoader, Dataset

device = "cuda:0"

# Load data and models

In [None]:
os.chdir("/allen/aics/modeling/ritvik/projects/benchmarking_representations/")

In [None]:
from cyto_dl.models.utils.mlflow import load_model_from_checkpoint

from br.data.get_datamodules import get_data
from br.models.load_models import load_model_from_path
from br.models.save_embeddings import get_pc_loss


def get_data_and_models(dataset_name, batch_size, results_path, debug=False):
    data_list = get_data(dataset_name, batch_size, results_path, debug)
    all_models, run_names, model_sizes = load_model_from_path(
        dataset_name, results_path
    )  # default list of models in load_models.py
    return data_list, all_models, run_names, model_sizes

In [None]:
dataset_name = "cellpack"
batch_size = 2
debug = True
results_path = (
    "/allen/aics/modeling/ritvik/projects/benchmarking_representations/br/configs/results/"
)
data_list, all_models, run_names, model_sizes = get_data_and_models(
    dataset_name, batch_size, results_path, debug
)

# Compute embeddings and emissions

In [None]:
from br.models.save_embeddings import save_embeddings

debug = True
splits_list = ["test"]
meta_key = "rule"
save_path = "./test_cellpack_save_embeddings/"
eval_scaled_img = [False] * 5
eval_scaled_img_params = [{}] * 5
loss_eval_list = None
sample_points_list = [True, True, False, False, False]
skew_scale = 100
save_embeddings(
    save_path,
    data_list,
    all_models,
    run_names,
    debug,
    splits_list,
    device,
    meta_key,
    loss_eval_list,
    sample_points_list,
    skew_scale,
    eval_scaled_img,
    eval_scaled_img_params,
)

In [None]:
from br.models.save_embeddings import save_emissions

max_batches = 2
save_emissions(
    save_path,
    data_list,
    all_models,
    run_names,
    max_batches,
    debug,
    device,
    loss_eval_list,
    sample_points_list,
    skew_scale,
    eval_scaled_img,
    eval_scaled_img_params,
)

# Compute benchmarking features

In [None]:
from br.models.compute_features import compute_features
from br.models.save_embeddings import get_pc_loss_chamfer
from br.models.utils import get_all_configs_per_dataset

keys = ["pcloud"] * 5
max_embed_dim = 256
DATA_LIST = get_all_configs_per_dataset(results_path)
data_config_list = DATA_LIST[dataset_name]["data_paths"]

evolve_params = {
    "modality_list_evolve": keys,
    "config_list_evolve": data_config_list,
    "num_evolve_samples": 40,
    "compute_evolve_dataloaders": False,
    "eval_meshed_img": [False] * 5,
    "skew_scale": 100,
    "eval_meshed_img_model_type": [None] * 5,
    "only_embedding": False,
    "fit_pca": False,
}

loss_eval = get_pc_loss_chamfer()
# loss_eval_list = [torch.nn.MSELoss(reduction='none')]*2 + [loss_eval, loss_eval]
loss_eval_list = [loss_eval] * 5
use_sample_points_list = [True, True, False, False, False]

classification_params = {"class_label": "rule"}
rot_inv_params = {"squeeze_2d": False, "id": "cell_id"}

regression_params = {"df_feat": None, "target_cols": None, "feature_df_path": None}

compactness_params = {
    "method": "mle",
    "num_PCs": None,
    "blobby_outlier_max_cc": None,
    "check_duplicates": True,
}

splits_list = ["train", "val", "test"]
compute_embeds = False

metric_list = [
    "Rotation Invariance Error",
    "Evolution Energy",
    "Reconstruction",
    "Classification",
    "Compactness",
]


compute_features(
    dataset=dataset_name,
    results_path=results_path,
    embeddings_path=save_path,
    save_folder=save_path,
    data_list=data_list,
    all_models=all_models,
    run_names=run_names,
    use_sample_points_list=use_sample_points_list,
    keys=keys,
    device=device,
    max_embed_dim=max_embed_dim,
    splits_list=splits_list,
    compute_embeds=compute_embeds,
    classification_params=classification_params,
    regression_params=regression_params,
    metric_list=metric_list,
    loss_eval_list=loss_eval_list,
    evolve_params=evolve_params,
    rot_inv_params=rot_inv_params,
    compactness_params=compactness_params,
)