In [None]:
%pip install --upgrade -r requirements.txt
%pip install rp --upgrade
# You may need to restart the runtime after installing these
# I'm not sure why this helps, but all sorts of weird random errors pop up in Colab if you don't

In [None]:
import numpy as np
import rp
import torch
import torch.nn as nn
import source.stable_diffusion as sd
from easydict import EasyDict
from source.learnable_textures import LearnableImageFourier
from source.stable_diffusion_labels import NegativeLabel
from itertools import chain
import time
import torchvision.transforms.functional as TF
import math

In [None]:
import os
# 确保你已经安装了 modelscope: pip install modelscope
if 's' not in dir():
    # --- 1. 配置模型保存的路径 ---
    # 使用相对路径（例如当前目录下的 weights 文件夹），这样换电脑也能直接用
    model_name = "./weights/sd-v1-4"
    # --- 2. 自动检查并下载逻辑 ---
    if not os.path.exists(model_name):
        print(f"模型路径 {model_name} 不存在，正在通过 ModelScope 自动下载...")
        try:
            from modelscope import snapshot_download
            snapshot_download('AI-ModelScope/stable-diffusion-v1-4', local_dir=model_name)
            print("模型下载完成！")
        except ImportError:
            raise ImportError("请先运行 pip install modelscope 以支持自动下载功能")
    else:
        print(f"检测到本地模型: {model_name}")
gpu=rp.select_torch_device()
s=sd.StableDiffusion(gpu,model_name)
device=s.device

In [None]:
prompts_list = [
    "a lion",           
    "a mountain",       
    "a skull",          
    "a rose",           
    "an eye",           
    # "a dog",          
    # "a cat"           
]#这里可以设置多个prompt, 设置几个prompt就对应生成几个旋转视角的图像
N = len(prompts_list)
ANGLE_STEP = 360.0 / N
negative_prompt = 'ugly blur low quality'
print(f"\n=== Current configuration: generating {N} images (each rotated by {ANGLE_STEP:.2f}°) ===\n")
print('Negative prompt:', repr(negative_prompt))
print('Chosen prompts:')
for i, p in enumerate(prompts_list):
    print(f'    Image {i+1} (Angle {i*ANGLE_STEP:.1f}°) = {repr(p)}')
labels = [NegativeLabel(p, negative_prompt) for p in prompts_list]

In [None]:
SIZE=256
learnable_image_maker = lambda: LearnableImageFourier(height=SIZE, width=SIZE, hidden_dim=256, num_features=128).to(s.device)
bottom_image=learnable_image_maker()
top_image=learnable_image_maker()
brightness=3
CLEAN_MODE = True 

#用圆形掩码保证图片是圆的
def create_circular_mask(h, w):
    center = (int(w/2), int(h/2))
    radius = min(center[0], center[1], w-center[0], h-center[1])
    Y, X = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
    dist_from_center = torch.sqrt((X - center[0])**2 + (Y - center[1])**2)
    mask = dist_from_center <= radius
    return mask.float().to(device)

mask_tensor = create_circular_mask(SIZE, SIZE)

def simulate_overlay(bottom, top, angle_degrees):
    rotated_top = TF.rotate(top, angle_degrees, interpolation=TF.InterpolationMode.BILINEAR)
    if CLEAN_MODE:
        exp=1
        brightness=3
    else:
        exp=rp.random_float(.5,1)
        brightness=rp.random_float(1,5)
        black=rp.random_float(0,.5)
        bottom=rp.blend(bottom,black,rp.random_float())
        rotated_top=rp.blend(rotated_top,black,rp.random_float())

    composition = (bottom**exp * rotated_top**exp * brightness).clamp(0,99).tanh()
    return composition * mask_tensor

angles = [i * ANGLE_STEP for i in range(N)]
learnable_images = [
    (lambda angle=angle: simulate_overlay(bottom_image(), top_image(), angle)) 
    for angle in angles
]

params=chain(
    bottom_image.parameters(),
    top_image.parameters(),
)
optim=torch.optim.SGD(params,lr=1e-4)
nums = list(range(N)) 
weights = [1] * N
weights = rp.as_numpy_array(weights)
weights = weights/weights.sum() * len(weights) #归一化

In [None]:
ims=[]
def get_display_image():
    current_views = [rp.as_numpy_image(image()) for image in learnable_images]
    b_masked = rp.as_numpy_image(bottom_image() * mask_tensor)
    t_masked = rp.as_numpy_image(top_image() * mask_tensor)
    top_cols = math.ceil(math.sqrt(len(current_views)))
    if top_cols < 3: top_cols = 3
    if top_cols > 5: top_cols = 5
    
    top_grid = rp.tiled_images(current_views, length=top_cols, border_thickness=2, border_color=(255,255,255))
    bottom_row = rp.tiled_images([b_masked, t_masked], length=2, border_thickness=2, border_color=(255,255,255))
    
    h1, w1, c = top_grid.shape
    h2, w2, c = bottom_row.shape
    target_w = max(w1, w2)
    bg_color = top_grid.max() 

    def pad_to_center(img, target_w):
        h, w, _ = img.shape
        if w >= target_w: return img
        diff = target_w - w
        pad_l = diff // 2
        pad_r = diff - pad_l
        left = np.ones((h, pad_l, c), dtype=img.dtype) * bg_color
        right = np.ones((h, pad_r, c), dtype=img.dtype) * bg_color
        return np.concatenate([left, img, right], axis=1)

    final_top = pad_to_center(top_grid, target_w)
    final_bot = pad_to_center(bottom_row, target_w)
    separator = np.ones((20, target_w, c), dtype=top_grid.dtype) * bg_color
    final_image = np.concatenate([final_top, separator, final_bot], axis=0)
    return final_image

NUM_ITER=10000

s.max_step=MAX_STEP=990
s.min_step=MIN_STEP=10 

display_eta=rp.eta(NUM_ITER, title='Status: ')

DISPLAY_INTERVAL = 200

print(f'Starting training for {N} images...')
print(f'Layout: Top grid = Illusions ({N} views). Bottom row = Bottom Image + Top Image.')
print(f'Displaying progress every {DISPLAY_INTERVAL} iterations.')

try:
    for iter_num in range(NUM_ITER):
        display_eta(iter_num) 

        preds=[]
        for label,learnable_image,weight in rp.random_batch(list(zip(labels,learnable_images,weights)), batch_size=1):
            pred=s.train_step(
                label.embedding,
                learnable_image()[None],
                noise_coef=.1*weight,guidance_scale=60,
            )
            preds+=list(pred)

        with torch.no_grad():
            if iter_num and not iter_num%(DISPLAY_INTERVAL*50):
                from IPython.display import clear_output
                clear_output()

            if not iter_num%DISPLAY_INTERVAL:
                im = get_display_image()
                ims.append(im)
                rp.display_image(im)

        optim.step()
        optim.zero_grad()
except KeyboardInterrupt:
    print()
    print('Interrupted early at iteration %i'%iter_num)
    im = get_display_image()
    ims.append(im)
    rp.display_image(im)

print('Final Result Preview:')
rp.display_image(get_display_image())

In [None]:
print(">>> Prime Image 1: Bottom Layer (Base) <<<")
img_bottom = rp.as_numpy_image(bottom_image() * mask_tensor)
rp.display_image(img_bottom)
print("\n>>> Prime Image 2: Top Layer (Rotator) <<<")
img_top = rp.as_numpy_image(top_image() * mask_tensor)
rp.display_image(img_top)
print("\n>>> Side-by-Side Comparison (For Printing) <<<")
comparison = rp.tiled_images([img_bottom, img_top], length=2)
rp.display_image(comparison)

In [None]:
def save_run(name):
    folder = "untracked/rotator_%d_prompts_%d_degrees/%s" % (N, int(ANGLE_STEP), name)
    if rp.path_exists(folder):
        folder+='_%i'%time.time()
    rp.make_directory(folder)
    ims_names=['ims_%04i.png'%i for i in range(len(ims))]
    with rp.SetCurrentDirectoryTemporarily(folder):
        rp.save_images(ims,ims_names,show_progress=False)
    rp.save_image(rp.as_numpy_image(bottom_image()), folder+'/final_bottom_raw.png')
    rp.save_image(rp.as_numpy_image(top_image()), folder+'/final_top_raw.png')
    rp.save_image(rp.as_numpy_image(bottom_image()*mask_tensor), folder+'/final_bottom_masked.png')
    rp.save_image(rp.as_numpy_image(top_image()*mask_tensor), folder+'/final_top_masked.png')
    print()
    print('Saved timelapse and final images to folder:',repr(folder))
save_run(f'circle_{N}_prompts')