# Notebook Setup

In [None]:
!hostname

In [None]:
%load_ext autoreload
%autoreload 2

# Imports

In [None]:
import sys
import os
NOVA_HOME = "/home/projects/hornsteinlab/Collaboration/NOVA"
os.environ["NOVA_HOME"] = NOVA_HOME
sys.path.insert(1, NOVA_HOME)
print(os.getenv("NOVA_HOME"))


In [None]:
from tools.attn_maps_plotting.get_attn_from_paths import generate_attn_maps_with_model
from src.common.utils import load_config_file
from src.datasets.dataset_config import DatasetConfig
from src.figures.plot_config import PlotConfig
from src.attention_maps.attention_config import AttnConfig

# Parmas Setup

Define here the paramaters controlling which data and how it will be displayed.

In [None]:

MODEL_NAMES = ["finetunedModel_MLPHead_acrossBatches_B56789_80pct_frozen"] #["pretrained_model"]
DATASET_CONFIG_NAME = 'EmbeddingsAlyssaCoyneDatasetConfig'

BATCH_SIZE = 50

SAVE_PLOT:bool = False # save attn figures in the target directory

SHOW_PLOT:bool = True # display attn figures in notebook

DISPLAY_ATTN_SCORE:bool = False # display attn scores on top of the image

NUM_WORKERS:int = 1 # multiprocessing number of workers 


"""
dictonary consists of -
      keys:     description (str, the subdir name the figures will be saved in)
      values:   list of paths (each path is a str of the format - [path/to/<file_name>.npy/<tile_number>])
"""
PATH_DICT = {
        "Test_Run":
        [
            "/home/projects/hornsteinlab/Collaboration/NOVA/input/images/processed/ManuscriptFinalData_80pct/AlyssaCoyne/batch1/Controls/Untreated/DCP1A/rep6_R11_w2confCy5_s57_panelA_Controls_processed.npy/4",
            "/home/projects/hornsteinlab/Collaboration/NOVA/input/images/processed/ManuscriptFinalData_80pct/AlyssaCoyne/batch1/Controls/Untreated/DCP1A/rep6_R11_w2confCy5_s57_panelA_Controls_processed.npy/5",
            "/home/projects/hornsteinlab/Collaboration/NOVA/input/images/processed/ManuscriptFinalData_80pct/AlyssaCoyne/batch1/Controls/Untreated/DCP1A/rep6_R11_w2confCy5_s57_panelA_Controls_processed.npy/6",
            "/home/projects/hornsteinlab/Collaboration/NOVA/input/images/processed/ManuscriptFinalData_80pct/AlyssaCoyne_new/batch1/Ctrl-EDi022/Untreated/DCP1A/rep1_control EDi022_DAPI_DCP1A_Map2_TDP-43_1-Orthogonal Projection-62-Image Export-11_c3_s11_panelA_Ctrl-EDi022_processed.npy/0"

        ]

    }

# Paths and Configs

In [None]:

# DEFINE PATHS
MODEL_DIR = os.path.join(os.getenv("NOVA_HOME"), "outputs", "vit_models")
CONFIG_PATH_DATA = os.path.join('manuscript','embeddings_config',DATASET_CONFIG_NAME)
CONFIG_PATH_ATTN = os.path.join('manuscript','manuscript_attention_config', 'BaseAttnConfig')
CONFIG_PATH_PLOT = os.path.join('tools','attn_maps_plotting','manuscript_plot_attention_map_config','BaseAttnMapPlotConfig'
)
CONFIG_PATH_CORR = os.path.join('manuscript','manuscript_attn_corr_scores_config', 'AttnScoresConfig') 


In [None]:
# load configs
config_data:DatasetConfig = load_config_file(CONFIG_PATH_DATA, "data")
config_attn:AttnConfig = load_config_file(CONFIG_PATH_ATTN, "data")
config_plot:PlotConfig = load_config_file(CONFIG_PATH_PLOT, "plot")

if DISPLAY_ATTN_SCORE:
    config_corr = load_config_file(CONFIG_PATH_CORR, "data") 
else:
    config_corr = None

# update arguments
config_plot.SAVE_PLOT = SAVE_PLOT
config_plot.SHOW_PLOT = SHOW_PLOT
config_plot.PLOT_ATTN_NUM_WORKERS = NUM_WORKERS

# Main

In [None]:

for model_name in MODEL_NAMES:
    outputs_folder_path = os.path.join(MODEL_DIR, model_name)
    config_data.OUTPUTS_FOLDER = outputs_folder_path

    try:
            print(f"Starting generate attention maps for {model_name}...")
            generate_attn_maps_with_model(PATH_DICT, outputs_folder_path, config_data, 
                                        config_attn, config_plot,
                                        config_corr, BATCH_SIZE)
            
    except Exception as e:
            print(e)
            raise e

print("Done")


In [None]:
# rollout = rollout @ layer_attn