# 支持Image Prompt的Tangram

In [None]:
%pip install --upgrade -r requirements.txt
%pip install rp --upgrade


In [None]:
# 环境与依赖导入
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import time
import math
import numpy as np
import torch
import torch.nn.functional as F
import rp
import source.stable_diffusion as sd
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel


In [None]:
# 辅助：加载参考图像函数
def load_ref_image(path, device):
    if not os.path.exists(path):
        raise FileNotFoundError(f"找不到图片文件: {os.path.abspath(path)}")
    img = rp.load_image(path)
    img = rp.resize_image(img, (512, 512))
    img_t = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0) / 255.0
    return img_t.to(device)


In [None]:
# 设备与模型初始化（可修改 model_name）
gpu = rp.select_torch_device()
model_name = "runwayml/stable-diffusion-v1-5"
s = sd.StableDiffusion(gpu, model_name)
device = s.device


In [None]:
# 载入参考图（若无可跳过）

ref_image_a = load_ref_image("pictures/Original for Image Prompt/3.png", device)
ref_image_b = load_ref_image("pictures/Original for Image Prompt/2.png", device)
ref_image_coef_a = 0.1
ref_image_coef_b = 0.1


In [None]:
# 提示词与训练超参
prompt_a = "a red fox, ultra detailed"
prompt_b = "a blue robot, cinematic lighting"
negative_prompt = "blurry, low quality"

SIZE = 256
NUM_ITER = 6000  # 可在 notebook 中临时设小值以便测试
DISPLAY_INTERVAL = 200

label_a = NegativeLabel(prompt_a, negative_prompt)
label_b = NegativeLabel(prompt_b, negative_prompt)


In [None]:
# 初始化可学习纹理（Fourier）和优化器
prime = LearnableImageFourier(
    height=SIZE, width=SIZE, hidden_dim=256, num_features=256, scale=10
).to(device)

# 使用 Adam，学习率可根据需要调整
optim = torch.optim.Adam(prime.parameters(), lr=1e-3)


## 七巧板几何与掩码函数

In [None]:
def make_xy_grid(h, w, device):
    ys = torch.linspace(-1, 1, h, device=device)
    xs = torch.linspace(-1, 1, w, device=device)
    y, x = torch.meshgrid(ys, xs, indexing='ij')
    return x, y

def triangle_mask(x, y, p0, p1, p2, sharpness=80.0):
    pts = [p0, p1, p2]
    m = torch.ones_like(x)
    for i in range(3):
        pi, pj, pk = pts[i], pts[(i+1)%3], pts[(i+2)%3]
        ex, ey = (pj[0]-pi[0]), (pj[1]-pi[1])
        a, b, c = -ey, ex, -( (-ey)*pi[0] + ex*pi[1] )
        if a*pk[0] + b*pk[1] + c < 0: a, b, c = -a, -b, -c
        m = m * torch.sigmoid(sharpness * (a*x + b*y + c))
    return m

def square_mask(x, y, cx=0.0, cy=0.0, s=0.25, sharpness=80.0):
    return (torch.sigmoid(sharpness*(x-(cx-s))) * torch.sigmoid(sharpness*((cx+s)-x)) *
            torch.sigmoid(sharpness*(y-(cy-s))) * torch.sigmoid(sharpness*((cy+s)-y)))

def parallelogram_mask(x, y, cx=0.0, cy=0.0, w=0.55, h=0.22, shear=0.6, sharpness=80.0):
    u = x - shear*y
    return (torch.sigmoid(sharpness*(u-(cx-w/2))) * torch.sigmoid(sharpness*((cx+w/2)-u)) *
            torch.sigmoid(sharpness*(y-(cy-h/2))) * torch.sigmoid(sharpness*((cy+h/2)-y)))

x, y = make_xy_grid(SIZE, SIZE, device)
T1 = triangle_mask(x, y, (-0.95, -0.95), (-0.15, -0.95), (-0.95, -0.15))
T2 = triangle_mask(x, y, ( 0.15, -0.95), ( 0.95, -0.95), ( 0.95, -0.15))
T3 = triangle_mask(x, y, (-0.95,  0.15), (-0.15,  0.15), (-0.95,  0.95))
T4 = triangle_mask(x, y, ( 0.15,  0.15), ( 0.55,  0.15), ( 0.15,  0.55))
T5 = triangle_mask(x, y, ( 0.60,  0.20), ( 0.95,  0.20), ( 0.95,  0.55))
SQ = square_mask(x, y, cx=0.55, cy=0.75, s=0.18)
PA = parallelogram_mask(x, y, cx=0.10, cy=0.75, w=0.55, h=0.22, shear=0.6)

piece_masks = [T1, T2, T3, T4, T5, SQ, PA]
mask_sum = torch.clamp(sum(piece_masks), min=1e-3)
piece_masks = [m / mask_sum for m in piece_masks]


## 视图变换与排布函数

In [None]:
def warp(img_chw, tx, ty, rot_deg, scale, flip_x):
    th = math.radians(rot_deg)
    c, s_ = math.cos(th), math.sin(th)
    sx = -scale if flip_x else scale
    theta = torch.tensor([[[sx*c, -scale*s_, tx], [sx*s_, scale*c, ty]]], device=device)
    grid = F.affine_grid(theta, size=(1, img_chw.shape[0], SIZE, SIZE), align_corners=False)
    return F.grid_sample(img_chw.unsqueeze(0), grid, mode='bilinear', padding_mode='zeros', align_corners=False).squeeze(0)

def tangram_view(prime_img, arr):
    out = torch.zeros_like(prime_img)
    for i, prm in enumerate(arr):
        p = prime_img * piece_masks[i].unsqueeze(0)
        out = out + warp(p, prm.get('tx',0), prm.get('ty',0), prm.get('rot_deg',0), prm.get('scale',1), prm.get('flip_x',False))
    return out


In [None]:
# 两种布局
ARR_A = [dict(tx=-0.2, ty=-0.2, rot_deg=0, scale=1.0), dict(tx=0.2, ty=-0.2, rot_deg=90, scale=1.0), dict(tx=-0.2, ty=0.2, rot_deg=-90, scale=0.9), dict(tx=0.15, ty=0.15, rot_deg=0, scale=0.9), dict(tx=0.35, ty=0.35, rot_deg=180, scale=0.85), dict(tx=0.0, ty=0.35, rot_deg=45, scale=0.9), dict(tx=-0.05, ty=0.05, rot_deg=0, scale=0.95)]
ARR_B = [dict(tx=-0.35, ty=-0.1, rot_deg=45, scale=1.0), dict(tx=0.1, ty=-0.35, rot_deg=135, scale=1.0), dict(tx=-0.05, ty=0.05, rot_deg=45, scale=0.95), dict(tx=0.25, ty=0.15, rot_deg=45, scale=0.9), dict(tx=0.35, ty=0.35, rot_deg=45, scale=0.85), dict(tx=0.05, ty=0.3, rot_deg=0, scale=0.95), dict(tx=-0.2, ty=0.25, rot_deg=45, scale=0.95, flip_x=True)]

learnable_image_a = lambda: tangram_view(prime(), ARR_A)
learnable_image_b = lambda: tangram_view(prime(), ARR_B)


## 训练循环（SDS）
运行前可在上方修改 NUM_ITER、DISPLAY_INTERVAL 等参数。

In [None]:
labels = [label_a, label_b]
weights = rp.as_numpy_array([1, 1])
weights = weights / weights.sum() * len(weights)

s.max_step, s.min_step = 990, 10
ims = []

def get_display_image():
    with torch.no_grad():
        view_a = rp.as_numpy_image(learnable_image_a().clamp(0, 1))
        view_b = rp.as_numpy_image(learnable_image_b().clamp(0, 1))
        p_base = rp.as_numpy_image(prime().clamp(0, 1))
        return rp.tiled_images([view_a, view_b, p_base], length=3, border_thickness=0)

television = rp.JupyterDisplayChannel(); television.display()
display_eta = rp.eta(NUM_ITER, title='Tangram training')

train_data = [
    (label_a, learnable_image_a, weights[0], ref_image_a, ref_image_coef_a),
    (label_b, learnable_image_b, weights[1], ref_image_b, ref_image_coef_b),
]

try:
    for iter_num in range(NUM_ITER):
        display_eta(iter_num)

        for label, li, w, ref_img, ref_coef in rp.random_batch(train_data, batch_size=1):
            _ = s.train_step(
                text_embeddings = label.embedding,
                pred_rgb = li()[None],
                noise_coef = 0.1 * w,
                guidance_scale = 80,
                ref_image = ref_img,
                ref_image_coef = ref_coef
            )
        
        optim.step()
        optim.zero_grad()

        with torch.no_grad():
            if not iter_num % DISPLAY_INTERVAL:
                im = get_display_image()
                ims.append(im)
                television.update(im)
                rp.display_image(get_display_image())

except KeyboardInterrupt:
    print('Interrupted at', iter_num)
    rp.display_image(get_display_image())


In [None]:
# 保存结果函数
def save_run(name):
    folder = f"untracked/tangram_image_prompt_runs/{name}_{int(time.time())}"
    rp.make_directory(folder)
    ims_names = [f"ims_{i:04d}.png" for i in range(len(ims))]
    with rp.SetCurrentDirectoryTemporarily(folder):
        rp.save_images(ims, ims_names, show_progress=True)
    print('Saved timelapse to:', folder)

# save_run('tangram_a-b')