In [13]:
import os
from PIL import Image
import torch
from tqdm import tqdm


# 原始路径列表
source_paths = [
    f"/public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq_68ku_180km-sw-lora-{45*i}-image"
    for i in range(5, 15)
]

for source_path in tqdm(source_paths):
    # 创建目标目录（原目录名 + "-transferred"）
    target_dir = f"{source_path}-transferred"
    os.makedirs(target_dir, exist_ok=True)
    
    # 遍历源目录中的所有.pt文件
    for pt_file in os.listdir(source_path):
        if not pt_file.endswith('.pt'):
            continue
        
        # 构造完整的文件路径
        pt_path = os.path.join(source_path, pt_file)
        png_filename = os.path.splitext(pt_file)[0] + '.png'
        png_path = os.path.join(target_dir, png_filename)
        
        try:
            # 加载.pt文件（假设存储的是图像张量）
            img_tensor = torch.load(pt_path)
            
            # 转换为PIL图像
            img = Image.fromarray(img_tensor.numpy())
            
            # 保存为PNG
            img.save(png_path)
            #print(f"Converted: {pt_path} -> {png_path}")
            
        except Exception as e:
            print(f"Error converting {pt_path}: {str(e)}")

100%|██████████| 10/10 [00:22<00:00,  2.26s/it]


In [14]:
import subprocess
import csv

def calculate_fid(converted_path):
    """调用pytorch_fid计算FID分数"""
    try:
        cmd = [
            "python", "-m", "pytorch_fid",
            converted_path,
            GT_PATH,
            "--device", CUDA_DEVICE
        ]
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        
        # 解析输出结果
        for line in result.stdout.split('\n'):
            if line.startswith("FID:"):
                return float(line.split()[1])
                
    except subprocess.CalledProcessError as e:
        print(f"FID计算失败: {e.stderr}")
        return "ERROR"
    except Exception as e:
        print(f"解析错误: {str(e)}")
        return "PARSE_ERROR"
    return None

GT_PATH = "/public_data/jihai/data/multimodalout/smart_watch_image_test"
CUDA_DEVICE = "cuda:7"
RESULT_CSV = "fid_results.csv"

# 初始化结果存储
fid_results = []

# 计算所有FID
for source_path in source_paths:
    target_dir = f"{source_path}-transferred"
    print(f"Calculating FID for {target_dir}")
    fid_score = calculate_fid(target_dir)
    fid_results.append(fid_score)
    print(f"FID计算完成: {fid_score}")

# 写入CSV文件（单行）
with open(RESULT_CSV, 'w', newline='') as f:
    writer = csv.writer(f)
    
    # 创建表头
    headers = [f"Model_{i+1}" for i in range(len(source_paths))]
    
    # 写入表头和结果
    writer.writerow(headers)
    writer.writerow(fid_results)


Calculating FID for /public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq_68ku_180km-sw-lora-225-image-transferred
FID计算完成: 182.95791657214625
Calculating FID for /public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq_68ku_180km-sw-lora-270-image-transferred
FID计算完成: 201.46815061505117
Calculating FID for /public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq_68ku_180km-sw-lora-315-image-transferred
FID计算完成: 196.35025107756644
Calculating FID for /public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq_68ku_180km-sw-lora-360-image-transferred
FID计算完成: 162.41236332381658
Calculating FID for /public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq_68ku_180km-sw-lora-405-image-transferred
FID计算完成: 186.08832933109505
Calculating FID for /public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-vq-vq_68ku_180km-sw-lora-450-image-transferred
FID计算完成: 192.4