In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="MIG-c5d3210f-194b-58d7-b64c-80067ff44d0a" 
from hydra.utils import instantiate
import yaml
import torch
import numpy as np
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from PIL import Image
device = "cuda:0"

In [None]:
from src.data.get_datamodules import get_data
from src.models.load_models import load_models
from cyto_dl.models.utils.mlflow import load_model_from_checkpoint
from src.models.save_embeddings import get_pc_loss

def get_data_and_models(dataset_name, batch_size):
    data = get_data(dataset_name, batch_size)
    data_list = [data[0], data[1], data[-2], data[-2], data[-1]] # custom list based on models loaded
    all_models, run_names = load_models(dataset_name) # default list of models in load_models.py
    return data_list, all_models, run_names

In [None]:
dataset_name = 'pcna'
batch_size = 1
data_list, all_models, run_names = get_data_and_models(dataset_name, batch_size)

In [None]:
from src.models.save_embeddings import save_embeddings

# compute embeddings, set debug to False 
debug = True
# path = "./embeddings_pcna_vit"
path = "./test"
splits_list = ['test']
outs = save_embeddings(path, data_list, all_models, run_names, debug, splits_list, device=device)

In [None]:
from src.models.save_embeddings import save_emissions

# get emissions
debug = True
# path = "./emissions_pcna_vit"
path = "./test"
save_emissions(path, data_list, all_models, run_names, 20, debug, device)

In [None]:
# viz
import pandas as pd
emissions = pd.read_csv(path + '/emissions.csv')
import seaborn as sns
g = sns.catplot(data=emissions,x='model', y='emissions', kind='point')
g.set_xticklabels(rotation=30)


In [None]:
# compute features on saved embeddings
# embedding_save_location should be updared in DATASET_INFO in compute_features.py
# DATASET_INFO also has
# orig_df: original dataframe which often has somee useful metadata
# image_path: manifest with paths for image models
# pc_path: manifest with paths for pc models

from src.models.compute_features import compute_features

# evolve params
keys = ['pcloud', 'pcloud', 'image', 'image', 'image']
data_config_list = [
        "../data/configs/inference_pcna_data_configs/pointcloud_3.yaml",
        "../data/configs/inference_pcna_data_configs/pointcloud_4.yaml",
        "../data/configs/inference_pcna_data_configs/image_full.yaml",
        "../data/configs/inference_pcna_data_configs/image_full.yaml",
        "../data/configs/inference_pcna_data_configs/mae.yaml",
]
evolve_params = {'modality_list_evolve': keys, 'config_list_evolve': data_config_list, 'num_evolve_samples': 2}

# classification params
classification_params = {'class_label': 'cell_stage_fine'}

# rot inv params
rot_inv_params = {'squeeze_2d': False, 'id': 'cell_id'}

# stereotypy params
#stereotypy_params = {'max_pcs': 2, 'max_bins': 9, 'get_baseline': True, 'return_correlation_matrix': False}

# regression params
# feature_df_path = "/allen/aics/assay-dev/MicroscopyOtherData/Viana/projects/cvapipe_analysis/local_staging_variance/computefeatures/manifest.csv"
# target_cols = ['STR_connectivity_cc',
#  'STR_shape_volume',
#  'STR_position_depth',
#  'STR_position_height',
#  'STR_position_width',
#  'STR_roundness_surface_area']
# # regression_params = {'df_feat': df_feat, 'target_cols': target_cols, 'feature_df_path': feature_df_path}
# regression_params = {'df_feat': None, 'target_cols': target_cols, 'feature_df_path': feature_df_path}


# general params
save_folder = './test/'
max_embed_dim = 192
splits_list = ['train', 'val', 'test']
metric_list = [
    "Rotation Invariance Error",
    "Evolution Energy",
    "Reconstruction",
    "Classification",
    "Compactness",
]
compute_embeds=False

compute_features(
    dataset=dataset_name,
    save_folder=save_folder,
    data_list=data_list,
    all_models=all_models,
    run_names=run_names,
    keys=keys,
    device="cuda:0",
    max_embed_dim=max_embed_dim,
    splits_list=splits_list,
    compute_embeds=compute_embeds,
    classification_params=classification_params,
    # regression_params=regression_params,
    metric_list=metric_list,
    evolve_params=evolve_params,
    rot_inv_params=rot_inv_params,
    # stereotypy_params=stereotypy_params
)