In [1]:
%load_ext autoreload
%autoreload 2


In [8]:
import os
import sys
sys.path.append("/n/home12/binxuwang/Github/Closed-loop-visual-insilico")
import timm
import torch
import torch as th
import torch.nn as nn
from torchvision.models.feature_extraction import create_feature_extractor
from tqdm.auto import tqdm
from os.path import join
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
from horama import maco, plot_maco
import torchvision.transforms as T
from torchvision.transforms import ToPILImage, ToTensor, Normalize, Resize
from torchvision.models import resnet50
from circuit_toolkit.CNN_scorers import TorchScorer
from circuit_toolkit.GAN_utils import upconvGAN, Caffenet
from circuit_toolkit.plot_utils import to_imgrid, show_imgrid, save_imgrid, saveallforms
from circuit_toolkit.layer_hook_utils import featureFetcher_module, featureFetcher, get_module_names
from circuit_toolkit.dataset_utils import ImagePathDataset
from torch.utils.data import DataLoader
from neural_regress.regress_lib import sweep_regressors, perform_regression_sweeplayer_RidgeCV, perform_regression_sweeplayer
from neural_regress.sklearn_torchify_lib import SRP_torch, PCA_torch, LinearRegression_torch, SpatialAvg_torch, LinearLayer_from_sklearn

import seaborn as sns
import pandas as pd

import sklearn
from sklearn.pipeline import make_pipeline
from sklearn.random_projection import SparseRandomProjection, GaussianRandomProjection
from sklearn.linear_model import Ridge, Lasso
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import RidgeCV
from sklearn.decomposition import PCA
from sklearn.kernel_ridge import KernelRidge

#%% Utility Functions

def load_neural_data(data_path, subject_id, stimroot):
    """Load neural data and image file paths."""
    from core.data_utils import load_from_hdf5
    data = load_from_hdf5(data_path)
    # Meta data
    brain_area = data[subject_id]["neuron_metadata"]["brain_area"]
    ncsnr = data[subject_id]["neuron_metadata"]["ncsnr"]
    reliability = data[subject_id]["neuron_metadata"]["reliability"]
    # Display parameters
    stim_pos = data[subject_id]['trials']['stimulus_pos_deg']
    stim_size = data[subject_id]['trials']['stimulus_size_pix']
    # Response data
    resp_mat = data[subject_id]['repavg']['response_peak']  # Peak, avg response
    resp_temp_mat = data[subject_id]['repavg']['response_temporal']  # Temporal response
    stimulus_names = data[subject_id]['repavg']['stimulus_name']
    image_fps = [f"{stimroot}/{stimname.decode('utf8')}" for stimname in stimulus_names]
    return {
        'brain_area': brain_area,
        'ncsnr': ncsnr,
        'reliability': reliability,
        'stim_pos': stim_pos,
        'stim_size': stim_size,
        'resp_mat': resp_mat,
        'resp_temp_mat': resp_temp_mat,
        'image_fps': image_fps,
    }


@th.no_grad()
def record_features(model, fetcher, dataset, batch_size=20, device="cuda"):
    """Record features from the model using the fetcher."""
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    model.to(device).eval()
    feat_col = {}
    for imgs, _ in tqdm(loader):
        model(imgs.to(device))
        for key in fetcher.activations.keys():
            if key not in feat_col:
                feat_col[key] = []
            feat_col[key].append(fetcher[key].cpu())
    for key in feat_col.keys():
        feat_col[key] = th.cat(feat_col[key], dim=0)
        print(key, feat_col[key].shape)
    return feat_col


def extract_features(model, dataset, layer_name="last_block", batch_size=20, device="cuda"):
    """Extract features from a specified layer of the model."""
    fetcher = featureFetcher_module()
    fetcher.record_module(model.layer4, layer_name, ) # this is for ResNet50, specifically, modify as needed
    feat_dict = record_features(model, fetcher, dataset, batch_size=batch_size, device=device)
    fetcher.cleanup()
    del fetcher
    return feat_dict


def check_gradient(objective_fn):
    """Check if gradients can flow through the objective function."""
    img_opt = th.randn(1, 3, 224, 224).cuda()
    img_opt.requires_grad_(True)
    resp = objective_fn(img_opt)
    resp.mean().backward()
    print(resp.shape)
    assert img_opt.grad is not None
    

def visualize_results(img_col, D2, ):
    """Visualize the optimized images."""
    row_num = len(img_col) // 5 + (len(img_col) % 5 > 0)
    figh, axs = plt.subplots(row_num, 5, figsize=(25, row_num * 5))
    for i, ax in enumerate(axs.flat):
        plt.sca(ax)
        if i >= len(img_col):
            ax.axis("off")
            continue
        tup = img_col[i]
        plot_maco(tup[0], tup[1])
        plt.title(f"Unit {i} R2={D2[i]:.2f}")


def load_model_transform(modelname, device="cuda"):
    # Prepare model and transforms
    if modelname == "resnet50_robust":
        model = resnet50(pretrained=False)
        model.load_state_dict(th.load("/n/home12/binxuwang/Github/Closed-loop-visual-insilico/checkpoints/imagenet_linf_8_pure.pt"))
        transforms_pipeline = T.Compose([
            T.ToTensor(),
            T.Resize((224, 224)),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    elif modelname == "resnet50":
        model = resnet50(pretrained=True)
        transforms_pipeline = T.Compose([
            T.ToTensor(),
            T.Resize((224, 224)),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    elif modelname == "resnet50_clip":
        import clip
        model_clip, preprocess = clip.load('RN50', device=device)
        model = model_clip.visual
        transforms_pipeline = preprocess
    elif modelname == "resnet50_dino":
        # https://github.com/facebookresearch/dino
        model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        transforms_pipeline = T.Compose([
            T.ToTensor(),
            T.Resize((224, 224)),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        raise ValueError(f"Unknown model: {modelname}")
        # model = timm.create_model(modelname, pretrained=True).to(device).eval()
        # data_config = timm.data.resolve_model_data_config(model)
        # transforms_pipeline = timm.data.create_transform(**data_config, is_training=False)
    model = model.to(device).eval()
    model.requires_grad_(False)
    
    return model, transforms_pipeline

In [4]:
dataroot = r"/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Projects/VVS_Accentuation"
data_path = join(dataroot, "nsd_shared1000_6monkeys_2024.h5")
stimroot = join(dataroot, "shared1000")
subject_id = 'paul_240713-240710' # 'paul_20240713-20240710' 
# modelname = "resnet50_clip" # "flexivit_base.1000ep_in21k",
# layer_name = "layer4"
# RD_method = "pca1000" 
batch_size = 96 
device = "cuda" 
reliability_thresh = 0.7
data_dict = load_neural_data(data_path, subject_id, stimroot)
image_fps = data_dict['image_fps']
resp_mat = data_dict['resp_mat']
reliability = data_dict['reliability']
ncsnr = data_dict['ncsnr']
figdir = join(dataroot, "model_outputs", subject_id, )
os.makedirs(figdir, exist_ok=True)

for modelname in ["resnet50_dino", "resnet50_robust", "resnet50", "resnet50_clip", ]:
    model, transforms_pipeline = load_model_transform(modelname, device=device)
    # Load data
    # Prepare dataset
    dataset = ImagePathDataset(image_fps, scores=resp_mat, transform=transforms_pipeline)

    batch_size = 96
    fetcher = featureFetcher(model, input_size=(3, 224, 224), print_module=False)
    module_names = [name for name in fetcher.module_names.values() if "Bottleneck" in name and ("layer4" in name or "layer3" in name)]
    print(module_names)
    for name in module_names: 
        fetcher.record(name, store_device='cpu', ingraph=False, )

    feat_dict_lyrswp = record_features(model, fetcher, dataset, batch_size=batch_size, device=device)
    fetcher.cleanup()

    thresh = reliability_thresh
    chan_mask = reliability > thresh
    resp_mat_sel = resp_mat[:, chan_mask]
    print(f"Fitting models for reliable channels > {thresh} N={chan_mask.sum()}")
    result_df_lyrswp_RidgeCV, fit_models_lyrswp_RidgeCV, Xdict_lyrswp_RidgeCV, Xtfmer_lyrswp_RidgeCV = perform_regression_sweeplayer_RidgeCV(feat_dict_lyrswp, resp_mat_sel, 
                                                                            layer_names=module_names[-9:], dimred_list=["srp", "pca1000"], 
                                                                        alpha_list=[1E-4, 1E-3, 1E-2, 1E-1, 1, 10, 100, 1E3, 1E4, 1E5, 1E6, 1E7, 1E8, 1E9],
                                                                        verbose=True)
    result_df_lyrswp_RidgeCV.to_csv(join(figdir, f"{subject_id}_{modelname}_sweep_regressors_highreliab_layers_sweep_RidgeCV.csv"))
    th.save(fit_models_lyrswp_RidgeCV, join(figdir, f"{subject_id}_{modelname}_sweep_regressors_highreliab_layers_fitmodels_RidgeCV.pth")) 
    th.save(Xtfmer_lyrswp_RidgeCV, join(figdir, f"{subject_id}_{modelname}_sweep_regressors_highreliab_layers_Xtfmer_RidgeCV.pkl"))
    # for half the layers (later half) ~ 20 mins
    result_df_lyrswp_formatted = result_df_lyrswp_RidgeCV.reset_index()
    result_df_lyrswp_formatted.rename(columns={"level_0": "layer_dimred", "level_1": "regressor", }, inplace=True)
    result_df_lyrswp_formatted["layer"] = result_df_lyrswp_formatted["layer_dimred"].apply(lambda x: x.split("_")[0])
    result_df_lyrswp_formatted["dimred"] = result_df_lyrswp_formatted["layer_dimred"].apply(lambda x: x.split("_")[-1])
    result_df_lyrswp_formatted.to_csv(join(figdir, f"{subject_id}_{modelname}_sweep_regressors_highreliab_layers_sweep_RidgeCV_formatted.csv"))


    figh, axs = plt.subplots(1, 2, figsize=(10, 5))
    plt.sca(axs[0])
    sns.lineplot(data=result_df_lyrswp_formatted, x="layer", 
            y="train_score", style="regressor", hue="dimred", ax=axs[0], marker="o")
    plt.xticks(rotation=45)
    xticklabels = plt.gca().get_xticklabels()
    xticklabels = [label.get_text().replace("Bottleneck", "B").replace(".layer", "L") for label in xticklabels]
    plt.xticks(ticks=range(len(xticklabels)), labels=xticklabels, rotation=45)
    plt.title("Training R2")

    plt.sca(axs[1])
    sns.lineplot(data=result_df_lyrswp_formatted, x="layer", 
            y="test_score", style="regressor", hue="dimred", ax=axs[1], marker="o")
    plt.xticks(rotation=45)
    xticklabels = plt.gca().get_xticklabels()
    xticklabels = [label.get_text().replace("Bottleneck", "B").replace(".layer", "L") for label in xticklabels]
    plt.xticks(ticks=range(len(xticklabels)), labels=xticklabels, rotation=45)
    plt.title("Test R2")

    plt.tight_layout()
    plt.suptitle(f"{subject_id} {modelname} layer sweep")
    saveallforms(figdir, f"{subject_id}_{modelname}_layer_sweep_GridCV_synopisis")
    plt.show()

The (227, 227) setting is overwritten by the size in custom transform


In [10]:
list(data_dict.keys())

['brain_area',
 'ncsnr',
 'reliability',
 'stim_pos',
 'stim_size',
 'resp_mat',
 'resp_temp_mat',
 'image_fps']