In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(module_path)

import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from NeuroPredictor.FeatExtractor import (
    TimmFeatureExtractor,
    TorchvisionFeatureExtractor,
    CLIPFeatureExtractor,
    OpenCLIPFeatureExtractor
)
from NeuroPredictor.Encoder import Encoder

In [None]:
# ==============================
#      Hyperparameters
# ==============================
device        = 'cuda'
batch_size    = 32

# 训练集（MEI 与神经反应对应）
train_img_dir = r"D:\Analysis\NSD_Alignment\NSD_shared1000"
train_resp    = r"D:\Analysis\Ephys_data_Face.npz"   # shape (N_train, n_neurons)
# 测试集（MEI 生成后用于检验的图像）
test_img_dir  = r"D:\Analysis\results\figures_face_timm"
save_dir = r"D:\Analysis\results"

# 第一组 Encoder（用于生成 MEI 时用的那个）
PRIMARY_ENCODER_CFG = {
    'model_type': 'timm',             # 'timm', 'torchvision', 'clip', 'open_clip'
    'model_name': 'vit_base_patch16_clip_224.laion2b',
    'model_layer': 'blocks.8'
}

# 第二组 Encoder（用于比较的其它几个）
SECONDARY_ENCODERS = [
    {
        'name': 'ResNet50_layer4',
        'model_type': 'torchvision',
        'model_name': 'resnet50',
        'model_layer': 'layer4.2.relu'
    },
    {
        'name': 'CLIP_RN50_layer4',
        'model_type': 'clip',
        'model_name': 'RN50',
        'model_layer': 'layer4.2.relu3'
    },
    {
        'name': 'OpenCLIP_ViT32_layer9',
        'model_type': 'open_clip',
        'model_name': 'ViT-B/32',
        'model_layer': 'transformer.resblocks.9'
    },
]

In [None]:
# ==============================
#       Utility Classes
# ==============================
class ImageFolderDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.paths = sorted([
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.lower().endswith(('.png','jpg','jpeg','bmp','tiff'))
        ])
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# ==============================
#   Backbone Factory & Config
# ==============================
BACKBONE_CONFIG = {
    'timm': {
        'class': TimmFeatureExtractor,
        'init_kwargs': {'model_name': None},  # fill in later
        'hook_arg': 'layer_or_names',
        'embedtype': 'spatial',
        'transform': transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466,0.4578275,0.40821073],
                std =[0.26862954,0.26130258,0.27577711]
            )
        ])
    },
    'torchvision': {
        'class': TorchvisionFeatureExtractor,
        'init_kwargs': {'model_name': None},
        'hook_arg': 'module_names',
        'transform': None
    },
    'clip': {
        'class': CLIPFeatureExtractor,
        'init_kwargs': {'model_name': None},
        'hook_arg': 'module_names',
        'transform': None
    },
    'open_clip': {
        'class': OpenCLIPFeatureExtractor,
        'init_kwargs': {'model_name': None},
        'hook_arg': 'module_names',
        'transform': None
    }
}

def make_extractor(cfg):
    bc = BACKBONE_CONFIG[cfg['model_type']]
    bc['init_kwargs']['model_name'] = cfg['model_name']
    extractor = bc['class'](**bc['init_kwargs'])
    extractor.to(device).eval()
    for p in extractor.parameters(): p.requires_grad = False
    hook_args = { bc['hook_arg']: cfg['model_layer'] }
    return extractor, hook_args, bc['transform'] or extractor.get_preprocess()

def extract_all_features(extractor, hook_args, transform, img_dir):
    ds = ImageFolderDataset(img_dir, transform=transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
    feats = []
    with torch.no_grad():
        for imgs in loader:
            imgs = imgs.to(device)
            out  = extractor(imgs, **hook_args)
            feats.append(out.cpu().numpy())
    return np.concatenate(feats, axis=0)

In [None]:
y_train = np.load(train_resp)['data']      # shape (N_train, n_neurons)

# --- Build primary encoder and fit ---
primary_cfg = PRIMARY_ENCODER_CFG
ext_p, hook_p, trans_p = make_extractor(primary_cfg)
X_train = extract_all_features(ext_p, hook_p, trans_p, train_img_dir)
if len(X_train.shape) > 2:
    X_train = X_train.reshape(X_train.shape[0], -1)
enc_primary = Encoder(method='Ridge', cv_folds=5)
enc_primary.fit(X_train, y_train)
y_pred_primary = enc_primary.predict(X_train)

# --- Prepare test features ---
test_feats = {}
for enc_cfg in [PRIMARY_ENCODER_CFG] + SECONDARY_ENCODERS:
    ext, hook, trans = make_extractor(enc_cfg)
    feats = extract_all_features(ext, hook, trans, test_img_dir)
    if len(feats.shape) > 2:
        feats = feats.reshape(feats.shape[0], -1)
    test_feats[enc_cfg.get('name','primary')] = feats

# --- Predict on test set ---
preds = {}
for name, feats in test_feats.items():
    if name == 'primary':
        preds[name] = enc_primary.predict(feats)
    else:
        enc = Encoder(method='Ridge', cv_folds=5)
        # use same responses for training
        ext_cfg = next(e for e in SECONDARY_ENCODERS if e['name']==name)
        # fit on train
        Xtr = extract_all_features(*make_extractor(ext_cfg), train_img_dir)
        if len(Xtr.shape) > 2:
            Xtr = Xtr.reshape(Xtr.shape[0], -1)
        enc.fit(Xtr, y_train)
        preds[name] = enc.predict(feats)
    print(name, preds[name].shape)

# --- Visualization ---
# os.makedirs('results', exist_ok=True)
for name, y_pred in preds.items():
    if name == PRIMARY_ENCODER_CFG.get('name','primary'):
        continue
    plt.figure(figsize=(6,6))
    plt.scatter(
        np.mean(preds['primary'], axis=1),
        np.mean(y_pred, axis=1),
        alpha=0.3, s=5
    )
    plt.plot([np.mean(y_pred, axis=1).min(), np.mean(y_pred, axis=1).max()],
                [np.mean(y_pred, axis=1).min(), np.mean(y_pred, axis=1).max()],
                'r--')
    plt.xlabel('Primary Encoder Prediction')
    plt.ylabel(f'{name} Prediction')
    plt.title(f'Primary vs {name} on Test MEIs')
    plt.tight_layout()
    save_path = os.path.join(save_dir, f'compare_primary_vs_{name}.png')
    plt.savefig(save_path, dpi=150)
    plt.close()