# 支持Learnable Layout的Tangram, 使用了Image Prompt

In [None]:
%pip install --upgrade -r requirements.txt
%pip install rp --upgrade


In [None]:
import numpy as np
import rp
import torch
import torch.nn as nn
import torch.nn.functional as F
import source.stable_diffusion as sd
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel
import math
import time
import os

# 环境配置：使用镜像源加速模型下载
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

# 初始化设备与模型
gpu = rp.select_torch_device()
model_name = "runwayml/stable-diffusion-v1-5"
s = sd.StableDiffusion(gpu, model_name)
device = s.device

print(f"Device: {device} | Model: {model_name}")

In [None]:
def make_xy_grid(h, w, device):
    ys = torch.linspace(-1, 1, h, device=device)
    xs = torch.linspace(-1, 1, w, device=device)
    y, x = torch.meshgrid(ys, xs, indexing='ij')
    return x, y

def triangle_mask(x, y, p0, p1, p2, sharpness=80.0):
    pts = [p0, p1, p2]
    m = torch.ones_like(x)
    for i in range(3):
        pi, pj, pk = pts[i], pts[(i+1)%3], pts[(i+2)%3]
        ex, ey = (pj[0]-pi[0]), (pj[1]-pi[1])
        a, b, c = -ey, ex, -( (-ey)*pi[0] + ex*pi[1] )
        if a*pk[0] + b*pk[1] + c < 0: a, b, c = -a, -b, -c
        m = m * torch.sigmoid(sharpness * (a*x + b*y + c))
    return m

def square_mask(x, y, cx=0.0, cy=0.0, s=0.25, sharpness=80.0):
    return (torch.sigmoid(sharpness*(x-(cx-s))) * torch.sigmoid(sharpness*((cx+s)-x)) *
            torch.sigmoid(sharpness*(y-(cy-s))) * torch.sigmoid(sharpness*((cy+s)-y)))

def parallelogram_mask(x, y, cx=0.0, cy=0.0, w=0.55, h=0.22, shear=0.6, sharpness=80.0):
    u = x - shear*y
    return (torch.sigmoid(sharpness*(u-(cx-w/2))) * torch.sigmoid(sharpness*((cx+w/2)-u)) *
            torch.sigmoid(sharpness*(y-(cy-h/2))) * torch.sigmoid(sharpness*((cy+h/2)-y)))

# 初始化 7 个碎片
SIZE = 512
x_grid, y_grid = make_xy_grid(SIZE, SIZE, device)
T1 = triangle_mask(x_grid, y_grid, (-0.95, -0.95), (-0.15, -0.95), (-0.95, -0.15))
T2 = triangle_mask(x_grid, y_grid, ( 0.15, -0.95), ( 0.95, -0.95), ( 0.95, -0.15))
T3 = triangle_mask(x_grid, y_grid, (-0.95,  0.15), (-0.15,  0.15), (-0.95,  0.95))
T4 = triangle_mask(x_grid, y_grid, ( 0.15,  0.15), ( 0.55,  0.15), ( 0.15,  0.55))
T5 = triangle_mask(x_grid, y_grid, ( 0.60,  0.20), ( 0.95,  0.20), ( 0.95,  0.55))
SQ = square_mask(x_grid, y_grid, cx=0.55, cy=0.75, s=0.18)
PA = parallelogram_mask(x_grid, y_grid, cx=0.10, cy=0.75, w=0.55, h=0.22, shear=0.6)

piece_masks = [T1, T2, T3, T4, T5, SQ, PA]

In [None]:
class MultiTangramLayout(nn.Module):
    def __init__(self, device, num_layouts=2, initial_configs=None):
        super().__init__()
        # 形状为 [N, 7, 4]，N 是布局数量
        self.params = nn.Parameter(torch.zeros((num_layouts, 7, 4), device=device))
        
        if initial_configs:
            # initial_configs 应该是一个包含 N 个 list 的 list
            with torch.no_grad():
                for n in range(num_layouts):
                    for i, cfg in enumerate(initial_configs[n]):
                        self.params[n, i, 0] = cfg.get('tx', 0)
                        self.params[n, i, 1] = cfg.get('ty', 0)
                        self.params[n, i, 2] = math.radians(cfg.get('rot_deg', 0))
                        self.params[n, i, 3] = math.log(cfg.get('scale', 1.0))

    def forward(self, prime_img, piece_masks, size, layout_idx=0):
        out = torch.zeros_like(prime_img)
        transformed_masks = []
        
        current_params = self.params[layout_idx]
        
        for i in range(len(piece_masks)):
            p = current_params[i]
            tx, ty = torch.tanh(p[0]), torch.tanh(p[1])
            rot_rad = p[2]
            scale = torch.exp(p[3])
            
            # 仿射变换矩阵
            cos_t, sin_t = torch.cos(rot_rad), torch.sin(rot_rad)
            theta = torch.stack([
                torch.stack([scale * cos_t, -scale * sin_t, tx]),
                torch.stack([scale * sin_t,  scale * cos_t, ty])
            ]).unsqueeze(0)
            
            p_img = prime_img * piece_masks[i].unsqueeze(0)
            grid = F.affine_grid(theta, size=(1, 3, size, size), align_corners=False)
            warped = F.grid_sample(p_img.unsqueeze(0), grid, mode='bilinear', 
                                   padding_mode='zeros', align_corners=False).squeeze(0)
            out = out + warped
            
            m_warped = F.grid_sample(piece_masks[i].view(1,1,size,size), grid, 
                                     mode='bilinear', padding_mode='zeros', align_corners=False)
            transformed_masks.append(m_warped)
            
        return out, transformed_masks

In [None]:
ARR_A = [dict(tx=-0.2, ty=-0.2, rot_deg=0, scale=1.0), dict(tx=0.2, ty=-0.2, rot_deg=90, scale=1.0), dict(tx=-0.2, ty=0.2, rot_deg=-90, scale=0.9), dict(tx=0.15, ty=0.15, rot_deg=0, scale=0.9), dict(tx=0.35, ty=0.35, rot_deg=180, scale=0.85), dict(tx=0.0, ty=0.35, rot_deg=45, scale=0.9), dict(tx=-0.05, ty=0.05, rot_deg=0, scale=0.95)]
ARR_B = [dict(tx=-0.35, ty=-0.1, rot_deg=45, scale=1.0), dict(tx=0.1, ty=-0.35, rot_deg=135, scale=1.0), dict(tx=-0.05, ty=0.05, rot_deg=45, scale=0.95), dict(tx=0.25, ty=0.15, rot_deg=45, scale=0.9), dict(tx=0.35, ty=0.35, rot_deg=45, scale=0.85), dict(tx=0.05, ty=0.3, rot_deg=0, scale=0.95), dict(tx=-0.2, ty=0.25, rot_deg=45, scale=0.95, flip_x=True)]

all_initial_configs = [ARR_A, ARR_B]

def load_ref_image(path, device):
    img = rp.load_image(path)
    img = rp.resize_image(img, (512, 512)) # 必须缩放到 512 匹配 SD
    img_t = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0) / 255.0
    return img_t.to(device)
    
ref_images = [
    load_ref_image("3.jpg", device),   # 对应 Prompt A
    load_ref_image("2.jpg", device)  # 对应 Prompt B
]
ref_image_coefs = [0.3, 0.2]


target_prompts = [
    "",
    "a blue robot, cinematic lighting, high quality"
]
labels = [NegativeLabel(p, "low quality, blurry") for p in target_prompts]

prime = LearnableImageFourier(height=SIZE, width=SIZE).to(device)
multi_layout = MultiTangramLayout(device, num_layouts=2, initial_configs=all_initial_configs).to(device)

optim = torch.optim.Adam([
    {'params': prime.parameters(), 'lr': 1e-3},
    {'params': multi_layout.parameters(), 'lr': 5e-4} 
])


In [None]:

# --- 参数配置 ---
NUM_ITER = 4000
DISPLAY_INTERVAL = 100
LAMBDA_OVERLAP = 25.0  
LAMBDA_COMPACT = 0.05

television = rp.JupyterDisplayChannel(); television.display()
display_eta = rp.eta(NUM_ITER, title='Multi-Tangram Training')
ims = []

def get_constraints_multi(masks, idx, model):
    all_masks = torch.cat(masks, dim=1)
    mask_sum = all_masks.sum(dim=1)
    overlap_loss = torch.mean(F.relu(mask_sum - 1.1)**2)
    compact_loss = torch.mean(model.params[idx, :, :2]**2)
    return overlap_loss, compact_loss

try:
    for iter_num in range(NUM_ITER):
        display_eta(iter_num)
        
        total_geom_loss = 0
        
        for idx in range(len(labels)):
            
            img, masks = multi_layout(prime(), piece_masks, SIZE, layout_idx=idx)
            
            s.train_step(
                text_embeddings = labels[idx].embedding,
                pred_rgb = img[None],
                guidance_scale = 100,
                noise_coef = 0.1,
                ref_image = ref_images[idx],   
                ref_image_coef = ref_image_coefs[idx]
            )
            
            # 计算该布局下的几何约束 
            overlap_l, compact_l = get_constraints_multi(masks, idx, multi_layout)
            total_geom_loss += (overlap_l * LAMBDA_OVERLAP) + (compact_l * LAMBDA_COMPACT)
        
        total_geom_loss.backward()
        optim.step()
        optim.zero_grad()
        
        if not iter_num % DISPLAY_INTERVAL:
            with torch.no_grad():
                current_views = []
                for idx in range(len(labels)):
                    v_img, _ = multi_layout(prime(), piece_masks, SIZE, layout_idx=idx)
                    current_views.append(rp.as_numpy_image(v_img.clamp(0, 1)))
                

                base_tex = rp.as_numpy_image(prime().clamp(0, 1))
                
                combined = rp.tiled_images(current_views + [base_tex], length=len(labels)+1)
                
                television.update(combined)
                ims.append(combined)
                rp.display_image(combined)

except KeyboardInterrupt:
    print('Interrupted at', iter_num)
    

In [None]:
all_params = multi_layout.params 
for layout_idx in range(2):
    current_layout = all_params[layout_idx] 

    print(f"Learned Tangram Layout {layout_idx} (tx, ty, rot_deg, scale):")
    for i in range(current_layout.size(0)):
        p = current_layout[i].flatten()
    
        print(f"Piece {i}: tx={torch.tanh(p[0]).item():.2f}, "
              f"ty={torch.tanh(p[1]).item():.2f}, "
              f"rot={math.degrees(p[2].item()):.1f}°, "
              f"scale={torch.exp(p[3]).item():.2f}")

In [None]:
import time
def save_run(name):
    folder = f"untracked/learnabletangram_runs/{name}"
    if rp.path_exists(folder):
        folder += f"_{int(time.time())}"
    rp.make_directory(folder)
    ims_names = [f"ims_{i:04d}.png" for i in range(len(ims))]
    with rp.SetCurrentDirectoryTemporarily(folder):
        rp.save_images(ims, ims_names, show_progress=True)
    print('Saved timelapse to:', folder)

save_run('tangram_a-b')
