In [None]:
import random

import os
import sys


os.environ['MOMAPS_HOME'] = '/home/labs/hornsteinlab/Collaboration/MOmaps_Sagy/MOmaps'
os.environ['MOMAPS_DATA_HOME'] = '/home/labs/hornsteinlab/Collaboration/MOmaps/input'

sys.path.insert(1, os.getenv("MOMAPS_HOME"))
print(f"MOMAPS_HOME: {os.getenv('MOMAPS_HOME')}")


from sandbox.eval_new_arch.dino4cells.utils import utils
import torch.backends.cudnn as cudnn
from sandbox.eval_new_arch.dino4cells.archs import vision_transformer as vits
import numpy as np
import logging
import torch

from src.common.lib.dataset import Dataset
from src.datasets.dataset_spd import DatasetSPD
from src.common.lib.utils import load_config_file
from src.common.lib.data_loader import get_dataloader
import matplotlib.pyplot as plt 
%matplotlib inline
from umap import UMAP

%load_ext autoreload
%autoreload 2

In [None]:


def plot_umap_all_conditions(preds, true_labels):
    classes = np.unique(true_labels)
    reducer = UMAP(n_components=2, random_state=1)
    reducer.fit(preds)

    preds_umap = reducer.transform(preds)

    colors = ['blue', 'orange', 'red', 'black', 'green', 'purple', 'pink', 'brown', 'grey']
    s = 30
    for ci, c in enumerate(classes):
        c_indx = np.where(true_labels == c)
        preds_c = preds[c_indx]

        print(f"class: {c}, shappe: {preds_c.shape}") 
        plt.scatter(preds_umap[c_indx, 0],  preds_umap[c_indx, 1], alpha=0.7, s=s, c=colors[ci%len(colors)])
    plt.legend(classes, bbox_to_anchor=(1.1, 1.05))
    plt.show()



In [None]:
chkp_path = "/home/labs/hornsteinlab/Collaboration/MOmaps_Sagy/MOmaps/sandbox/eval_new_arch/vit_contrastive/checkpoints/checkpoints_130724_124329_509907_vit_tiny_opencell/checkpoint_best.pth"
config_path_data = "./src/datasets/configs/training_data_config/OpenCellTrainContrastiveDatasetConfig"


In [None]:

config_data = load_config_file(config_path_data)

cudnn.benchmark = False
random.seed(1)

config_data = load_config_file(config_path_data)

dataset = DatasetSPD(config_data, None)


if config_data.SPLIT_DATA:
    logging.info("Split data...")
    train_indexes, val_indexes, test_indexes = dataset.split()
    dataset_test_subset = Dataset.get_subset(dataset, test_indexes)
else:
    dataset_test_subset = dataset


# FILTER BY MARKERS #
markers_to_test = ['G3BP1', 'CLTC']
markers_to_test_full = [f'{m}_WT_Untreated' for m in markers_to_test]

markers_samples_indexes = np.where(np.isin(dataset_test_subset.label, markers_to_test_full))[0]
dataset_test_subset = deepcopy(dataset_test_subset)
dataset_test_subset.X_paths = dataset_test_subset.X_paths[markers_samples_indexes]
dataset_test_subset.label = dataset_test_subset.label[markers_samples_indexes]
dataset_test_subset.y = dataset_test_subset.y[markers_samples_indexes]
print(f"unique markers subset: {dataset_test_subset.unique_markers} ({len(dataset_test_subset.unique_markers)}) {np.unique(dataset_test_subset.label)}")
####


data_loader_test = get_dataloader(dataset_test_subset, 10, num_workers=1, shuffle=False)
print(f"unique markers: {dataset_test_subset.unique_markers} ({len(dataset_test_subset.unique_markers)})")

In [None]:
config = {
        'seed': 1,
        'embedding': {
            'image_size': 100
        },
        'num_classes': 128,
        'patch_size': 14,
        'num_channels': 2,
        'out_dim': 225,
        'use_bn_in_head': False,
        'norm_last_layer': True,
        
        'local_crops_number': 8,
        'warmup_teacher_temp': 0.04,
        'teacher_temp': 0.04,
        'warmup_teacher_temp_epochs': 0,
        'epochs': 5,
        'student_temp': 0.1,
        'center_momentum': 0.9,
        'momentum_teacher': 0.996,
        
        'lr': 0.0005,
        'min_lr': 1e-6,
        'warmup_epochs': 10,
        
        'weight_decay': 0.04,
        'weight_decay_end': 0.4,
        
        
        'batch_size_per_gpu': 4,
        'num_workers': 1
    }
    
class DictToObject:
    def __init__(self, dict_obj):
        for key, value in dict_obj.items():
            if isinstance(value, dict):
                # Recursively convert dictionaries to objects
                setattr(self, key, DictToObject(value))
            else:
                setattr(self, key, value)    

config = DictToObject(config)

model = vits.vit_tiny(
            # img_size=[config.embedding.image_size, config.embedding.image_size*len(dataset.unique_markers)],
            img_size=[config.embedding.image_size, config.embedding.image_size],
            patch_size=config.patch_size,
            # drop_path_rate=0.1,  # stochastic depth
            # drop_rate=0.3, # can't go together with drop_path_rate - cuda out of memory
            # in_chans=config.num_channels,
            in_chans=config.num_channels,#*len(dataset.unique_markers),
            # num_classes=len(dataset.unique_markers)#256
            num_classes=config.num_classes if config.num_classes is not None else len(dataset.unique_markers)
    ).cuda()



model = utils.load_model_from_checkpoint(chkp_path, model)


In [None]:

n_batches = 30

preds_array = []
lbls_array = []
hidden_array = []

with torch.no_grad():
    model.eval()
    
    for i, sample in enumerate(data_loader_test):
        inpt = sample['image'].to(torch.float).cuda()
        lbls = sample['label'].cpu().detach().numpy()
        lbls = dataset.id2label(lbls)
        preds, hidden = model(inpt, return_hidden=True)
        preds = preds.cpu()
        hidden = hidden.cpu()
        preds_array.append(preds)
        hidden_array.append(hidden)
        lbls_array.extend(lbls)
        print(preds.shape, lbls.shape, np.unique(lbls), hidden.shape)
        
        if i >= n_batches:
            break
    
preds_array = np.vstack(preds_array)
hidden_array = np.vstack(hidden_array)
lbls_array = np.asarray(lbls_array)

preds_array.shape, hidden_array.shape, lbls_array.shape, np.unique(lbls_array)

In [None]:
# vit_tiny contrastive (128 features)
# opencell test

plot_umap_all_conditions(preds_array, lbls_array)