In [2]:
import os, sys, yaml, time
import numpy as np
import joblib
import torch
from sklearn.cross_decomposition import CCA
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)
def read_yaml(): 
    ENV = os.getenv("MY_ENV", "dev") 
    with open("../../config.yaml", "r") as f: 
        config = yaml.safe_load(f) 
    paths = config[ENV]["paths"] 
    return paths 
paths = read_yaml()
sys.path.append(paths["src_path"])
from datetime import datetime
from dim_redu_anns.utils import get_relevant_output_layers

In [5]:
def CCA_loop_between_mod(model_names, pooling, num_components, pca_option, res_path):
    layer_names = [get_relevant_output_layers(m) for m in model_names]
    cca_dir = f"{res_path}/cca_{model_names[0]}_vs_{model_names[1]}_{pooling}"
    os.makedirs(cca_dir, exist_ok=True)
    layers_RSA = np.zeros((len(layer_names[0]),len(layer_names[1])))
    for layer_idx1 in range(len(layer_names[0])):
        target_layer1 = layer_names[0][layer_idx1]
        feats_path1 = f"{res_path}/imagenet_val_{model_names[0]}_{target_layer1}_{pooling}_features.pkl"
        print(
            datetime.now().strftime("%H:%M:%S"),
            f"starting loading {target_layer1}",
            flush=True
        )
        all_acts1 = joblib.load(feats_path1)
        print(
            datetime.now().strftime("%H:%M:%S"),
            f"finished loading {target_layer1}",
            flush=True
        )
        if pca_option==True:
            first_projection = False
            if all_acts1.shape[1] > 1000:
                PCs_path1 =f"{res_path}/imagenet_val_{model_names[0]}_{target_layer1}_{pooling}_pca_model_1000_PCs.pkl"
                print(
                    datetime.now().strftime("%H:%M:%S"),
                    f"starting loading PC1",
                    flush=True
                )
                PCs1 = joblib.load(PCs_path1)
                print(
                    datetime.now().strftime("%H:%M:%S"),
                    f"starting loading PC1",
                    flush=True
                )
                all_acts1 = all_acts1 @ PCs1.components_.T
                print(
                     datetime.now().strftime("%H:%M:%S"),
                     f"finished backprojecting in PCs1",
                     flush=True
                )
                first_projection = True

        for layer_idx2 in range(len(layer_names[1])):
            target_layer2 = layer_names[1][layer_idx2]
            print(datetime.now().strftime("%H:%M:%S"), f"starting layers {target_layer1} vs {target_layer2}")
            feats_path2 = f"{res_path}/imagenet_val_{model_names[1]}_{target_layer2}_{pooling}_features.pkl"
            all_acts2 = joblib.load(feats_path2)
            save_path = f"{cca_dir}/cca_{model_names[0]}_vs_{model_names[1]}_{num_components}_components_{target_layer1}_vs_{target_layer2}.pkl"
            if pca_option==True:
                second_projection = False
                if all_acts2.shape[1] > 1000:
                    PCs_path2 =f"{res_path}/imagenet_val_{model_names[1]}_{target_layer2}_{pooling}_pca_model_1000_PCs.pkl"
                    PCs2 = joblib.load(PCs_path2)
                    all_acts2 = all_acts2 @ PCs2.components_.T
                    second_projection = True

                if first_projection == True or second_projection == True:
                    save_path = f"{cca_dir}/cca_{model_names[0]}_vs_{model_names[1]}_{num_components}_components_pca_{target_layer1}_vs_{target_layer2}.pkl"
            if os.path.exists(save_path):
                print(
                    datetime.now().strftime("%H:%M:%S"),
                    f"CCA already exists for {target_layer1} vs {target_layer2}  at {save_path}",
                    flush=True
                )
                weights_dict = joblib.load(save_path)
                d1 = all_acts1 @ weights_dict["W1"]
                d2 = all_acts2 @ weights_dict["W2"]
                coefs_CCA = np.array([
                    np.corrcoef(d1[:, i], d2[:, i])[0, 1] for i in range(d1.shape[1])
                ])
                layers_RSA[layer_idx1, layer_idx2] = np.mean(coefs_CCA)
                print(datetime.now().strftime("%H:%M:%S"), f"{target_layer1} vs {target_layer2} corr {np.round(np.mean(coefs_CCA), 3)}", flush=True)

            else:
                print(
                     datetime.now().strftime("%H:%M:%S"),
                     f"starting CCA",
                     flush=True
                )
                cca = CCA(n_components = num_components)
                cca.fit(all_acts1, all_acts2)
                print(
                     datetime.now().strftime("%H:%M:%S"),
                     f"finished CCA fit",
                     flush=True
                )
                weights_dict = {}
                weights_dict["W1"] = cca.x_weights_  # shape: (n_features1, n_components)
                weights_dict["W2"] = cca.y_weights_  # shape: (n_features2, n_components)

                # 3. Project the data manually (optional, equivalent to fit_transform)
                d1 = all_acts1 @ weights_dict["W1"]
                d2 = all_acts2 @ weights_dict["W2"]
                coefs_CCA = np.array([
                    np.corrcoef(d1[:, i], d2[:, i])[0, 1] for i in range(d1.shape[1])
                ])
                weights_dict["coefs"] = coefs_CCA
                joblib.dump(weights_dict, save_path)
                layers_RSA[layer_idx1, layer_idx2] = np.mean(coefs_CCA)
                print(datetime.now().strftime("%H:%M:%S"), f"{target_layer1} vs {target_layer2} corr {np.round(np.mean(coefs_CCA), 3)}", flush=True)
    csv_save_path = f"{cca_dir}/{model_names[0]}_vs_{model_names[1]}_similarity_layers.csv"
    if pca_option == True:
        csv_save_path = f"{cca_dir}/{model_names[0]}_vs_{model_names[1]}_similarity_layers_pca.csv"
    np.savetxt(csv_save_path, layers_RSA, delimiter=",")


In [None]:
model_names = ["alexnet", "alexnet"]
pooling = "maxpool"
num_components = 50
pca_option = True
res_path = paths["results_path"]
CCA_loop_between_mod(model_names, pooling, num_components, pca_option, res_path)

09:55:21 starting loading features.0
09:55:21 finished loading features.0
09:55:21 starting layers features.0 vs features.0
inside outer if
09:55:21 starting CCA
09:55:28 finished CCA fit
09:55:28 features.0 vs features.0 corr 1.0
09:55:28 starting layers features.0 vs features.4
inside outer if
09:55:28 starting CCA
09:56:14 finished CCA fit
09:56:14 features.0 vs features.4 corr 0.412
09:56:14 starting layers features.0 vs features.7
inside outer if
09:56:14 starting CCA
09:57:42 finished CCA fit
09:57:42 features.0 vs features.7 corr 0.281
09:57:42 starting layers features.0 vs features.9
inside outer if
09:57:42 starting CCA
09:58:38 finished CCA fit
09:58:38 features.0 vs features.9 corr 0.207
09:58:38 starting layers features.0 vs features.11
inside outer if
09:58:38 starting CCA
09:59:31 finished CCA fit
09:59:31 features.0 vs features.11 corr 0.186
09:59:31 starting layers features.0 vs classifier.2
inside outer if
09:59:46 starting CCA
10:05:52 finished CCA fit
10:05:52 featur



10:17:08 finished CCA fit
10:17:08 features.4 vs features.7 corr 0.587
10:17:08 starting layers features.4 vs features.9
inside outer if
10:17:08 starting CCA




10:18:54 finished CCA fit
10:18:54 features.4 vs features.9 corr 0.432
10:18:54 starting layers features.4 vs features.11
inside outer if
10:18:54 starting CCA
10:20:41 finished CCA fit
10:20:42 features.4 vs features.11 corr 0.461
10:20:42 starting layers features.4 vs classifier.2
inside outer if
10:20:45 starting CCA
10:29:01 finished CCA fit
10:29:02 features.4 vs classifier.2 corr 0.503
10:29:02 starting layers features.4 vs classifier.5
inside outer if
10:29:05 starting CCA
10:36:58 finished CCA fit
10:36:58 features.4 vs classifier.5 corr 0.486
10:36:58 starting loading features.7
10:36:58 finished loading features.7
10:36:58 starting layers features.7 vs features.0
inside outer if
10:36:58 starting CCA
10:38:24 finished CCA fit
10:38:24 features.7 vs features.0 corr 0.282
10:38:24 starting layers features.7 vs features.4
inside outer if
10:38:24 starting CCA




10:41:18 finished CCA fit
10:41:18 features.7 vs features.4 corr 0.587
10:41:18 starting layers features.7 vs features.7
inside outer if
10:41:18 starting CCA
10:43:04 finished CCA fit
10:43:04 features.7 vs features.7 corr 1.0
10:43:04 starting layers features.7 vs features.9
inside outer if
10:43:04 starting CCA




10:46:27 finished CCA fit
10:46:27 features.7 vs features.9 corr 0.634
10:46:27 starting layers features.7 vs features.11
inside outer if
10:46:27 starting CCA




10:49:56 finished CCA fit
10:49:56 features.7 vs features.11 corr 0.647
10:49:56 starting layers features.7 vs classifier.2
inside outer if
10:50:00 starting CCA
