# Diffusion Illusions: Tangram (七巧板)

优化一个 prime image，使其在两种七巧板排布（view A / view B）下呈现不同语义内容。

- prime: `LearnableImageFourier`
- pieces: 7 个可微 mask（5 三角 + 1 正方形 + 1 平行四边形）
- view: 对每块做可微仿射变换（`affine_grid` + `grid_sample`）并合成
- loss: `StableDiffusion.train_step()` (SDS)


In [None]:
%pip install --upgrade -r requirements.txt
%pip install rp --upgrade


In [None]:
import numpy as np
import rp
import torch
import torch.nn.functional as F
import source.stable_diffusion as sd
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel
import math
import time


In [None]:

prompt_a = "a red fox, ultra detailed"          # view A
prompt_b = "a blue robot, cinematic lighting"  # view B
negative_prompt = "blurry, low quality"

SIZE = 256
NUM_ITER = 6000
DISPLAY_INTERVAL = 200

print('prompt_a =', prompt_a)
print('prompt_b =', prompt_b)
print('negative_prompt =', negative_prompt)


In [None]:

gpu = rp.select_torch_device()
model_name = "runwayml/stable-diffusion-v1-5"
s = sd.StableDiffusion(gpu, model_name)
device = s.device

label_a = NegativeLabel(prompt_a, negative_prompt)
label_b = NegativeLabel(prompt_b, negative_prompt)


In [None]:

prime = LearnableImageFourier(
    height=SIZE, width=SIZE, hidden_dim=256, num_features=256, scale=10
).to(device)

optim = torch.optim.SGD(prime.parameters(), lr=1e-4)


In [None]:

# x,y in [-1,1]
def make_xy_grid(h, w, device):
    ys = torch.linspace(-1, 1, h, device=device)
    xs = torch.linspace(-1, 1, w, device=device)
    y, x = torch.meshgrid(ys, xs, indexing='ij')
    return x, y

def halfplane_mask(x, y, a, b, c, sharpness=80.0):
    v = a*x + b*y + c
    return torch.sigmoid(sharpness * v)

def polygon_mask(x, y, halfplanes, sharpness=80.0):
    m = torch.ones_like(x)
    for (a,b,c) in halfplanes:
        m = m * halfplane_mask(x, y, a, b, c, sharpness)
    return m

def triangle_mask(x, y, p0, p1, p2, sharpness=80.0):
    pts = [p0, p1, p2]
    halfplanes = []
    for i in range(3):
        pi = pts[i]
        pj = pts[(i+1)%3]
        pk = pts[(i+2)%3]
        ex, ey = (pj[0]-pi[0]), (pj[1]-pi[1])
        a, b = -ey, ex
        c = -(a*pi[0] + b*pi[1])
        if a*pk[0] + b*pk[1] + c < 0:
            a, b, c = -a, -b, -c
        halfplanes.append((a,b,c))
    return polygon_mask(x, y, halfplanes, sharpness)

def square_mask(x, y, cx=0.0, cy=0.0, s=0.25, sharpness=80.0):
    return (
        torch.sigmoid(sharpness*(x-(cx-s))) * torch.sigmoid(sharpness*((cx+s)-x)) *
        torch.sigmoid(sharpness*(y-(cy-s))) * torch.sigmoid(sharpness*((cy+s)-y))
    )

def parallelogram_mask(x, y, cx=0.0, cy=0.0, w=0.55, h=0.22, shear=0.6, sharpness=80.0):
    u = x - shear*y
    return (
        torch.sigmoid(sharpness*(u-(cx-w/2))) * torch.sigmoid(sharpness*((cx+w/2)-u)) *
        torch.sigmoid(sharpness*(y-(cy-h/2))) * torch.sigmoid(sharpness*((cy+h/2)-y))
    )

x, y = make_xy_grid(SIZE, SIZE, device)

T1 = triangle_mask(x, y, (-0.95, -0.95), (-0.15, -0.95), (-0.95, -0.15))
T2 = triangle_mask(x, y, ( 0.15, -0.95), ( 0.95, -0.95), ( 0.95, -0.15))
T3 = triangle_mask(x, y, (-0.95,  0.15), (-0.15,  0.15), (-0.95,  0.95))
T4 = triangle_mask(x, y, ( 0.15,  0.15), ( 0.55,  0.15), ( 0.15,  0.55))
T5 = triangle_mask(x, y, ( 0.60,  0.20), ( 0.95,  0.20), ( 0.95,  0.55))
SQ = square_mask(x, y, cx=0.55, cy=0.75, s=0.18)
PA = parallelogram_mask(x, y, cx=0.10, cy=0.75, w=0.55, h=0.22, shear=0.6)

piece_masks = [T1, T2, T3, T4, T5, SQ, PA]
mask_sum = torch.clamp(sum(piece_masks), min=1e-3)
piece_masks = [m / mask_sum for m in piece_masks]


In [None]:
# affine warp + tangram view 
def affine_matrix_2x3(tx=0.0, ty=0.0, rot_deg=0.0, scale=1.0, flip_x=False):
    th = math.radians(rot_deg)
    c, s_ = math.cos(th), math.sin(th)
    sx = -scale if flip_x else scale
    return torch.tensor(
        [[sx*c, -scale*s_, tx],
         [sx*s_,  scale*c, ty]],
        device=device, dtype=torch.float32
    )

def warp(img_chw, A_2x3, out_size=SIZE):
    C, H, W = img_chw.shape
    theta = A_2x3.unsqueeze(0)
    grid = F.affine_grid(theta, size=(1, C, out_size, out_size), align_corners=False)
    out = F.grid_sample(img_chw.unsqueeze(0), grid, mode='bilinear', padding_mode='zeros', align_corners=False)
    return out.squeeze(0)

def tangram_view(prime_img_chw, arr_params):
    pieces = [(prime_img_chw * m.unsqueeze(0)) for m in piece_masks]
    out = torch.zeros_like(prime_img_chw)
    for p, prm in zip(pieces, arr_params):
        A = affine_matrix_2x3(
            tx=prm.get('tx',0.0), ty=prm.get('ty',0.0),
            rot_deg=prm.get('rot_deg',0.0), scale=prm.get('scale',1.0),
            flip_x=prm.get('flip_x',False)
        )
        out = out + warp(p, A, out_size=SIZE)
    return out.clamp(0, 1)


In [None]:
# two arrangements (A/B)

ARR_A = [
    dict(tx=-0.20, ty=-0.20, rot_deg=0,   scale=1.00),
    dict(tx= 0.20, ty=-0.20, rot_deg=90,  scale=1.00),
    dict(tx=-0.20, ty= 0.20, rot_deg=-90, scale=0.90),
    dict(tx= 0.15, ty= 0.15, rot_deg=0,   scale=0.90),
    dict(tx= 0.35, ty= 0.35, rot_deg=180, scale=0.85),
    dict(tx= 0.00, ty= 0.35, rot_deg=45,  scale=0.90),
    dict(tx=-0.05, ty= 0.05, rot_deg=0,   scale=0.95),
]

ARR_B = [
    dict(tx=-0.35, ty=-0.10, rot_deg=45,  scale=1.00),
    dict(tx= 0.10, ty=-0.35, rot_deg=135, scale=1.00),
    dict(tx=-0.05, ty= 0.05, rot_deg=45,  scale=0.95),
    dict(tx= 0.25, ty= 0.15, rot_deg=45,  scale=0.90),
    dict(tx= 0.35, ty= 0.35, rot_deg=45,  scale=0.85),
    dict(tx= 0.05, ty= 0.30, rot_deg=0,   scale=0.95),
    dict(tx=-0.20, ty= 0.25, rot_deg=45,  scale=0.95, flip_x=True),
]

learnable_image_a = lambda: tangram_view(prime(), ARR_A)
learnable_image_b = lambda: tangram_view(prime(), ARR_B)


In [None]:
# training (SDS) 
labels = [label_a, label_b]
learnable_images = [learnable_image_a, learnable_image_b]
weights = [1, 1]

weights = rp.as_numpy_array(weights)
weights = weights / weights.sum()
weights = weights * len(weights)

s.max_step = 990
s.min_step = 10

ims = []

def get_display_image():
    return rp.tiled_images(
        [rp.as_numpy_image(learnable_image_a()), rp.as_numpy_image(learnable_image_b()), rp.as_numpy_image(prime())],
        length=3,
        border_thickness=0,
    )

television = rp.JupyterDisplayChannel(); television.display()
display_eta = rp.eta(NUM_ITER, title='Tangram training')

try:
    for iter_num in range(NUM_ITER):
        display_eta(iter_num)

        for label, li, w in rp.random_batch(list(zip(labels, learnable_images, weights)), batch_size=1):
            _ = s.train_step(
                label.embedding,
                li()[None],
                noise_coef=0.1 * w,
                guidance_scale=80,
            )

        with torch.no_grad():
            if not iter_num % DISPLAY_INTERVAL:
                im = get_display_image()
                ims.append(im)
                television.update(im)
                rp.display_image(im)

        optim.step(); optim.zero_grad()

except KeyboardInterrupt:
    print('Interrupted at', iter_num)
    im = get_display_image()
    ims.append(im)
    rp.display_image(im)


In [None]:

def save_run(name):
    folder = f"untracked/tangram_runs/{name}"
    if rp.path_exists(folder):
        folder += f"_{int(time.time())}"
    rp.make_directory(folder)
    ims_names = [f"ims_{i:04d}.png" for i in range(len(ims))]
    with rp.SetCurrentDirectoryTemporarily(folder):
        rp.save_images(ims, ims_names, show_progress=True)
    print('Saved timelapse to:', folder)

save_run('tangram_a-b')
