In [4]:
import os
import logging
import pickle

import numpy as np 
from dotenv import load_dotenv

from code.model_activations.models.utils import load_model, load_full_identifier
from code.model_activations.activation_extractor import Activations
from code.encoding_score.regression.get_betas import NeuralRegression
from code.encoding_score.regression.scores_tools import get_bootstrap_rvalues
from code.eigen_analysis.compute_pcs import compute_model_pcs
from demo.model_configs import analysis_cfg as cfg
from config import RESULTS_PATH

load_dotenv()
CACHE = os.getenv("CACHE")

In [2]:
MODEL_NAME = 'expansion'
DATASET = 'majajhong_demo'
DEVICE = 'cuda'
BATCH_SIZE = 50
N_BOOTSTRAPS = 1000
N_ROWS = cfg[DATASET]['test_data_size']
ALL_SAMPLED_INDICES = np.random.choice(N_ROWS, (N_BOOTSTRAPS, N_ROWS), replace=True) 

In [3]:
for features in cfg[DATASET]['analysis']['pca']['features']:

    TOTAL_COMPONENTS = 10 
    N_COMPONENTS = list(np.logspace(0, np.log10(TOTAL_COMPONENTS), num=int(np.log10(TOTAL_COMPONENTS)) + 1, base=10).astype(int))
    
    pca_identifier = load_full_identifier(model_name=MODEL_NAME, 
                                                features=features, 
                                                layers=cfg[DATASET]['analysis']['pca']['layers'], 
                                                dataset=DATASET,
                                                principal_components = TOTAL_COMPONENTS)

    # compute model PCs using the train set
    if not os.path.exists(os.path.join(CACHE,'pca',pca_identifier)):
        compute_model_pcs(model_name = MODEL_NAME, 
                            features = features, 
                            layers = cfg[DATASET]['analysis']['pca']['layers'], 
                            dataset = DATASET, 
                            components = TOTAL_COMPONENTS, 
                            device = DEVICE,
                            batch_size=BATCH_SIZE)
        
    # project activations onto the computed PCs 
    for n_components in N_COMPONENTS:
        
        activations_identifier = load_full_identifier(model_name=MODEL_NAME, 
                                                features=features, 
                                                layers=cfg[DATASET]['analysis']['pca']['layers'], 
                                                dataset=DATASET,
                                                principal_components = n_components)            
        
        logging.info(f"Model: {activations_identifier}, Components = {n_components}, Region: {cfg[DATASET]['regions']}")
        #load model
        model = load_model(model_name=MODEL_NAME, 
                            features=features, 
                                layers=cfg[DATASET]['analysis']['pca']['layers'],
                                device=DEVICE)

        # compute activations and project onto PCs
        Activations(model=model, 
                    dataset=DATASET, 
                    pca_iden = pca_identifier,
                    n_components = n_components, 
                    batch_size = BATCH_SIZE,
                    device= DEVICE).get_array(activations_identifier)  


        # predict neural data in a cross validated manner using model PCs
        NeuralRegression(activations_identifier=activations_identifier,
                            dataset=DATASET,
                            region=cfg[DATASET]['regions'],
                            device= DEVICE).predict_data()
        

# get a bootstrap distribution of r-values between predicted and actual neural responses
get_bootstrap_rvalues(model_name= MODEL_NAME,
                features=cfg[DATASET]['analysis']['pca']['features'],
                layers = cfg[DATASET]['analysis']['pca']['layers'],
                principal_components=[1,10],
                dataset=DATASET, 
                subjects=cfg[DATASET]['subjects'],
                region=cfg[DATASET]['regions'],
                all_sampled_indices=ALL_SAMPLED_INDICES,
                device=DEVICE,
                file_name= 'pca')

2024-06-27 15:20:42,663 - INFO - Model: expansion_features=3_layers=5_dataset=majajhong_demo_principal_components=1, Components = 1, Region: IT
2024-06-27 15:20:43,678 - INFO - Activations already exist
2024-06-27 15:20:43,678 - INFO - Predicting neural data from model activations...
100%|██████████| 2/2 [00:00<00:00, 14614.30it/s]
  0%|          | 0/2 [00:00<?, ?it/s]
2024-06-27 15:20:43,683 - INFO - Model: expansion_features=3_layers=5_dataset=majajhong_demo_principal_components=10, Components = 10, Region: IT
2024-06-27 15:20:44,615 - INFO - Loading processed images...
2024-06-27 15:20:44,647 - INFO - Extracting activations...
100%|██████████| 1/1 [00:03<00:00,  3.21s/it]
2024-06-27 15:20:47,862 - INFO - Model activations are saved in cache
2024-06-27 15:20:47,945 - INFO - Predicting neural data from model activations...
100%|██████████| 2/2 [00:00<00:00, 13.71it/s]
100%|██████████| 2/2 [00:00<00:00, 13.50it/s]
2024-06-27 15:20:48,095 - INFO - Computing bootstrap distribution of r-v

/home/atlask/data/atlas/.cache expansion_features=3_layers=5_dataset=majajhong_demo_principal_components=10
/home/atlask/data/atlas/.cache expansion_features=3_layers=5_dataset=majajhong_demo_principal_components=10
/home/atlask/data/atlas/.cache expansion_features=3_layers=5_dataset=majajhong_demo_principal_components=10
/home/atlask/data/atlas/.cache expansion_features=3_layers=5_dataset=majajhong_demo_principal_components=10


100%|██████████| 2/2 [00:03<00:00,  1.78s/it]
100%|██████████| 2/2 [00:03<00:00,  1.65s/it]
2024-06-27 15:20:54,958 - INFO - Bootstrap r-values are now saved in cache


In [5]:
# printing the output:
file_path = os.path.join(RESULTS_PATH, 'pca_majajhong_demo_IT.pkl')
with open(file_path, 'rb') as file:
    pca_expansion_score = pickle.load(file)
display(pca_expansion_score)

Unnamed: 0,model,features,pcs,init_type,nl_type,score,lower,upper
0,expansion_features=3_layers=5_dataset=majajhon...,3,1,kaiming_uniform,relu,tensor(-0.0860),-0.269513,0.122117
1,expansion_features=3_layers=5_dataset=majajhon...,3,10,kaiming_uniform,relu,tensor(-0.0747),-0.266319,0.139583
