In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
print(module_path)
sys.path.append(module_path)

In [None]:
import clip
import timm
import open_clip
import torchvision.models as models

from NeuroPredictor.FeatExtractor import (
    TimmFeatureExtractor,
    TorchvisionFeatureExtractor,
    CLIPFeatureExtractor,
    OpenCLIPFeatureExtractor
)

#### Get available models

In [None]:
backbone_type = 'torchvision' # 'timm', 'torchvision', 'clip', 'openclip'
if backbone_type == 'timm':
    model_list = timm.list_models()
elif backbone_type == 'torchvision':
    model_list = models.list_models()
elif backbone_type == 'clip':
    model_list = clip.available_models()
elif backbone_type == 'openclip':
    model_list = open_clip.list_models()

print(model_list)
print([i for i in model_list if 'resnet' in i])

#### Get available layers

In [None]:
backbone_type = 'torchvision'
# model_name = 'vit_base_patch16_clip_224.laion2b'
model_name = 'resnet50'

if backbone_type == 'timm':
    model = TimmFeatureExtractor(model_name=model_name)
elif backbone_type == 'torchvision':
    model = TorchvisionFeatureExtractor(model_name=model_name)
elif backbone_type == 'clip':
    model = CLIPFeatureExtractor(model_name=model_name)
elif backbone_type == 'openclip':
    model = OpenCLIPFeatureExtractor(model_name=model_name)

print(model.list_hookable_layers())

In [None]:
backbone_type = 'torchvision'
model_name = 'resnet50'
# ckpt_path = r"D:\Dataset\robust_weights\resnet50\resnet50_linf_eps2.0.ckpt"
ckpt_path = r"D:\Dataset\ss_weights\resnet50\r50_1x_sk1.pth"

if backbone_type == 'timm':
    model = TimmFeatureExtractor(model_name=model_name)
elif backbone_type == 'torchvision':
    model = TorchvisionFeatureExtractor(model_name=model_name, ckpt_path=ckpt_path)
elif backbone_type == 'clip':
    model = CLIPFeatureExtractor(model_name=model_name)
elif backbone_type == 'openclip':
    model = OpenCLIPFeatureExtractor(model_name=model_name)

print(model.list_hookable_layers())

#### Get the shape of layers

In [None]:
# extractor = TorchvisionFeatureExtractor(model_name='resnet50', ckpt_path=ckpt_path)
extractor = TorchvisionFeatureExtractor(model_name='mobilenet_v3_small')

from collections import Counter

def param_dtype_summary(model):
    cnt = Counter()
    for name, p in model.named_parameters():
        cnt[str(p.dtype)] += 1
    return cnt
print(param_dtype_summary(extractor.model), next(extractor.model.parameters()).device)
layers = extractor.list_hookable_layers()
# selected_layers = ['layer1', 'layer2', 'layer3', 'layer4']
selected_layers = layers
shapes = extractor.get_feature_shapes(selected_layers)

for name, shape in shapes.items():
    print(f"{name}: {shape}")

In [None]:
extractor = TimmFeatureExtractor(model_name='vit_base_patch16_clip_224.laion2b')
layers = extractor.list_hookable_layers()
selected_layers = ['patch_embed.proj', 'blocks.0.norm1', 'blocks.0.attn.qkv', 'blocks.11', 'norm']
shapes = extractor.get_feature_shapes(selected_layers)

for name, shape in shapes.items():
    print(f"{name}: {shape}")

#### Test weights loading

In [None]:
import torch
from torchvision.models import resnet50, ResNet50_Weights

# 1) 新建 torchvision 的 resnet50（不自动下载 torchvision 的预训练权重）
model = resnet50(weights=None)
model.eval()

# 2) 加载 checkpoint（路径换成你下载的文件）
ckpt_path = r"D:\Dataset\ss_weights\resnet50\moco_v2_800ep_pretrain.pth.tar"  # 文件来自 RobustBench / MadryLab 等
ckpt = torch.load(ckpt_path, map_location="cpu")

# ckpt 的结构会不同：常见有直接 state_dict，或包含 'model_state_dict'/'state_dict'/'model'
if isinstance(ckpt, dict):
    # 尝试常见键名
    if "state_dict" in ckpt:
        sd = ckpt["state_dict"]
    elif "model_state_dict" in ckpt:
        sd = ckpt["model_state_dict"]
    elif "model" in ckpt and isinstance(ckpt["model"], dict):
        sd = ckpt["model"]
    else:
        # 否则假设整个对象就是 state_dict
        sd = ckpt
else:
    sd = ckpt

# 3) 清洗 key（去掉 'module.' 或 'model.' 前缀）
new_sd = {}
for k, v in sd.items():
    new_k = k
    if k.startswith("module."):
        new_k = k[len("module."):]
    if new_k.startswith("model."):
        new_k = new_k[len("model."):]
    new_sd[new_k] = v

# 4) 尝试加载（strict=False 以便于调试不匹配项）
missing, unexpected = model.load_state_dict(new_sd, strict=False)
print("missing keys:", missing)
print("unexpected keys:", unexpected)

In [None]:
print(new_sd.keys())
print(model)

In [None]:
import numpy as np
mat_a, mat_b = np.random.randn(17390, 139968), np.random.randn(17390, 139968)
corr = np.corrcoef(mat_a, mat_b)[0, 1]
print(corr)