In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="MIG-7d67cf82-0c6b-5aee-99ce-cab54f47f0b6" 
from hydra.utils import instantiate
import yaml
import torch
from src.models.compute_features import get_embeddings
import pandas as pd
from src.data.get_datamodules import CONFIG_LIST
from src.features.diagnosticsheet import diagnostic_sheet
import warnings
import matplotlib.pyplot as plt
import numpy as np
device = "cuda:0"

plt.style.use("dark_background")
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
def custom_contrast_adjust(img):
    img = np.where(img < 10000, img, 0)
    img = np.where(img > 0, img, 0)
    return img

def get_data2(dataset_name, batch_size):
    config_list = CONFIG_LIST[dataset_name]
    data = []
    for config_path in config_list:
        with open(config_path, "r") as stream:
            config = yaml.safe_load(stream)
            config['transforms'] = config['transforms'][:1]
            if batch_size:
                config["batch_size"] = batch_size
                config["shuffle"] = False
            data.append(instantiate(config))
    return data

In [None]:
orig_df = pd.read_parquet("/allen/aics/modeling/ritvik/variance_punctate/one_step/manifest.parquet")

In [None]:
orig_df

In [None]:
from tqdm import tqdm

run_names = ['2048_ed_mae', '2048_int_ed_vndgcnn',
            'classical_image', 'so2_image', 'vit']
dataset = 'variance'
all_ret, _ = get_embeddings(run_names, dataset)

eval_run_names = ['2048_int_ed_vndgcnn']
eval_run_names = ['classical_image', 'so2_image', '2048_ed_mae', '2048_int_ed_vndgcnn']
data_inds = [3, 3, 0, 1]
save_paths = ['./viz/variance/classical_image/', './viz/variance/so2_image/', './viz/variance/classical_pcloud/', './viz/variance/so3_pcloud/']

for j, run in enumerate(tqdm(eval_run_names)):
    this_mo =all_ret.loc[all_ret['model'] == run].reset_index(drop=True)

    cols = [i for i in all_ret.columns if "mu" in i]
    embeddings = torch.tensor(this_mo[cols].dropna(axis=1).values)
    labels = torch.tensor(pd.factorize(this_mo['structure_name'])[0])

    # get embedding distances and sort
    dists = torch.cdist(embeddings.to(device), embeddings.to(device), p=2).cpu()
    dist_argsort = torch.argsort(dists.to(device), dim=1).cpu()

    # for first 5 test set samples, look at the top 2 nearest neighbors
    k=2
    test_idxs = torch.arange(5)
    this_dist_argsort = dist_argsort[test_idxs,:k+1]
    close_inds = this_dist_argsort.flatten()

    data = get_data2('variance', 1)
    this_data = data[data_inds[j]]

    bb = this_data.get_dataset('test')

    all_cell_ids = []
    all_samples = []
    all_closest = []
    for ids in close_inds:
        all_cell_ids.append(bb[ids]['cell_id'])
        sample = np.where(this_dist_argsort.numpy() == ids.item())[0]
        closest = np.where(this_dist_argsort.numpy() == ids.item())[1]
        this_closest_ind = np.where(closest==closest.min())[0]
        this_closest = closest[this_closest_ind][0]
        this_sample = sample[this_closest_ind][0]
        all_samples.append(this_sample)
        all_closest.append(this_closest)

    tmp_df = pd.DataFrame()
    tmp_df['CellId'] = all_cell_ids
    tmp_df['index'] = close_inds
    tmp_df['SampleId']= all_samples
    tmp_df['ClosestId'] = all_closest
    tmp_df = tmp_df.merge(orig_df, on='CellId')

    df_sample = tmp_df.loc[tmp_df['ClosestId'].isin([0,1,2])]
    df_sample['CellId'] = df_sample['CellId'] + df_sample.index
    # df_sample = df_sample.drop_duplicates().reset_index(drop=True)

    diagnostic_sheet(df_sample, 
                    save_dir = save_paths[j], # Created a diagnostic_sheets folder in the current working dir
                    image_column = "registered_path", # Pass in the 3D image path or one of the 2D image paths like max_projection_x
                    max_cells=5, # max cells per sheet
                    #  channels = [6,7,8], # DNA, Membrane, Structure intensity channels
                    #  colors = [[0, 1, 1], [1, 0, 1], [1,1,1]], # Cyan, Magenta, White
                    #  channels = [6, 8], # DNA, Membrane, Structure intensity channels
                    #  colors = [[0, 1, 1], [1,1,1]], # Cyan, Magenta, White
                    channels = [2], # DNA, Membrane, Structure intensity channels
                    colors = [[1,1,1]], # Cyan, Magenta, White
                    metadata ='ClosestId',
                    proj_method = "max", # options - max, mean, sum
                    #  metadata = "edge_flag", # Optional, Metadata to stratify the diagnostic sheets
                    feature = "SampleId", # Optional, Feature to add as text,
                    fig_width = None, # Default is number of columns * 7
                    fig_height = None, # Default is number of rows * 5,
                    distributed_executor_address = None, # An optional executor address to pass to some computation engine.
                    batch_size = None, # process all at once
                    contrast_adjust = custom_contrast_adjust,
                    overwrite=True)

