In [1]:
import time 
from subprocess import call
import numpy as np
import matplotlib.pyplot as plt
import h5py
import os
import copy

from tqdm import tqdm
from PIL import Image
from sklearn.preprocessing import StandardScaler

import webdataset as wds
import sys

# nsd_access is from this repo: https://github.com/tknapen/nsd_access
# also see https://cvnlab.slite.page/p/dC~rBTjqjb/How-to-get-the-data for how to download the NSD data!
from nsd_access import NSDAccess
nsd_path = '/scratch/gpfs/KNORMAN/natural-scenes-dataset'
nsda = NSDAccess(nsd_path)

import nibabel as nib

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:",device)

ModuleNotFoundError: No module named 'nsd_access'

In [None]:
tmp = '/scratch/gpfs/KNORMAN'
shared1000 = np.load("shared1000.npy") # download from https://huggingface.co/datasets/pscotti/mindeyev2/tree/main

In [1]:
for sub in [0]: #,1,2,3,4,5,6,7]:
    subject=f'subj0{sub+1}'
    subj=subject
    print(subject)
    
    abs_cnt = -1
    abs_notshared1000_cnt = -1
    abs_shared1000_cnt = -1
    
    # load coco 73k indices
    indices_path = "COCO_73k_subj_indices.hdf5"
    hdf5_file = h5py.File(indices_path, "r")
    indices = hdf5_file[f"{subj}"][:]

    nsessions_allsubj=np.array([40, 40, 32, 30, 40, 32, 40, 30]) 
    nsessions=nsessions_allsubj[sub];
    ntrials = 750*nsessions
    print(nsessions,ntrials)

    print(time.strftime("\nCurrent time: %H:%M:%S", time.localtime())) 
    
    file = f"/scratch/gpfs/KNORMAN/natural-scenes-dataset/nsddata/ppdata/{subject}/func1pt8mm/roi/nsdgeneral.nii.gz"
    nifti = nib.load(file) 
    mask = nifti.get_data()
    mask[mask<1] = 0 
    nsdgeneral_mask = mask

    for tar in tqdm(range(nsessions)):
        sess=tar+1
        
        behav = nsda.read_behavior(subject=subject, 
                    session_index=sess, 
                    trial_index=[]) 

        # pull single-trial betas and mask them
        betas = nsda.read_betas(subject=subject, 
                            session_index=sess, 
                            trial_index=[], # empty list as index means get all for this session
                            data_type='betas_fithrf', # GLMSingle beta2
                            data_format='func1pt8mm') 

        # betas = betas[mask]
        betas = np.moveaxis(betas,-1,0)
        
        vox_include = copy.deepcopy(nsdgeneral_mask)
        ncsnr = nib.load(f"{subject}_ncsnr.nii.gz").get_fdata()
        ncsnr[ncsnr<.15] = np.nan 
        if tar==0: print("voxels left:", len(vox_include[vox_include>0]))
        vox_include[np.isnan(ncsnr)] -= 1 # keep all nsdgeneral voxels even if they are below the threshold
        vox_include[vox_include<0] = 0
        if tar==0: print("voxels left after ncsnr thresholding:", len(vox_include[vox_include>0])) # subj01 = 49329
        
        betas = betas.reshape(len(betas),-1)
        betas = betas[:,vox_include.flatten().astype(bool)]
        shape = betas.shape
        scalar = StandardScaler(with_mean=True, with_std=True).fit(betas) # YOU SHOULD EXCLUDE SHARED1000 FROM THIS (NOT DONE HERE BUT DONE IN ACTUAL MINDEYE2 PAPER)
        betas_mean = scalar.mean_
        betas_std = scalar.scale_
        betas = (betas - betas_mean) / betas_std
        betas = betas.reshape(shape).astype('float16') # (1, 15724)    
        
        globals()[f'betas_ses{sess}'] = betas  
        globals()[f'behav_ses{sess}'] = behav   
        print(betas.shape)
        
    for tar in range(nsessions):
        sess=tar+1
        
        if sess==1:
            betas_all = globals()[f'betas_ses{sess}']
        else:
            betas_all = np.vstack((betas_all,globals()[f'betas_ses{sess}']))
        print(betas_all.shape)
        
    with h5py.File(f'betas_{subject}.hdf5', 'w') as f:
        f.create_dataset('betas', data=betas_all)
    print(f"saved betas_{subject}.hdf5")
        
    os.makedirs(f"{tmp}/mindeyev2_wds/{subj}",exist_ok=True)
    os.makedirs(f"{tmp}/mindeyev2_wds/{subj}/train",exist_ok=True)
    os.makedirs(f"{tmp}/mindeyev2_wds/{subj}/test",exist_ok=True)
    sink1 = wds.TarWriter(f"{tmp}/mindeyev2_wds/{subj}/test/0.tar")
    for tar in tqdm(range(nsessions)):
        behav = globals()[f'behav_ses{tar+1}']
        
        sink2 = wds.TarWriter(f"{tmp}/mindeyev2_wds/{subj}/train/{tar}.tar")
        for i in range(len(behav)):
            abs_cnt += 1                

            trial_numbers = np.where(indices==indices[abs_cnt])[0]
            assert np.isin(abs_cnt,trial_numbers)
            trial_numbers[trial_numbers == abs_cnt] = -1 # current trial becomes negative 1
            if len(trial_numbers) == 1:
                trial_numbers = np.append(trial_numbers, -1)
                trial_numbers = np.append(trial_numbers, -1)
            if len(trial_numbers) == 2:
                trial_numbers = np.append(trial_numbers, -1)
            assert len(trial_numbers) == 3

            sess=tar+1
            behav = globals()[f'behav_ses{sess}']
            behav_matrix = np.ones((1, 17))*-1
            jjj=-1
            for j in range(1):
                jj = i-j
                jjj += 1

                if jj >= 0:
                    # change NaNs to negative-one integers
                    iscorrect = behav.iloc[jj]['ISCORRECT']
                    if np.isnan(iscorrect): iscorrect = -1

                    isoldcurrent = behav.iloc[jj]['ISOLDCURRENT']
                    if np.isnan(isoldcurrent): isoldcurrent = -1

                    iscorrectcurrent = behav.iloc[jj]['ISCORRECTCURRENT']
                    if np.isnan(iscorrectcurrent): iscorrectcurrent = -1

                    rt = behav.iloc[jj]['RT']
                    if np.isnan(rt): rt = -1

                    changemind = behav.iloc[jj]['CHANGEMIND']
                    if np.isnan(changemind): changemind = -1

                    button = behav.iloc[jj]['BUTTON']
                    if np.isnan(button): button = -1

                    total1 = behav.iloc[jj]['TOTAL1']
                    if np.isnan(total1): total1 = -1

                    total2 = behav.iloc[jj]['TOTAL2']
                    if np.isnan(total2): total2 = -1
                    
                    coco73 = int(behav.iloc[jj]['73KID'])-1
                    assert coco73 >= 0 and coco73 < 730000

                    behavior = {
                        "cocoidx": coco73, #0
                        "subject": sub+1,                          #1
                        "session": int(behav.iloc[jj]['SESSION']), #2
                        "run": int(behav.iloc[jj]['RUN']),         #3
                        "trial": int(behav.iloc[jj]['TRIAL']),     #4
                        "global_trial": (int(behav.iloc[jj]['SESSION'])-1)*750 + jj,        #5
                        "time": int(behav.iloc[jj]['TIME']),       #6
                        "isold": int(behav.iloc[jj]['ISOLD']),     #7
                        "iscorrect": iscorrect,                    #8
                        "rt": rt, # 0 = no RT                      #9
                        "changemind": changemind,                  #10
                        "isoldcurrent": isoldcurrent,              #11
                        "iscorrectcurrent": iscorrectcurrent,      #12
                        "total1": total1,   #13
                        "total2": total2,   #14
                        "button": button,                          #15
                        "shared1000": shared1000[int(behav.iloc[jj]['73KID'])-1], #16
                    }
                    
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj >= 0
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj < 27750

                    behav_matrix[jjj] = np.array(list(behavior.values()))
                    
            past_behav_matrix = np.ones((15, 17))*-1
            jjj=-1
            for j in range(1,16):
                jj = i-j
                jjj += 1

                if jj >= 0:
                    # change NaNs to negative-one integers
                    iscorrect = behav.iloc[jj]['ISCORRECT']
                    if np.isnan(iscorrect): iscorrect = -1

                    isoldcurrent = behav.iloc[jj]['ISOLDCURRENT']
                    if np.isnan(isoldcurrent): isoldcurrent = -1

                    iscorrectcurrent = behav.iloc[jj]['ISCORRECTCURRENT']
                    if np.isnan(iscorrectcurrent): iscorrectcurrent = -1

                    rt = behav.iloc[jj]['RT']
                    if np.isnan(rt): rt = -1

                    changemind = behav.iloc[jj]['CHANGEMIND']
                    if np.isnan(changemind): changemind = -1

                    button = behav.iloc[jj]['BUTTON']
                    if np.isnan(button): button = -1

                    total1 = behav.iloc[jj]['TOTAL1']
                    if np.isnan(total1): total1 = -1

                    total2 = behav.iloc[jj]['TOTAL2']
                    if np.isnan(total2): total2 = -1
                    
                    coco73 = int(behav.iloc[jj]['73KID'])-1
                    assert coco73 >= 0 and coco73 < 730000

                    behavior = {
                        "cocoidx": coco73, #0
                        "subject": sub+1,                          #1
                        "session": int(behav.iloc[jj]['SESSION']), #2
                        "run": int(behav.iloc[jj]['RUN']),         #3
                        "trial": int(behav.iloc[jj]['TRIAL']),     #4
                        "global_trial": (int(behav.iloc[jj]['SESSION'])-1)*750 + jj,        #5
                        "time": int(behav.iloc[jj]['TIME']),       #6
                        "isold": int(behav.iloc[jj]['ISOLD']),     #7
                        "iscorrect": iscorrect,                    #8
                        "rt": rt, # 0 = no RT                      #9
                        "changemind": changemind,                  #10
                        "isoldcurrent": isoldcurrent,              #11
                        "iscorrectcurrent": iscorrectcurrent,      #12
                        "total1": total1,   #13
                        "total2": total2,   #14
                        "button": button,                          #15
                        "shared1000": shared1000[int(behav.iloc[jj]['73KID'])-1], #16
                    }
                    
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj >= 0
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj < 27750

                    past_behav_matrix[jjj] = np.array(list(behavior.values()))
                    
            future_behav_matrix = np.ones((15, 17))*-1
            jjj=-1
            for j in range(1,16):
                jj = i+j
                jjj += 1

                if jj >= 0 and jj<750:
                    # change NaNs to negative-one integers
                    iscorrect = behav.iloc[jj]['ISCORRECT']
                    if np.isnan(iscorrect): iscorrect = -1

                    isoldcurrent = behav.iloc[jj]['ISOLDCURRENT']
                    if np.isnan(isoldcurrent): isoldcurrent = -1

                    iscorrectcurrent = behav.iloc[jj]['ISCORRECTCURRENT']
                    if np.isnan(iscorrectcurrent): iscorrectcurrent = -1

                    rt = behav.iloc[jj]['RT']
                    if np.isnan(rt): rt = -1

                    changemind = behav.iloc[jj]['CHANGEMIND']
                    if np.isnan(changemind): changemind = -1

                    button = behav.iloc[jj]['BUTTON']
                    if np.isnan(button): button = -1

                    total1 = behav.iloc[jj]['TOTAL1']
                    if np.isnan(total1): total1 = -1

                    total2 = behav.iloc[jj]['TOTAL2']
                    if np.isnan(total2): total2 = -1
                    
                    coco73 = int(behav.iloc[jj]['73KID'])-1
                    assert coco73 >= 0 and coco73 < 730000

                    behavior = {
                        "cocoidx": coco73, #0
                        "subject": sub+1,                          #1
                        "session": int(behav.iloc[jj]['SESSION']), #2
                        "run": int(behav.iloc[jj]['RUN']),         #3
                        "trial": int(behav.iloc[jj]['TRIAL']),     #4
                        "global_trial": (int(behav.iloc[jj]['SESSION'])-1)*750 + jj,        #5
                        "time": int(behav.iloc[jj]['TIME']),       #6
                        "isold": int(behav.iloc[jj]['ISOLD']),     #7
                        "iscorrect": iscorrect,                    #8
                        "rt": rt, # 0 = no RT                      #9
                        "changemind": changemind,                  #10
                        "isoldcurrent": isoldcurrent,              #11
                        "iscorrectcurrent": iscorrectcurrent,      #12
                        "total1": total1,   #13
                        "total2": total2,   #14
                        "button": button,                          #15
                        "shared1000": shared1000[int(behav.iloc[jj]['73KID'])-1], #16
                    }
                    
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj >= 0
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj < 27750

                    future_behav_matrix[jjj] = np.array(list(behavior.values()))

            olds_behav_matrix = np.ones((3, 17))*-1
            jjj=-1
            for j in range(3):
                jj = trial_numbers[j]

                if jj>=0:
                    jjj += 1
                    old_session = int(np.floor(jj / 750)) + 1
                    old_trial = jj % 750
                    behav = globals()[f'behav_ses{old_session}']
                    jj = old_trial

                    # change NaNs to negative-one integers
                    iscorrect = behav.iloc[jj]['ISCORRECT']
                    if np.isnan(iscorrect): iscorrect = -1

                    isoldcurrent = behav.iloc[jj]['ISOLDCURRENT']
                    if np.isnan(isoldcurrent): isoldcurrent = -1

                    iscorrectcurrent = behav.iloc[jj]['ISCORRECTCURRENT']
                    if np.isnan(iscorrectcurrent): iscorrectcurrent = -1

                    rt = behav.iloc[jj]['RT']
                    if np.isnan(rt): rt = -1

                    changemind = behav.iloc[jj]['CHANGEMIND']
                    if np.isnan(changemind): changemind = -1

                    button = behav.iloc[jj]['BUTTON']
                    if np.isnan(button): button = -1

                    total1 = behav.iloc[jj]['TOTAL1']
                    if np.isnan(total1): total1 = -1

                    total2 = behav.iloc[jj]['TOTAL2']
                    if np.isnan(total2): total2 = -1
                    
                    coco73 = int(behav.iloc[jj]['73KID'])-1
                    assert coco73 >= 0 and coco73 < 730000

                    behavior = {
                        "cocoidx": coco73, #0
                        "subject": sub+1,                          #1
                        "session": int(behav.iloc[jj]['SESSION']), #2
                        "run": int(behav.iloc[jj]['RUN']),         #3
                        "trial": int(behav.iloc[jj]['TRIAL']),     #4
                        "global_trial": (int(behav.iloc[jj]['SESSION'])-1)*750 + jj,        #5
                        "time": int(behav.iloc[jj]['TIME']),       #6
                        "isold": int(behav.iloc[jj]['ISOLD']),     #7
                        "iscorrect": iscorrect,                    #8
                        "rt": rt, # 0 = no RT                      #9
                        "changemind": changemind,                  #10
                        "isoldcurrent": isoldcurrent,              #11
                        "iscorrectcurrent": iscorrectcurrent,      #12
                        "total1": total1,   #13
                        "total2": total2,   #14
                        "button": button,                          #15
                        "shared1000": shared1000[int(behav.iloc[jj]['73KID'])-1], #16
                    }
                    
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj >= 0
                    assert (int(behav.iloc[jj]['SESSION'])-1)*750 + jj < 27750

                    olds_behav_matrix[jjj] = np.array(list(behavior.values()))

            behav = globals()[f'behav_ses{sess}']
            # Check if this is a shared1000 trial
            if shared1000[int(behav.iloc[i]['73KID'])-1]:
                abs_shared1000_cnt += 1
            else:
                abs_notshared1000_cnt += 1
                
            with torch.no_grad(): #https://cvnlab.slite.page/p/fRv4lz5V2F/Untitled
                if shared1000[int(behav.iloc[i]['73KID'])-1]:
                    sink1.write({
                        "__key__": "sample%09d" % abs_shared1000_cnt,
                        "behav.npy": behav_matrix,
                        "past_behav.npy": past_behav_matrix,
                        "future_behav.npy": future_behav_matrix,
                        "olds_behav.npy": olds_behav_matrix,
                    })
                    assert behav_matrix[-1,0] < 73000
                else:
                    sink2.write({
                        "__key__": "sample%09d" % abs_notshared1000_cnt,
                        "behav.npy": behav_matrix,
                        "past_behav.npy": past_behav_matrix,
                        "future_behav.npy": future_behav_matrix,
                        "olds_behav.npy": olds_behav_matrix,
                    })
                    assert behav_matrix[-1,0] < 73000
        sink2.close()
    sink1.close()
    
    print(time.strftime("\nCurrent time: %H:%M:%S", time.localtime())) 

In [2]:
# --- 定义所有输入文件的绝对路径 ---

# 评估用的标准图像和标题
ALL_IMAGES_PT = "/home/vipuser/MindEyeV2_Project/src/evals/all_images.pt"
ALL_CAPTIONS_PT = "/home/vipuser/MindEyeV2_Project/src/evals/all_captions.pt"

# 你自己训练生成的重建图像所在的目录
# 请确保这个路径是正确的
RECONS_IMG_DIR = "/home/vipuser/train_logs/s1_ps1p5_h512_e5_cycle/inference/images"

# 打印路径以供检查
print("Ground Truth Images:", ALL_IMAGES_PT)
print("Ground Truth Captions:", ALL_CAPTIONS_PT)
print("Reconstructed Images Directory:", RECONS_IMG_DIR)


Ground Truth Images: /home/vipuser/MindEyeV2_Project/src/evals/all_images.pt
Ground Truth Captions: /home/vipuser/MindEyeV2_Project/src/evals/all_captions.pt
Reconstructed Images Directory: /home/vipuser/train_logs/s1_ps1p5_h512_e5_cycle/inference/images


In [5]:
# ====================================================================
# 卡点 A: 从 PNG 图片生成评估所需的 recons.pt 文件 (最终修正版)
# ====================================================================

# --- 0. 强制升级 open_clip_torch 库 ---
# 这是解决问题的关键步骤
print("正在强制升级 open_clip_torch 库到最新版本...")
!pip install --upgrade open_clip_torch
print("升级完成！")


import os
import glob
import torch
from PIL import Image
from tqdm import tqdm
import open_clip # 重新导入

# --- 重新确认路径 ---
print(f"Source PNG directory: {RECONS_IMG_DIR}")

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- 1. 加载 CLIP 模型 (使用 openai 标签) ---
print("Loading CLIP model ViT-H-14 from OpenAI...")
try:
    # 升级后，这行代码应该就能正常工作了
    model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="openai")
    model = model.to(device).eval()
    print("CLIP model loaded successfully.")
except Exception as e:
    print(f"❌ 加载模型失败: {e}")
    print("如果升级后仍然失败，请检查网络连接或重启 Kernel 再试。")
    raise e

# --- 2. 找到所有 PNG 文件并计算特征 ---
png_files = sorted(glob.glob(os.path.join(RECONS_IMG_DIR, "*.png")))
if not png_files:
    raise FileNotFoundError(f"错误：在目录 {RECONS_IMG_DIR} 中没有找到任何 .png 文件。请检查路径！")

print(f"Found {len(png_files)} PNG files to process.")

all_embeddings = []
with torch.no_grad():
    for file_path in tqdm(png_files, desc="Calculating embeddings"):
        image = Image.open(file_path).convert("RGB")
        image_tensor = preprocess(image).unsqueeze(0).to(device)
        embedding = model.encode_image(image_tensor)
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
        all_embeddings.append(embedding.cpu())

recons_embeds = torch.cat(all_embeddings, dim=0)

# --- 3. 保存特征文件 ---
RECONS_PT_PATH = "/home/vipuser/train_logs/s1_ps1p5_h512_e5_cycle/inference/recons.pt"
torch.save(recons_embeds, RECONS_PT_PATH)

print(f"\nSuccessfully saved embeddings to: {RECONS_PT_PATH}")
print(f"Shape of the saved tensor: {recons_embeds.shape}")

# --- 4. 将生成的特征加载到变量中，供后续使用 ---
recons_embeds = torch.load(RECONS_PT_PATH).to(device)
print("Reconstruction embeddings are loaded into 'recons_embeds' variable and ready for evaluation.")

# 清理显存
del model
torch.cuda.empty_cache()


正在强制升级 open_clip_torch 库到最新版本...


[0m

ERROR:root:Pretrained tag or path (openai) for 'ViT-H-14' not found. Available tags: ['laion2b_s32b_b79k', 'metaclip_fullcc', 'metaclip_altogether', 'dfn5b']


升级完成！
Source PNG directory: /home/vipuser/train_logs/s1_ps1p5_h512_e5_cycle/inference/images
Loading CLIP model ViT-H-14 from OpenAI...
❌ 加载模型失败: Pretrained value 'openai' is not a known tag or valid file path
如果升级后仍然失败，请检查网络连接或重启 Kernel 再试。


RuntimeError: Pretrained value 'openai' is not a known tag or valid file path

In [None]:
# ====================================================================
# 第 5 步: 加载 Ground Truth 数据并计算评估指标
# ====================================================================
import torch
import numpy as np
from tqdm import tqdm

# --- 1. 加载 Ground Truth (GT) 图像和标题的特征 ---
# 确保这些变量路径正确
print("Loading Ground Truth embeddings...")
gt_images_pt = torch.load(ALL_IMAGES_PT, map_location=device)
# gt_captions_pt = torch.load(ALL_CAPTIONS_PT, map_location=device) # 标题暂时不用，先注释掉

# gt_images_pt 可能是一个字典，我们需要取出真正的特征张量
# 常见的 key 是 'images' 或 'embeds'，我们检查一下
if isinstance(gt_images_pt, dict):
    if 'images' in gt_images_pt:
        gt_embeds = gt_images_pt['images'].to(device)
    elif 'embeds' in gt_images_pt:
        gt_embeds = gt_images_pt['embeds'].to(device)
    else:
        raise KeyError("在 all_images.pt 文件中找不到 'images' 或 'embeds' 键")
else:
    gt_embeds = gt_images_pt.to(device)

print(f"Ground Truth embeddings loaded. Shape: {gt_embeds.shape}")
print(f"Reconstruction embeddings shape: {recons_embeds.shape}")

# --- 2. 检查并匹配样本数量 ---
# 你的重建结果数量可能和GT数量不一致，这是正常的。评估时以前者为准。
num_recons = recons_embeds.shape[0]
if num_recons > gt_embeds.shape[0]:
    print(f"警告：重建图像数量 ({num_recons}) 大于GT图像数量 ({gt_embeds.shape[0]})。将只评估前 {gt_embeds.shape[0]} 个样本。")
    recons_embeds = recons_embeds[:gt_embeds.shape[0]]
else:
    # 从完整的GT中，只取出与你重建数量相匹配的前 N 个
    gt_embeds = gt_embeds[:num_recons]

print(f"Embeddings matched for evaluation. Using {num_recons} samples.")


# --- 3. 计算指标 ---

# 指标 1: CLIP Cosine Similarity (CLIP-Cos)
# 逐个计算你的重建图像与对应的GT图像的相似度
print("\nCalculating CLIP Cosine Similarity...")
# (N, D) * (N, D) -> (N,)
cos_sims = (recons_embeds * gt_embeds).sum(dim=1)
clip_cos = cos_sims.mean().item()
print(f" => Average CLIP Cosine Similarity: {clip_cos:.4f}")


# 指标 2: Retrieval Accuracy (Top-1, Top-5)
# 检查你的每个重建图像，在所有GT图像中，能否找回正确的那个
print("\nCalculating Retrieval Accuracy (Top-1, Top-5)...")
# 计算你的每个重建图像与 *所有* GT图像的相似度矩阵
# (N, D) @ (D, N) -> (N, N)
sim_matrix = torch.matmul(recons_embeds, gt_embeds.t())

# 找到每个重建图像最相似的GT图像的索引
_, top_indices = torch.topk(sim_matrix, k=5, dim=1)

# 正确的索引应该是对角线 (0, 1, 2, ..., N-1)
correct_indices = torch.arange(num_recons, device=device)

# Top-1: 最相似的是否是正确的那个
top1_correct = top_indices[:, 0] == correct_indices
top1_new = top1_correct.float().mean().item()

# Top-5: 前5个最相似的是否包含正确的那个
top5_correct = (top_indices == correct_indices.unsqueeze(1)).any(dim=1)
top5_new = top5_correct.float().mean().item()

print(f" => Top-1 Accuracy: {top1_new:.4f} ({int(top1_new*num_recons)}/{num_recons})")
print(f" => Top-5 Accuracy: {top5_new:.4f} ({int(top5_new*num_recons)}/{num_recons})")

# 为了让最后一个单元格能用，我们把计算出的指标存为全局变量
# （在Jupyter Notebook中，一个单元格的变量在运行后对其他单元格可见）
print("\nMetrics calculated and stored in variables: clip_cos, top1_new, top5_new")


In [None]:
# ====================================================================
# 第 6 步 (最终步): 将所有指标保存到 JSON 文件
# ====================================================================
import json
import time
import os

# --- 1. 准备要保存的指标字典 ---
# 我们从全局变量中获取上一步计算好的指标
# 如果某些指标不存在（比如 mse, lpips），我们将其设为 None

# 检查变量是否存在，不存在则设为 None
# 这样做可以避免代码因缺少某个非核心指标而报错
mse_value = float(mse) if 'mse' in globals() else None
lpips_value = float(lpips) if 'lpips' in globals() else None
latency_value = float(globals().get("latency_ms_per_img")) if 'latency_ms_per_img' in globals() and globals().get("latency_ms_per_img") is not None else None


metrics = {
    "top1_new": float(top1_new),
    "top5_new": float(top5_new),
    "clip_cosine": float(clip_cos),
    "mse": mse_value,
    "lpips": lpips_value,
    "latency_ms": latency_value,
    "set": "new_test", # 标记这是哪次测试
    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") # 添加时间戳
}

# --- 2. 定义输出文件路径 ---
# 我们将它保存在你的模型训练日志目录下，方便管理
output_json_path = "/home/vipuser/train_logs/s1_ps1p5_h512_e5_cycle/metrics.json"

# --- 3. 写入 JSON 文件 ---
try:
    with open(output_json_path, "w") as f:
        json.dump(metrics, f, indent=4) # indent=4 让文件格式更美观
    print(f"✅ 评估完成！指标已成功保存到: {output_json_path}")
    
    # --- 4. 打印最终结果 ---
    print("\n--- Final Metrics ---")
    print(json.dumps(metrics, indent=4))
    print("---------------------")

except Exception as e:
    print(f"❌ 保存文件时出错: {e}")
    print("请检查路径 '/home/vipuser/train_logs/s1_ps1p5_h512_e5_cycle/' 是否存在且有写入权限。")

