In [None]:
#这个代码对应的是我们论文中的Component-Aware Diffusion Illusions
import numpy as np
import rp
import torch
import torch.nn as nn
import cv2
import source.stable_diffusion as sd
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel
from itertools import chain
import torchvision.transforms.functional as TF
from face_parser import FaceParser

In [None]:
# 1. 加载 Stable Diffusion
import os
model_name = "./weights/sd-v1-4"
# 确保你已经安装了 modelscope: pip install modelscope
if 's' not in dir():
    # --- 1. 配置模型保存的路径 ---
    # --- 2. 自动检查并下载逻辑 ---
    if not os.path.exists(model_name):
        print(f"模型路径 {model_name} 不存在，正在通过 ModelScope 自动下载...")
        try:
            from modelscope import snapshot_download
            snapshot_download('AI-ModelScope/stable-diffusion-v1-4', local_dir=model_name)
            print("模型下载完成！")
        except ImportError:
            raise ImportError("请先运行 pip install modelscope 以支持自动下载功能")
    else:
        print(f"检测到本地模型: {model_name}")
gpu=rp.select_torch_device()
s=sd.StableDiffusion(gpu,model_name)
device=s.device

# 2. 设置输入图片和解析
INPUT_IMAGE_PATH = 'images/face1.jpg'  # 原始人脸图片
SIZE = 512 

print("正在解析人脸五官...")
parser = FaceParser(model_path='cp/79999_iter.pth')
parse_result = parser.parse(INPUT_IMAGE_PATH, target_size=(SIZE, SIZE))
masks = parse_result['masks']
original_img = parse_result['original_img']

rp.display_image(original_img)
print("解析完成。可用掩码:", masks.keys())

# 3. 定义五官到水果的映射
prompt_map = {
    "Left_Eye": "a single dark blackberry fruit, realistic, macro photography, high contrast, white background",
    "Right_Eye": "a single dark blackberry fruit, realistic, macro photography, high contrast, white background",
    "Nose": "a single ripe pear, top view, subtle shadows, white background",
    "Mouth": "a fresh banana slice, detailed texture, macro, white background",
    # 这里的关键：让盘子带有淡淡的人脸凹凸感
    "Face_Skin": "a white ceramic plate with faint human face contours, marble texture, smooth porcelain, minimalist",
    # 全局：强调 Arcimboldo (由水果组成的人脸) 风格
    "Global": "a portrait of a person made entirely out of fruits, Giuseppe Arcimboldo style, surrealism, high quality, fruits arranged in a face shape"
}

# 负面提示词
negative_prompt = "color, messy, human skin, realistic eye, teeth, noise, blur, watermark, text, low quality"

mask_tensors = {}
for k, v in masks.items():
    # 扩展维度 [1, 1, H, W]
    m = torch.from_numpy(v).unsqueeze(0).unsqueeze(0).to(device)
    mask_tensors[k] = m

In [None]:
main_image = LearnableImageFourier(height=SIZE, width=SIZE, hidden_dim=256, num_features=128).to(device)

for p in main_image.parameters():
    p.requires_grad = True

optim = torch.optim.Adam(main_image.parameters(), lr=1e-3)

print("正在编码 Prompt Embeddings...")
embeddings = {}
for key, prompt in prompt_map.items():
    embeddings[key] = s.get_text_embeddings(prompt)

print("Embeddings 准备就绪。")

def get_masked_view(image_tensor, mask_name):
    """
    确保 mask 操作不切断梯度链
    """
    if mask_name == "Global":
        return image_tensor
        
    mask = mask_tensors[mask_name]
    white_bg = torch.ones_like(image_tensor)

    return image_tensor * mask + white_bg * (1 - mask)

In [None]:
import torch
import gc
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

NUM_ITER = 5000          
DISPLAY_INTERVAL = 200   
SWITCH_INTERVAL = 20     

display_eta = rp.eta(NUM_ITER, title='生成进度: ')
keys = list(prompt_map.keys())
chosen_part = np.random.choice(keys)

print(f'开始生成 "水果人脸拼盘"...')

gc.collect()
torch.cuda.empty_cache()

try:
    for iter_num in range(1, NUM_ITER + 1):
        display_eta(iter_num) 

        # 每 20 轮换一个部位
        if iter_num % SWITCH_INTERVAL == 0:
            chosen_part = np.random.choice(keys)
        
        optim.zero_grad() 
        
        with torch.enable_grad():
            current_image = main_image() 
            
            view = get_masked_view(current_image, chosen_part)
            
            if view.dim() == 3:
                view = view.unsqueeze(0)
            
            dummy_val = sum(p.sum() for p in main_image.parameters()) * 0
            view = view + dummy_val
            
            loss = s.train_step(
                embeddings[chosen_part], 
                view, 
                noise_coef=0.1,
                guidance_scale=60 if chosen_part != "Global" else 30
            )
        
        optim.step()

        if iter_num % DISPLAY_INTERVAL == 0:
            with torch.no_grad():
                im = rp.as_numpy_image(current_image)
                print(f"\n--- 第 {iter_num} 轮 | 部位: {chosen_part} ---")
                rp.display_image(im)
            gc.collect()
            torch.cuda.empty_cache()

        del view, current_image
                
except Exception as e:
    import traceback
    traceback.print_exc()
    print(f"\n[运行出错]: {e}")
finally:
    gc.collect()
    torch.cuda.empty_cache()

print("训练完成！")