In [10]:
import os
from PIL import Image
import torch
from tqdm import tqdm


# 原始路径列表
source_paths = [
    f"/public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq-af2-sw-lora-{45*i}-image"
    for i in range(10, 11)
]

for source_path in tqdm(source_paths):
    # 创建目标目录（原目录名 + "-transferred"）
    target_dir = f"{source_path}-transferred"
    os.makedirs(target_dir, exist_ok=True)
    
    # 遍历源目录中的所有.pt文件
    for pt_file in os.listdir(source_path):
        if not pt_file.endswith('.pt'):
            continue
        
        # 构造完整的文件路径
        pt_path = os.path.join(source_path, pt_file)
        png_filename = os.path.splitext(pt_file)[0] + '.png'
        png_path = os.path.join(target_dir, png_filename)
        
        try:
            # 加载.pt文件（假设存储的是图像张量）
            img_tensor = torch.load(pt_path)
            
            # 转换为PIL图像
            img = Image.fromarray(img_tensor.numpy())
            
            # 保存为PNG
            img.save(png_path)
            #print(f"Converted: {pt_path} -> {png_path}")
            
        except Exception as e:
            print(f"Error converting {pt_path}: {str(e)}")

100%|██████████| 1/1 [00:02<00:00,  2.44s/it]


In [11]:
import subprocess
import csv

# source_paths = [
#     f"/public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-siglip-vq_180ku_180km-sw-lora-{45*i}-image"
#     for i in range(20, 21)
# ]

def calculate_fid(converted_path):
    """调用pytorch_fid计算FID分数"""
    try:
        cmd = [
            "python", "-m", "pytorch_fid",
            converted_path,
            GT_PATH,
            "--device", CUDA_DEVICE
        ]
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        
        # 解析输出结果
        for line in result.stdout.split('\n'):
            if line.startswith("FID:"):
                return float(line.split()[1])
                
    except subprocess.CalledProcessError as e:
        print(f"FID计算失败: {e.stderr}")
        return "ERROR"
    except Exception as e:
        print(f"解析错误: {str(e)}")
        return "PARSE_ERROR"
    return None

GT_PATH = "/public_data/jihai/data/multimodalout/smart_watch_image_test"
CUDA_DEVICE = "cuda:7"
RESULT_CSV = "fid_results.csv"

# 初始化结果存储
fid_results = []

# 计算所有FID
for source_path in source_paths:
    target_dir = f"{source_path}-transferred"
    print(f"Calculating FID for {target_dir}")
    fid_score = calculate_fid(target_dir)
    fid_results.append(fid_score)
    print(f"FID计算完成: {fid_score}")

# 写入CSV文件（单行）
with open(RESULT_CSV, 'w', newline='') as f:
    writer = csv.writer(f)
    
    # 创建表头
    headers = [f"Model_{i+1}" for i in range(len(source_paths))]
    
    # 写入表头和结果
    writer.writerow(headers)
    writer.writerow(fid_results)


Calculating FID for /public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq-af2-sw-lora-450-image-transferred
FID计算完成: 195.36088087017606


In [3]:
import os
import shutil
from pathlib import Path
from tqdm import tqdm

def copy_first_n_images(src_folder, dst_folder, n=2000):
    # 确保源文件夹存在
    src_path = Path(src_folder)
    if not src_path.exists() or not src_path.is_dir():
        print(f"源文件夹 {src_folder} 不存在！")
        return

    # 创建目标文件夹（如果已存在则先删除）
    dst_path = Path(dst_folder)
    if dst_path.exists():
        shutil.rmtree(dst_path)  # 删除已有文件夹
    dst_path.mkdir(parents=True, exist_ok=True)

    # 获取所有 PNG 文件并按名称排序
    files = sorted(src_path.glob("*.png"), key=lambda f: f.name)  # 按文件名排序
    # 如果需要按修改时间排序，可以改用：
    # files = sorted(src_path.glob("*.png"), key=lambda f: f.stat().st_mtime)

    # 拷贝前 N 个文件
    copied = 0
    for file in tqdm(files):
        if copied >= n:
            break
        shutil.copy(file, dst_path / file.name)
        copied += 1
       

    print(f"完成！共拷贝 {copied} 个文件到 {dst_folder}")

# 使用示例
copy_first_n_images(
    src_folder="/public_data/jihai/data/multimodalout/smart_watch_image_train",  # 替换为你的源文件夹路径（如 "/data/images"）
    dst_folder="/public_data/jihai/data/multimodalout/smart_watch_image_2k",    # 替换为目标文件夹名称（如 "/data/images_subset"）
    n=2000
)

  3%|▎         | 2000/60000 [00:00<00:03, 18125.73it/s]

完成！共拷贝 2000 个文件到 /public_data/jihai/data/multimodalout/smart_watch_image_2k



