In [None]:
import torch
from safetensors import safe_open
from safetensors.torch import save_file
import torch.nn as nn

# --- 1. 定义文件路径 ---
# !!! 请务必将这里的路径修改为您自己的文件路径 !!!
model_path = "/home2/y2024/s2430069/DiffSynth-Studio/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint/model.safetensors"
output_path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint/model.safetensors" # 建议保存为新文件以防意外

# --- 2. 加载模型权重 ---
# 使用 safetensors 库安全地加载 .safetensors 文件
# 这会返回一个 state_dict (一个从字符串键映射到 PyTorch 张量的字典)
print(f"正在从 {model_path} 加载模型权重...")
try:
    with safe_open(model_path, framework="pt", device="cpu") as f:
        # 使用字典推导式高效加载所有张量
        model_state_dict = {k: f.get_tensor(k) for k in f.keys()}
except FileNotFoundError:
    print(f"错误：找不到输入文件 '{model_path}'。请检查路径是否正确。")
    exit()


print("\n模型中的原始键和形状 (示例):")
# 只打印前5个键作为示例，避免刷屏
for key in list(model_state_dict.keys())[:5]:
    print(f"- {key}: {model_state_dict[key].shape}")
print("...")

# --- 3. 将整个模型的权重随机初始化 ---
print("\n正在将整个模型的所有权重进行随机初始化...")

# 遍历字典中的每一项
for key, tensor in model_state_dict.items():
    # 检查张量是否为浮点类型，因为整数类型的张量（如位置ID）通常不应被随机化
    if torch.is_floating_point(tensor):
        # 使用 Kaiming Normal 初始化，这是一种常用的权重初始化方法
        # 它会根据张量的形状调整随机值的分布，有助于在训练初期保持梯度的稳定性
        # 注意：这会原地修改张量
        if tensor.dim() > 1: # 通常只对超过一维的权重张量（如 conv, linear）使用 Kaiming 初始化
            nn.init.normal_(tensor, mean=0, std=0)
            print(f"已对张量 '{key}' 应用 Kaiming Normal 初始化。")
        else: # 对于一维张量（如 bias），可以使用正态分布或均匀分布初始化
            nn.init.normal_(tensor, mean=0, std=0)
            print(f"已对一维张量 '{key}' 应用 Normal 初始化。")
    else:
        print(f"跳过非浮点类型的张量 '{key}' (类型: {tensor.dtype})。")


# --- 4. 保存修改后的模型 ---
print(f"\n正在将修改后的权重保存到 {output_path}...")
# save_file 函数会自动创建目标目录（如果它不存在）
save_file(model_state_dict, output_path)

print("操作完成！")

# --- 5. (可选) 验证修改是否成功 ---
print("\n正在验证已保存的文件...")
with safe_open(output_path, framework="pt", device="cpu") as f:
    # 随机选择一个键进行检查
    first_key = next(iter(f.keys()))
    randomized_tensor = f.get_tensor(first_key)
    
    # 检查张量是否不再是全零（一个简单的随机化成功标志）
    # torch.any(tensor != 0) 会检查张量中是否至少有一个非零元素
    if torch.any(randomized_tensor != 0):
        print(f"验证成功：'{first_key}' 中的值已被随机化。")
        # 打印一些样本值
        print(f"  - 均值: {randomized_tensor.mean():.4f}")
        print(f"  - 标准差: {randomized_tensor.std():.4f}")
        print(f"  - 样本值: {randomized_tensor.flatten()[:5].tolist()}")
    else:
        print(f"验证失败：'{first_key}' 仍然是全零，可能初始化未成功。")



正在从 /home2/y2024/s2430069/DiffSynth-Studio/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint/model.safetensors 加载模型权重...

模型中的原始键和形状 (示例):
- controlnet_blocks.0.input_proj.bias: torch.Size([3072])
- controlnet_blocks.0.input_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.0.output_proj.bias: torch.Size([3072])
- controlnet_blocks.0.output_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.0.x_rms.weight: torch.Size([3072])
...

正在将整个模型的所有权重进行随机初始化...
已对一维张量 'controlnet_blocks.0.input_proj.bias' 应用 Normal 初始化。
已对张量 'controlnet_blocks.0.input_proj.weight' 应用 Kaiming Normal 初始化。
已对一维张量 'controlnet_blocks.0.output_proj.bias' 应用 Normal 初始化。
已对张量 'controlnet_blocks.0.output_proj.weight' 应用 Kaiming Normal 初始化。
已对一维张量 'controlnet_blocks.0.x_rms.weight' 应用 Normal 初始化。
已对一维张量 'controlnet_blocks.0.y_rms.weight' 应用 Normal 初始化。
已对一维张量 'controlnet_blocks.1.input_proj.bias' 应用 Normal 初始化。
已对张量 'controlnet_blocks.1.input_proj.weight' 应用 Kaiming Normal 初始化。
已对一维张量 'contro

In [1]:
from safetensors import safe_open
model_path = 'train/Qwen-Image-Edit-2509_inpaint_controlnet_and_lora/step-18000.safetensors'
print(f"正在从 {model_path} 加载模型权重...")
with safe_open(model_path, framework="pt", device="cpu") as f:
    model_state_dict = {}
    for k in f.keys():
        model_state_dict[k] = f.get_tensor(k)

print("模型中的原始键和形状:")
for key in list(model_state_dict.keys())[:5]: # 只打印前5个键作为示例
    print(f"- {key}: {model_state_dict[key].shape}")
print("...")

正在从 train/Qwen-Image-Edit-2509_inpaint_controlnet_and_lora/step-18000.safetensors 加载模型权重...


  import pynvml  # type: ignore[import]


模型中的原始键和形状:
- transformer_blocks.0.attn.add_k_proj.lora_A.default.weight: torch.Size([64, 3072])
- transformer_blocks.0.attn.add_k_proj.lora_B.default.weight: torch.Size([3072, 64])
- transformer_blocks.0.attn.add_q_proj.lora_A.default.weight: torch.Size([64, 3072])
- transformer_blocks.0.attn.add_q_proj.lora_B.default.weight: torch.Size([3072, 64])
- transformer_blocks.0.attn.add_v_proj.lora_A.default.weight: torch.Size([64, 3072])
...


In [2]:
#18000
model_state_dict

{'transformer_blocks.0.attn.add_k_proj.lora_A.default.weight': tensor([[-0.0024,  0.0117, -0.0024,  ..., -0.0079, -0.0064, -0.0111],
         [ 0.0126,  0.0007,  0.0055,  ..., -0.0137,  0.0117,  0.0010],
         [-0.0057, -0.0139, -0.0152,  ...,  0.0058, -0.0182, -0.0140],
         ...,
         [-0.0179,  0.0078, -0.0083,  ..., -0.0052,  0.0056, -0.0085],
         [-0.0052, -0.0012,  0.0096,  ...,  0.0081, -0.0063, -0.0148],
         [ 0.0182, -0.0060, -0.0116,  ...,  0.0094,  0.0031,  0.0020]],
        dtype=torch.bfloat16),
 'transformer_blocks.0.attn.add_k_proj.lora_B.default.weight': tensor([[ 3.3875e-03,  3.1891e-03, -4.3945e-03,  ...,  2.2888e-03,
           2.4414e-03, -1.2283e-03],
         [-5.1270e-03,  1.7319e-03, -1.6403e-03,  ...,  3.3875e-03,
           2.3346e-03, -3.1128e-03],
         [-3.8605e-03, -2.0447e-03,  4.5166e-03,  ..., -3.0060e-03,
          -2.2125e-03,  3.8147e-03],
         ...,
         [ 5.6763e-03, -7.5989e-03, -5.8746e-04,  ..., -8.8501e-03,
       

In [28]:
model_state_dict

{'pipe.blockwise_controlnet.models.0.controlnet_blocks.0.input_proj.bias': tensor([   3.9062,  196.0000, -182.0000,  ...,  -22.5000,  162.0000,
           66.0000], dtype=torch.bfloat16),
 'pipe.blockwise_controlnet.models.0.controlnet_blocks.0.input_proj.weight': tensor([[  84.0000,  217.0000,   76.0000,  ...,   87.0000,   74.5000,
           190.0000],
         [  51.2500,  -73.0000,  129.0000,  ...,  139.0000,  270.0000,
           139.0000],
         [ 100.0000,  166.0000,  180.0000,  ...,   78.5000,  173.0000,
            51.5000],
         ...,
         [ 169.0000,  112.5000,  212.0000,  ...,  -44.0000, -156.0000,
            46.0000],
         [ 118.5000,  155.0000,   30.5000,  ...,  272.0000,   50.5000,
            75.0000],
         [ 241.0000,  128.0000,  152.0000,  ...,   87.5000,    7.0000,
           158.0000]], dtype=torch.bfloat16),
 'pipe.blockwise_controlnet.models.0.controlnet_blocks.0.output_proj.bias': tensor([   0.8750,  110.0000,   53.2500,  ..., -124.0000, -169.0

In [3]:
import json

json_path = "prepared_data_original/metadata.json"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

for item in data:
    if item.get("ref_file"):
        item["ref_file"] = f"ref/{item['ref_file']}"
    if item.get("tgt_original_file"):
        item["tgt_original_file"] = f"tgt_original/{item['tgt_original_file']}"
    if item.get("tgt_original_mask_file"):
        item["tgt_original_mask_file"] = f"tgt_original/{item['tgt_original_mask_file']}"
    if item.get("edit_image"):
        item["edit_image"] = f"tgt_clean/{item['edit_image']}"

with open(json_path, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

print("图片路径已批量修正为 ref/ 和 tgt_original/ 子目录。")

图片路径已批量修正为 ref/ 和 tgt_original/ 子目录。


In [4]:
data

[{'id': '000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807',
  'caption': 'Encircled by flickering candlelight on a quiet altar, it gleams softly, with the heady aroma of incense wafting gently through the sacred space.',
  'ref_file': 'ref/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807.png',
  'ref_mask_file': None,
  'ref_mask_type': None,
  'tgt_clean_file': '000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807.png',
  'tgt_original_file': 'tgt_original/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807.png',
  'tgt_original_mask_file': 'tgt_original/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807_mask.png'},
 {'id': '00b8587e8073744ffef0405cd0eacf53c8ba67613a81c9cede4e81bd583a3b53',
  'caption': "In a shallow tide pool along a rocky beach, it navigates through strands of seaweed, with the overcast sky lending a muted light that enhances the pool's rich, natural colors.",
  'ref_file': 'ref/00b8587e8073744ff

In [5]:
import os
import json
import tqdm

json_path = "prepared_data_original/metadata.json"
base_ref = "prepared_data_original/ref"
base_tgt = "prepared_data_original/tgt_original"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

# 统计有效图片数量
valid_ref = 0
valid_tgt = 0
for item in tqdm.tqdm(data):
    if item.get("ref_file"):
        ref_path = os.path.join(base_ref, item["ref_file"].split("/")[-1])
        if os.path.isfile(ref_path):
            valid_ref += 1
            item["ref_file"] = f"ref/{item['ref_file'].split('/')[-1]}"
        else:
            item["ref_file"] = None
    if item.get("tgt_original_file"):
        tgt_path = os.path.join(base_tgt, item["tgt_original_file"].split("/")[-1])
        if os.path.isfile(tgt_path):
            valid_tgt += 1
            item["tgt_original_file"] = f"tgt_original/{item['tgt_original_file'].split('/')[-1]}"
        else:
            item["tgt_original_file"] = None
    if item.get("tgt_original_mask_file"):
        mask_path = os.path.join(base_tgt, item["tgt_original_mask_file"].split("/")[-1])
        if os.path.isfile(mask_path):
            item["tgt_original_mask_file"] = f"tgt_original/{item['tgt_original_mask_file'].split('/')[-1]}"
        else:
            item["tgt_original_mask_file"] = None
    if item.get("edit_image"):
        edit_path = os.path.join(base_tgt, item["edit_image"].split("/")[-1])
        if os.path.isfile(edit_path):
            item["edit_image"] = f"tgt_clean/{item['edit_image'].split('/')[-1]}"
        else:
            item["edit_image"] = None

# 删除无效图片的条目
data = [item for item in data if item.get("ref_file") and item.get("tgt_original_file")]

with open(json_path, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

print(f"有效ref图片数量: {valid_ref}")
print(f"有效tgt_original图片数量: {valid_tgt}")
print("无效图片条目已删除，路径已修正。")

100%|██████████| 2000/2000 [00:03<00:00, 614.20it/s]

有效ref图片数量: 2000
有效tgt_original图片数量: 2000
无效图片条目已删除，路径已修正。





In [13]:
data

[{'id': '000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807',
  'caption': 'Encircled by flickering candlelight on a quiet altar, it gleams softly, with the heady aroma of incense wafting gently through the sacred space.',
  'ref_file': 'ref/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807.png',
  'ref_mask_file': None,
  'ref_mask_type': None,
  'tgt_clean_file': '000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807.png',
  'tgt_original_file': 'tgt_original/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807.png',
  'tgt_original_mask_file': 'tgt_original/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807_mask.png',
  'blockwise_controlnet_image': 'tgt_original/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807.png',
  'blockwise_controlnet_inpaint_mask': 'tgt_original/000543f6b0ba609b88c9197e114194fda1a652cb4f91884835714895afb07807_mask.png',
  'edit_image': ['tgt_original/000543f6b0ba609b88c9197e1141

In [11]:
for item in data:
    if item.get("tgt_original_file") and item.get("ref_file"):
        #[edittgt_original_file_image,ref_file]
        item["edit_image"] = [f"tgt_original/{item['tgt_original_file'].split('/')[-1]}", f"ref/{item['ref_file'].split('/')[-1]}"]

In [12]:
#修改"edit_image"为"image",删除"edit_image"
for item in data:
    if item.get("tgt_clean_file"):
        item["image"] = f"tgt_clean/{item['tgt_clean_file'].split('/')[-1]}"


In [15]:
#caption-> prompt
for item in data:
    if item.get("caption"):
        item["prompt"] = item["caption"]
        del item["caption"]

In [7]:
#加入'edit_image" 为['tgt_original_file'，'ref_file']
for item in data:
    item["blockwise_controlnet_image"] = item.get("tgt_original_file")
    item["blockwise_controlnet_inpaint_mask"] = item.get("tgt_original_mask_file")
    del item["blockwise_controblockwise_controlnet_inpaint_masklnet_image"]


KeyError: 'blockwise_controblockwise_controlnet_inpaint_masklnet_image'

In [None]:
data

In [None]:
json_path

In [16]:
#保存
with open(json_path, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=4)


In [None]:
import os
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

json_path = "prepared_data_original/metadata.json"
base_ref = "prepared_data_original/ref"
base_tgt = "prepared_data_original/tgt_original"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

def check_and_fix(item):
    result = item.copy()
    valid_ref = False
    valid_tgt = False
    # 检查ref图片
    if item.get("ref_file"):
        ref_name = item["ref_file"].split("/")[-1]
        ref_path = os.path.join(base_ref, ref_name)
        if os.path.isfile(ref_path):
            result["ref_file"] = f"ref/{ref_name}"
            valid_ref = True
        else:
            result["ref_file"] = None
    # 检查tgt图片
    if item.get("tgt_original_file"):
        tgt_name = item["tgt_original_file"].split("/")[-1]
        tgt_path = os.path.join(base_tgt, tgt_name)
        if os.path.isfile(tgt_path):
            result["tgt_original_file"] = f"tgt_original/{tgt_name}"
            valid_tgt = True
        else:
            result["tgt_original_file"] = None
    # 检查mask图片
    if item.get("tgt_original_mask_file"):
        mask_name = item["tgt_original_mask_file"].split("/")[-1]
        mask_path = os.path.join(base_tgt, mask_name)
        if os.path.isfile(mask_path):
            result["tgt_original_mask_file"] = f"tgt_original/{mask_name}"
        else:
            result["tgt_original_mask_file"] = None
    return result, valid_ref, valid_tgt

new_data = []
valid_ref = 0
valid_tgt = 0
with ThreadPoolExecutor(max_workers=64) as executor:
    futures = [executor.submit(check_and_fix, item) for item in data]
    for fut in tqdm(as_completed(futures), total=len(futures), desc="检查图片"):
        item, ref_ok, tgt_ok = fut.result()
        if item.get("ref_file") and item.get("tgt_original_file"):
            new_data.append(item)
        if ref_ok: valid_ref += 1
        if tgt_ok: valid_tgt += 1

with open(json_path, "w", encoding="utf-8") as f:
    json.dump(new_data, f, ensure_ascii=False, indent=2)

print(f"有效ref图片数量: {valid_ref}")
print(f"有效tgt_original图片数量: {valid_tgt}")
print("无效图片条目已删除，路径已修正。")

In [None]:
import os
import json

def check_files(data):
    """
    检查JSON数据中引用的所有文件是否存在
    
    参数:
        data: 包含文件路径的字典
    
    返回:
        包含检查结果的字典
    """
    results = {
        "all_exist": True,
        "missing_files": [],
        "existing_files": [],
        "details": {}
    }
    
    # 需要检查的文件字段
    file_fields = [
        "ref_file",
        "ref_mask_file",
        "tgt_original_file",
        "tgt_original_mask_file",
        "image",
        "blockwise_controlnet_image",
        "blockwise_controlnet_inpaint_mask"
    ]
    
    # 检查单个文件字段
    for field in file_fields:
        if field in data and data[field] is not None:
            file_path = data[field]
            exists = os.path.exists(file_path)
            
            results["details"][field] = {
                "path": file_path,
                "exists": exists
            }
            
            if exists:
                results["existing_files"].append(file_path)
            else:
                results["missing_files"].append(file_path)
                results["all_exist"] = False
        else:
            results["details"][field] = {
                "path": None,
                "exists": None
            }
    
    # 检查edit_image数组中的文件
    if "edit_image" in data and data["edit_image"]:
        for idx, file_path in enumerate(data["edit_image"]):
            exists = os.path.exists(file_path)
            field_name = f"edit_image[{idx}]"
            
            results["details"][field_name] = {
                "path": file_path,
                "exists": exists
            }
            
            if exists:
                results["existing_files"].append(file_path)
            else:
                results["missing_files"].append(file_path)
                results["all_exist"] = False
    
    return results


def print_results(results):
    """打印检查结果"""
    print("=" * 60)
    print("文件存在性检查结果")
    print("=" * 60)
    
    if results["all_exist"]:
        print("\n✓ 所有文件都存在！")
    else:
        print(f"\n✗ 发现 {len(results['missing_files'])} 个文件缺失")
    
    print(f"\n存在的文件: {len(results['existing_files'])}")
    print(f"缺失的文件: {len(results['missing_files'])}")
    
    print("\n详细信息:")
    print("-" * 60)
    
    for field, info in results["details"].items():
        if info["path"] is None:
            status = "N/A"
            symbol = "○"
        elif info["exists"]:
            status = "存在"
            symbol = "✓"
        else:
            status = "缺失"
            symbol = "✗"
        
        print(f"{symbol} {field:35} {status:6} {info['path'] or 'null'}")
    
    if results["missing_files"]:
        print("\n" + "=" * 60)
        print("缺失的文件列表:")
        print("=" * 60)
        for file in results["missing_files"]:
            print(f"  - {file}")
    
    print("\n" + "=" * 60)


# 示例用法
if __name__ == "__main__":
    # 你的数据
    data = {
        "id": "7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61",
        "prompt": "Nestled in a dense forest, this item is captured from a front-facing perspective, partly hidden by the foliage under the dappled light of the afternoon, while leaves swirl gently in the cool breeze.",
        "ref_file": "ref/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "ref_mask_file": None,
        "ref_mask_type": None,
        "tgt_original_file": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "tgt_original_mask_file": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61_mask.png",
        "image": "tgt_clean/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "edit_image": [
            "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
            "ref/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png"
        ],
        "blockwise_controlnet_image": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "blockwise_controlnet_inpaint_mask": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61_mask.png"
    }
    
    # 检查文件
    results = check_files(data)
    
    # 打印结果
    print_results(results)
    
    # 可选：保存结果到JSON文件
    # with open("file_check_results.json", "w", encoding="utf-8") as f:
    #     json.dump(results, f, indent=2, ensure_ascii=False)