In [None]:
import torch
from safetensors import safe_open
from safetensors.torch import save_file
import torch.nn as nn

# 定义文件路径
model_path = "/home2/y2024/s2430069/DiffSynth-Studio/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint/model.safetensors"
output_path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint/model_zero.safetensors" # 建议保存为新文件以防意外

# 使用 safetensors 库来加载 .safetensors 文件
# 这会加载为一个 state_dict (一个从字符串映射到张量的字典)
print(f"正在从 {model_path} 加载模型权重...")
with safe_open(model_path, framework="pt", device="cpu") as f:
    model_state_dict = {}
    for k in f.keys():
        model_state_dict[k] = f.get_tensor(k)

print("模型中的原始键和形状:")
for key in list(model_state_dict.keys())[:5]: # 只打印前5个键作为示例
    print(f"- {key}: {model_state_dict[key].shape}")
print("...")

# --- 将整个模型的权重完全初始化为0 ---
print("\n正在将整个模型的所有权重清零...")

# 遍历字典中的每一项
for key, tensor in model_state_dict.items():
    # 对所有张量进行原地清零
    nn.init.zeros_(tensor)
    print(f"已将张量 '{key}' 清零。")

# --- 保存修改后的模型 ---
print(f"\n正在将修改后的权重保存到 {output_path}...")
# 注意：save_file 会自动创建目录（如果不存在）
save_file(model_state_dict, output_path)

print("操作完成！")

# (可选) 验证修改是否成功
print("\n正在验证已保存的文件...")
with safe_open(output_path, framework="pt", device="cpu") as f:
    # 检查一个被修改过的张量
    first_key = list(f.keys())[0]
    zeroed_tensor = f.get_tensor(first_key)
    # torch.all(tensor == 0) 会检查张量中的所有元素是否都为0
    if torch.all(zeroed_tensor == 0):
        print(f"验证成功：'{first_key}' 中的所有值都为 0。")
    else:
        print(f"验证失败：'{first_key}' 未被完全清零。")

正在从 /home2/y2024/s2430069/DiffSynth-Studio/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint/model.safetensors 加载模型权重...
模型中的原始键和形状:
- controlnet_blocks.0.input_proj.bias: torch.Size([3072])
- controlnet_blocks.0.input_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.0.output_proj.bias: torch.Size([3072])
- controlnet_blocks.0.output_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.0.x_rms.weight: torch.Size([3072])
...

正在将指定层的权重清零...
已将张量 'controlnet_blocks.0.input_proj.bias' 清零。
已将张量 'controlnet_blocks.0.input_proj.weight' 清零。
已将张量 'controlnet_blocks.0.output_proj.bias' 清零。
已将张量 'controlnet_blocks.0.output_proj.weight' 清零。
已将张量 'controlnet_blocks.0.x_rms.weight' 清零。
已将张量 'controlnet_blocks.0.y_rms.weight' 清零。
已将张量 'controlnet_blocks.1.input_proj.bias' 清零。
已将张量 'controlnet_blocks.1.input_proj.weight' 清零。
已将张量 'controlnet_blocks.1.output_proj.bias' 清零。
已将张量 'controlnet_blocks.1.output_proj.weight' 清零。
已将张量 'controlnet_blocks.1.x_rms.weight' 清零。
已将张量 'co

In [3]:


from safetensors.torch import save_file
save_file(model, "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint/model.safetensors")


模型中的键:
- controlnet_blocks.0.input_proj.bias: torch.Size([3072])
- controlnet_blocks.0.input_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.0.output_proj.bias: torch.Size([3072])
- controlnet_blocks.0.output_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.0.x_rms.weight: torch.Size([3072])
- controlnet_blocks.0.y_rms.weight: torch.Size([3072])
- controlnet_blocks.1.input_proj.bias: torch.Size([3072])
- controlnet_blocks.1.input_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.1.output_proj.bias: torch.Size([3072])
- controlnet_blocks.1.output_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.1.x_rms.weight: torch.Size([3072])
- controlnet_blocks.1.y_rms.weight: torch.Size([3072])
- controlnet_blocks.10.input_proj.bias: torch.Size([3072])
- controlnet_blocks.10.input_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks.10.output_proj.bias: torch.Size([3072])
- controlnet_blocks.10.output_proj.weight: torch.Size([3072, 3072])
- controlnet_blocks

AttributeError: 'dict' object has no attribute 'img_in'

In [None]:
import json

json_path = "prepared_data_original/metadata.json"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

for item in data:
    if item.get("ref_file"):
        item["ref_file"] = f"ref/{item['ref_file']}"
    if item.get("tgt_original_file"):
        item["tgt_original_file"] = f"tgt_original/{item['tgt_original_file']}"
    if item.get("tgt_original_mask_file"):
        item["tgt_original_mask_file"] = f"tgt_original/{item['tgt_original_mask_file']}"
    if item.get("edit_image"):
        item["edit_image"] = f"tgt_clean/{item['edit_image']}"

with open(json_path, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

print("图片路径已批量修正为 ref/ 和 tgt_original/ 子目录。")

In [None]:
data

In [None]:
import os
import json
import tqdm

json_path = "prepared_data_original/metadata.json"
base_ref = "prepared_data_original/ref"
base_tgt = "prepared_data_original/tgt_original"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

# 统计有效图片数量
valid_ref = 0
valid_tgt = 0
for item in tqdm.tqdm(data):
    if item.get("ref_file"):
        ref_path = os.path.join(base_ref, item["ref_file"].split("/")[-1])
        if os.path.isfile(ref_path):
            valid_ref += 1
            item["ref_file"] = f"ref/{item['ref_file'].split('/')[-1]}"
        else:
            item["ref_file"] = None
    if item.get("tgt_original_file"):
        tgt_path = os.path.join(base_tgt, item["tgt_original_file"].split("/")[-1])
        if os.path.isfile(tgt_path):
            valid_tgt += 1
            item["tgt_original_file"] = f"tgt_original/{item['tgt_original_file'].split('/')[-1]}"
        else:
            item["tgt_original_file"] = None
    if item.get("tgt_original_mask_file"):
        mask_path = os.path.join(base_tgt, item["tgt_original_mask_file"].split("/")[-1])
        if os.path.isfile(mask_path):
            item["tgt_original_mask_file"] = f"tgt_original/{item['tgt_original_mask_file'].split('/')[-1]}"
        else:
            item["tgt_original_mask_file"] = None
    if item.get("edit_image"):
        edit_path = os.path.join(base_tgt, item["edit_image"].split("/")[-1])
        if os.path.isfile(edit_path):
            item["edit_image"] = f"tgt_clean/{item['edit_image'].split('/')[-1]}"
        else:
            item["edit_image"] = None

# 删除无效图片的条目
data = [item for item in data if item.get("ref_file") and item.get("tgt_original_file")]

with open(json_path, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

print(f"有效ref图片数量: {valid_ref}")
print(f"有效tgt_original图片数量: {valid_tgt}")
print("无效图片条目已删除，路径已修正。")

In [None]:
#修改"edit_image"为"image",删除"edit_image"
for item in data:
    if item.get("edit_image"):
        item["image"] = f"tgt_clean/{item['edit_image'].split('/')[-1]}"
        del item["edit_image"]


In [None]:
#加入'edit_image" 为['tgt_original_file'，'ref_file']
for item in data:
    item["blockwise_controlnet_image"] = item.get("tgt_original_file")
    item["blockwise_controlnet_inpaint_mask"] = item.get("tgt_original_mask_file")
    del item["blockwise_controblockwise_controlnet_inpaint_masklnet_image"]


In [25]:
data

[{'id': '0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d',
  'prompt': 'At the foot of a rainforest trail under heavy overcast skies, it rests on dewy leaves, with towering green foliage all around.',
  'ref_file': 'ref/0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d.png',
  'ref_mask_file': None,
  'ref_mask_type': None,
  'tgt_original_file': 'tgt_original/0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d.png',
  'tgt_original_mask_file': 'tgt_original/0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d_mask.png',
  'image': 'tgt_clean/0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d.png',
  'edit_image': ['tgt_original/0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d.png',
   'ref/0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d.png'],
  'blockwise_controlnet_image': 'tgt_original/0000749a25d7882bad58e6a34200804db1da5a5ff704abf96b96f17688c03f4d.png',
  'blockwise_controlnet_

In [24]:
json_path

'prepared_data_original/metadata.json'

In [None]:
#保存
with open(json_path, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=4)


In [None]:
import os
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

json_path = "prepared_data_original/metadata.json"
base_ref = "prepared_data_original/ref"
base_tgt = "prepared_data_original/tgt_original"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

def check_and_fix(item):
    result = item.copy()
    valid_ref = False
    valid_tgt = False
    # 检查ref图片
    if item.get("ref_file"):
        ref_name = item["ref_file"].split("/")[-1]
        ref_path = os.path.join(base_ref, ref_name)
        if os.path.isfile(ref_path):
            result["ref_file"] = f"ref/{ref_name}"
            valid_ref = True
        else:
            result["ref_file"] = None
    # 检查tgt图片
    if item.get("tgt_original_file"):
        tgt_name = item["tgt_original_file"].split("/")[-1]
        tgt_path = os.path.join(base_tgt, tgt_name)
        if os.path.isfile(tgt_path):
            result["tgt_original_file"] = f"tgt_original/{tgt_name}"
            valid_tgt = True
        else:
            result["tgt_original_file"] = None
    # 检查mask图片
    if item.get("tgt_original_mask_file"):
        mask_name = item["tgt_original_mask_file"].split("/")[-1]
        mask_path = os.path.join(base_tgt, mask_name)
        if os.path.isfile(mask_path):
            result["tgt_original_mask_file"] = f"tgt_original/{mask_name}"
        else:
            result["tgt_original_mask_file"] = None
    return result, valid_ref, valid_tgt

new_data = []
valid_ref = 0
valid_tgt = 0
with ThreadPoolExecutor(max_workers=64) as executor:
    futures = [executor.submit(check_and_fix, item) for item in data]
    for fut in tqdm(as_completed(futures), total=len(futures), desc="检查图片"):
        item, ref_ok, tgt_ok = fut.result()
        if item.get("ref_file") and item.get("tgt_original_file"):
            new_data.append(item)
        if ref_ok: valid_ref += 1
        if tgt_ok: valid_tgt += 1

with open(json_path, "w", encoding="utf-8") as f:
    json.dump(new_data, f, ensure_ascii=False, indent=2)

print(f"有效ref图片数量: {valid_ref}")
print(f"有效tgt_original图片数量: {valid_tgt}")
print("无效图片条目已删除，路径已修正。")

In [None]:
import os
import json

def check_files(data):
    """
    检查JSON数据中引用的所有文件是否存在
    
    参数:
        data: 包含文件路径的字典
    
    返回:
        包含检查结果的字典
    """
    results = {
        "all_exist": True,
        "missing_files": [],
        "existing_files": [],
        "details": {}
    }
    
    # 需要检查的文件字段
    file_fields = [
        "ref_file",
        "ref_mask_file",
        "tgt_original_file",
        "tgt_original_mask_file",
        "image",
        "blockwise_controlnet_image",
        "blockwise_controlnet_inpaint_mask"
    ]
    
    # 检查单个文件字段
    for field in file_fields:
        if field in data and data[field] is not None:
            file_path = data[field]
            exists = os.path.exists(file_path)
            
            results["details"][field] = {
                "path": file_path,
                "exists": exists
            }
            
            if exists:
                results["existing_files"].append(file_path)
            else:
                results["missing_files"].append(file_path)
                results["all_exist"] = False
        else:
            results["details"][field] = {
                "path": None,
                "exists": None
            }
    
    # 检查edit_image数组中的文件
    if "edit_image" in data and data["edit_image"]:
        for idx, file_path in enumerate(data["edit_image"]):
            exists = os.path.exists(file_path)
            field_name = f"edit_image[{idx}]"
            
            results["details"][field_name] = {
                "path": file_path,
                "exists": exists
            }
            
            if exists:
                results["existing_files"].append(file_path)
            else:
                results["missing_files"].append(file_path)
                results["all_exist"] = False
    
    return results


def print_results(results):
    """打印检查结果"""
    print("=" * 60)
    print("文件存在性检查结果")
    print("=" * 60)
    
    if results["all_exist"]:
        print("\n✓ 所有文件都存在！")
    else:
        print(f"\n✗ 发现 {len(results['missing_files'])} 个文件缺失")
    
    print(f"\n存在的文件: {len(results['existing_files'])}")
    print(f"缺失的文件: {len(results['missing_files'])}")
    
    print("\n详细信息:")
    print("-" * 60)
    
    for field, info in results["details"].items():
        if info["path"] is None:
            status = "N/A"
            symbol = "○"
        elif info["exists"]:
            status = "存在"
            symbol = "✓"
        else:
            status = "缺失"
            symbol = "✗"
        
        print(f"{symbol} {field:35} {status:6} {info['path'] or 'null'}")
    
    if results["missing_files"]:
        print("\n" + "=" * 60)
        print("缺失的文件列表:")
        print("=" * 60)
        for file in results["missing_files"]:
            print(f"  - {file}")
    
    print("\n" + "=" * 60)


# 示例用法
if __name__ == "__main__":
    # 你的数据
    data = {
        "id": "7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61",
        "prompt": "Nestled in a dense forest, this item is captured from a front-facing perspective, partly hidden by the foliage under the dappled light of the afternoon, while leaves swirl gently in the cool breeze.",
        "ref_file": "ref/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "ref_mask_file": None,
        "ref_mask_type": None,
        "tgt_original_file": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "tgt_original_mask_file": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61_mask.png",
        "image": "tgt_clean/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "edit_image": [
            "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
            "ref/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png"
        ],
        "blockwise_controlnet_image": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61.png",
        "blockwise_controlnet_inpaint_mask": "tgt_original/7cb96767571821fc1453bf8a557d9566605a47a8f00f3488ef792d432950dc61_mask.png"
    }
    
    # 检查文件
    results = check_files(data)
    
    # 打印结果
    print_results(results)
    
    # 可选：保存结果到JSON文件
    # with open("file_check_results.json", "w", encoding="utf-8") as f:
    #     json.dump(results, f, indent=2, ensure_ascii=False)