In [None]:
!pip3 install torch torchvision -f https://download.pytorch.org/whl/torch_stable.html
!pip3 install numpy accelerate smalldiffusion tqdm diffusers transformers xformers

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting smalldiffusion
  Downloading smalldiffusion-0.4.3-py3-none-any.whl.metadata (11 kB)
Collecting xformers
  Downloading xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Downloading smalldiffusion-0.4.3-py3-none-any.whl (16 kB)
Downloading xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl (16.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.7/16.7 MB[0m [31m40.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xformers, smalldiffusion
Successfully installed smalldiffusion-0.4.3 xformers-0.0.28.post3


In [None]:
import torch
import numpy as np
import torchvision.transforms.v2.functional as TF
from accelerate import Accelerator
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer
from torch import nn
from torchvision.utils import save_image, make_grid
from smalldiffusion import ModelMixin, ScheduleLDM
from collections import namedtuple
from itertools import pairwise
from transformers import CLIPProcessor, CLIPModel
from statistics import geometric_mean
from tqdm import tqdm

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [None]:
def alpha_bar(sigma):
    return 1/(sigma**2+1)

class ModelLatentDiffusion(nn.Module, ModelMixin):
    def __init__(self, model_key, accelerator=None):
        super().__init__()
        self.accelerator = accelerator or Accelerator()
        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae")
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder")
        self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet")
        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
        self.input_dims = (self.unet.config.in_channels, self.unet.sample_size, self.unet.sample_size,)
        self.text_condition = None
        self.text_guidance_scale = None
        if is_xformers_available():
            self.unet.enable_xformers_memory_efficient_attention()
        self.to(self.accelerator.device)

    def tokenize(self, prompt):
        return self.tokenizer(
            prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
            truncation=True, return_tensors='pt'
        ).input_ids.to(self.accelerator.device)

    def embed_prompt(self, prompt):
        with torch.no_grad():
            return self.text_encoder(self.tokenize(prompt))[0]

    @torch.no_grad()
    def decode_latents(self, latents):
        decoded = self.vae.decode(latents / 0.18215).sample
        normalized = (decoded.squeeze()+1)/2 # from [-1,1] to [0, 1]
        return normalized.clamp(0,1)

    def sigma_to_t(self, sigma):
        idx = torch.searchsorted(reversed(self.scheduler.alphas_cumprod.to(sigma)), alpha_bar(sigma))
        return self.scheduler.config.num_train_timesteps - 1 - idx

    def forward(self, x, sigma, cond=None):
        z = alpha_bar(sigma).sqrt() * x
        return self.unet(z, self.sigma_to_t(sigma), encoder_hidden_states=cond).sample

Transform = namedtuple('Transform', ['fwd', 'inv'])
id_t = Transform(lambda x:x, lambda x:x)
r = lambda r: lambda x: TF.rotate(x, r)
rot_180 = Transform(r(180), r(-180))
rot_90 = Transform(r(90), r(-90))

def show_tensor(x):
    display(TF.to_pil_image(x))


def sample_base(model,
           prompt    = 'An astronaut riding a horse',
           N          = 50,
           gam        = 1.,
           mu         = 0.,
           seed       = 0,):
    """
      Generate an image provided a prompt and parameters N, gam, mu.
      The random seed can also be fixed.
    """
    model.eval() # evaluatiuon mode (not training)
    torch.manual_seed(seed) # for reproducibility

    accelerator = Accelerator() # to handle CPU, GPU, TPU

    # Process the empty prompt '' and the input prompt into their encodings
    embeds = torch.cat([model.embed_prompt(''),model.embed_prompt(prompt)])

    # Denoising schedule
    schedule = ScheduleLDM(1000)
    sigmas = schedule.sample_sigmas(N)

    xt = model.rand_input(1).to(accelerator.device) * sigmas[0]
    eps = None

    for i, (sig, sig_prev) in enumerate(tqdm(pairwise(sigmas))):
        # Predict eps with '' and with the prompt
        xts = torch.cat([xt,xt])
        with torch.no_grad(): # no need of backprogation here (save memory)
            eps_uncond,eps_cond = model.predict_eps(xts, sig.to(xt), embeds).chunk(N)
        # Hardcoded weighted average of the two computed eps
        eps_prev, eps = eps, eps_uncond*(-6.5) + eps_cond*7.5

        # Applying various modifications to improve the update
        eps_av = eps * gam + eps_prev * (1-gam)  if i > 0 else eps
        sig_p = (sig_prev/sig**mu)**(1/(1-mu))
        eta = (sig_prev**2 - sig_p**2).sqrt()

        # Update latent space
        xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)

    # Convert from latent space to image
    image = model.decode_latents(xt)

    return image

In [None]:
model    = ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vae/config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/807 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/911 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/346 [00:00<?, ?B/s]

# 1. Warm-up


### 1.1 Varying N

In [None]:
prompt = "Epic, abandoned city with ruined buildings, long deserted streets, cars aged by time, trees, flowers, scattered leaves, empty street, vibrant colors, Ghibli-inspired "
N_iter = [2, 3, 5, 10, 15, 25, 50, 75]

In [None]:
fig, axs = plt.subplots(2, 4, figsize=(10, 5))
axs = axs.ravel()

for n, ax in zip(N_iter, axs):
    img = sample_base(
        model, gam=1.2, N=n, seed=42, mu = 0.5,
        prompt = prompt
    )

    ax.imshow(img.cpu().permute(1, 2, 0).numpy())
    ax.set_title(f'N={n}')
    ax.set_axis_off()

plt.tight_layout()
plt.savefig("warm_up_N.pdf", dpi = 300)
plt.show()

### 1.2 Varying $\gamma$

In [None]:
prompt = "Epic, abandoned city with ruined buildings, long deserted streets, cars aged by time, trees, flowers, scattered leaves, empty street, vibrant colors, Ghibli-inspired "
gam_iter = [1.1, 1.5, 2.0, 2.5, 3.0]

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(8, 5))
axs = axs.ravel()

for gam, ax in zip(gam_iter, axs):
    img = sample_base(
        model, gam=gam, N=25, seed=42, mu = 0.5,
        prompt = prompt
    )

    ax.imshow(img.cpu().permute(1, 2, 0).numpy())
    ax.set_title(r'$\gamma$'+ f"={gam}")
    ax.set_axis_off()

plt.tight_layout()
plt.savefig("warm_up_gamma.pdf", dpi = 300)
plt.show()

### 1.3 Varying $\lambda$

In [None]:
prompt = "Epic, abandoned city with ruined buildings, long deserted streets, cars aged by time, trees, flowers, scattered leaves, empty street, vibrant colors, Ghibli-inspired "
lam_iter = [1.0, 2.0, 5.0, 7.0, 10.0]

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(8, 5))
axs = axs.ravel()

for lam, ax in zip(lam_iter, axs):
    img = sample_base(
        model, gam=1.2, N=25, seed=42, lam = lam, mu = 0.5,
        prompt = prompt
    )

    ax.imshow(img.cpu().permute(1, 2, 0).numpy())
    ax.set_title(r'$\lambda$'+ f"={lam}")
    ax.set_axis_off()

plt.tight_layout()
plt.savefig("warm_up_lam.pdf", dpi = 300)
plt.show()

### 1.4 Varying $\mu$

In [None]:
prompt = "Epic, abandoned city with ruined buildings, long deserted streets, cars aged by time, trees, flowers, scattered leaves, empty street, vibrant colors, Ghibli-inspired "
mu_iter = [0.0, 0.2, 0.5, 0.8, .98]

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(8, 5))
axs = axs.ravel()

for mu, ax in zip(mu_iter, axs):
    img = sample_base(
        model, gam=1.2, N=25, seed=42,  mu = mu,
        prompt = prompt
    )

    ax.imshow(img.cpu().permute(1, 2, 0).numpy())
    ax.set_title(r'$\mu$'+ f"={mu}")
    ax.set_axis_off()

plt.tight_layout()
plt.savefig("warm_up_mu.pdf", dpi = 300)
plt.show()

# 2. Interpolation

### 2.1 New Sampling Function

In [None]:
def sample_base_prompts(
    model,
    list_prompts,
    weights,
    N = 50,
    gam = 1.,
    mu = 0.,
    seed = 0) :

    model.eval() # evaluatiuon mode (not training)
    torch.manual_seed(seed)


    accelerator = Accelerator()

    embeds = torch.cat([model.embed_prompt(p) for p in list_prompts], dim = 0)

    schedule = ScheduleLDM(1000)
    sigmas = schedule.sample_sigmas(N)


    xt = model.rand_input(1).to(accelerator.device) * sigmas[0]
    eps_curr = None


    for i, (sig, sig_prev) in enumerate(tqdm(pairwise(sigmas))):

        xts = xt.repeat(len(list_prompts), 1, 1, 1)

        with torch.no_grad(): # no need of backprogation here (save memory)
            eps_preds = model.predict_eps(xts, sig.to(xt), embeds).chunk(N)

        eps_prev, eps_curr = eps_curr, sum(w * e for w, e in zip(weights, eps_preds))

        # Applying various modifications to improve the update
        eps_av = eps_curr * gam + eps_prev * (1-gam)  if i > 0 else eps_curr
        sig_p = (sig_prev/sig**mu)**(1/(1-mu))
        eta = (sig_prev**2 - sig_p**2).sqrt()

        # Update latent space
        xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)

    # Convert from latent space to image
    image = model.decode_latents(xt)

    return image

In [None]:
lam = 7.0
im1 = sample_base_prompts(
    model = model,
    list_prompts = ['', 'a castle', 'oil painting, ghibli inspired', 'in the woods'],
    weights = [(1-lam), -lam*6, lam*2.0, lam*5.0],
    N=25,
    gam = 1.1, mu = 0)

In [None]:
fig = plt.figure(figsize=(10, 5))

plt.imshow(im1.cpu().permute(1, 2, 0).numpy())
plt.axis("off")

plt.savefig("interpolation.pdf", dpi = 300)
plt.show()

### 2.2 Morphing Sequence

In [None]:
def generate_morp_sequence(model,
                           prompt_start,
                           prompt_end,
                           steps = 5,
                           seed = 42,
                           lam = 7.0,
                           **kwargs):

    images = []

    for i in range(steps+1):

      weight_start = (steps - i) / steps
      weight_end = i / steps


      image = sample_base_prompts(
          model = model,
          list_prompts = ['', prompt_start, prompt_end],
          weights = [(1-lam), lam*weight_start, lam*weight_end],
          seed = seed,
          **kwargs
      )

      images.append(image.cpu().permute(1, 2, 0).numpy())

    return images

In [None]:
prompt_start = "A horse standing left side-on looking in a grassy field, eating grass"
prompt_end = "A zebra standing left side-on looking in a grassy field, eating grass"

images = generate_morp_sequence(
    model,
    prompt_start,
    prompt_end,
    steps = 9,
    seed = 42,
    N = 25,
    gam = 1.2,
    lam = 7,
    mu=0.5
)

In [None]:
fig, axs = plt.subplots(2, 5 , figsize=(20, 8))
axs = axs.ravel()


for im, ax in zip(images, axs):

  ax.imshow(im)
  ax.set_axis_off()

plt.tight_layout()
plt.show()

# 3. Illusion

In [None]:
# Part. I - Interpolation
# see other notebook

In [None]:
# Part. II - Illusion

# Simple adaptation of the sample method.
def sample_illusion(model,
           prompt1    = 'A pinguin body',
           prompt2    = 'A zebra',
           c          = 3.0,
           flip       = rot_180,
           N          = 50,
           gam        = 1.,
           mu         = 0.,
           seed       = 0,):
    """
      Generate an image provided a prompt and parameters N, gam, mu.
      The random seed can also be fixed.
    """
    model.eval() # evaluatiuon mode (not training)
    torch.manual_seed(seed) # for reproducibility

    accelerator = Accelerator() # to handle CPU, GPU, TPU

    # Process the empty prompt '' and the input prompts into their encodings
    embeds = torch.cat([model.embed_prompt(''),model.embed_prompt(''), model.embed_prompt(prompt1), model.embed_prompt(prompt2)])
    # embeds = torch.cat([model.embed_prompt(''), model.embed_prompt(prompt1), model.embed_prompt(prompt2)])

    # Denoising schedule
    schedule = ScheduleLDM(1000)
    sigmas = schedule.sample_sigmas(N)

    xt = model.rand_input(1).to(accelerator.device) * sigmas[0]
    eps = None

    for i, (sig, sig_prev) in enumerate(tqdm(pairwise(sigmas))):

        # Predict eps with '' and the two prompts (for xt and np.flip(xt))
        xts = torch.cat([xt,flip[0](xt),xt,flip[0](xt)])
        # xts = torch.cat([xt,xt,flip[0](xt)])

        with torch.no_grad(): # no need of backprogation here (save memory)
            eps0,eps0p,eps1,eps2 = model.predict_eps(xts, sig.to(xt), embeds).chunk(N)
            # eps0, eps1,eps2 = model.predict_eps(xts, sig.to(xt), embeds).chunk(N)
            # Note:
            #   - eps0 = eps_uncond
            #   - eps1 = eps_cond (1)
            #   - eps2 = eps_cond (2)

        # Hardcoded weighted average of the two computed eps
        alpha, beta = 0.5, 0.5  # Weights for prompt1 and prompt2
        eps_prev, eps = eps, alpha*(1-c)*eps0 + beta*(1-c)*flip[1](eps0p) + c * (alpha*eps1 + beta*flip[1](eps2))
        # eps_prev, eps = eps, (1-c)*eps0 + (c/2)*eps1 + (c/2)*(flip[1](eps2))  # Classifier Free Guidance with multiple prompt


        # Applying various modifications to improve the update
        eps_av = eps * gam + eps_prev * (1-gam)  if i > 0 else eps    # Acceleration
        sig_p = (sig_prev/sig**mu)**(1/(1-mu))                        # Compute sigma_t_prime, see Slides
        eta = (sig_prev**2 - sig_p**2).sqrt()                         # see Slides

        # Update latent space
        xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)

    # Convert from latent space to image
    image = model.decode_latents(xt)

    return image

In [None]:
# General method presented in the report.
def sample_illusion_generalized(model,
           prompts,
           weights,
           transforms,
           N          = 50,
           gam        = 1.,
           mu         = 0.,
           seed       = 0,):
    """
      Generate an image provided a prompt and parameters N, gam, mu.
      The random seed can also be fixed.
    """
    model.eval() # evaluatiuon mode (not training)
    torch.manual_seed(seed) # for reproducibility

    accelerator = Accelerator() # to handle CPU, GPU, TPU

    # Process the empty prompt '' and the input prompts into their encodings
    embeds = torch.cat([model.embed_prompt(p) for p in prompts])

    # Denoising schedule
    schedule = ScheduleLDM(1000)
    sigmas = schedule.sample_sigmas(N)

    xt = model.rand_input(1).to(accelerator.device) * sigmas[0]
    eps = None

    for i, (sig, sig_prev) in enumerate(tqdm(pairwise(sigmas))):

        # Predict eps with '' and the two prompts (for xt and np.flip(xt))
        xts = torch.cat([t.fwd(xt) for t in transforms])

        with torch.no_grad(): # no need of backprogation here (save memory)
            epsilons = model.predict_eps(xts, sig.to(xt), embeds).chunk(N)

        # Hardcoded weighted average of the two computed eps
        eps_prev, eps = eps,  sum(w * t.inv(e) for w,t,e in zip(weights, transforms, epsilons))

        # Applying various modifications to improve the update
        eps_av = eps * gam + eps_prev * (1-gam)  if i > 0 else eps    # Acceleration
        sig_p = (sig_prev/sig**mu)**(1/(1-mu))                        # Compute sigma_t_prime, see Slides
        eta = (sig_prev**2 - sig_p**2).sqrt()                         # see Slides

        # Update latent space
        xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)

    # Convert from latent space to image
    image = model.decode_latents(xt)

    return image

In [None]:
# Part. II: Simple case

# Simple examples using first method

# img = sample_illusion(
#         model, gam=1.2, mu=0.5, N=50, seed=0,
#         prompt1 = 'A painting of a snowy mountain',
#         prompt2 = "A painting of a horse",
#         c = 7.0,
#         flip = rot_180
#     )

# img = sample_illusion(
#         model, gam=1.2, mu=0.5, N=50, seed=0,
#         prompt1 = "A drawing of pinguin",
#         prompt2 = "A drawing of girafe",
#         c = 7.0,
#         flip = rot_90
#     )



In [None]:
# Useful function
def show_and_save_tensor(x, filename):
    img = TF.to_pil_image(x)  # Convert tensor to PIL image
    display(img)              # Display the image
    img.save(filename)        # Save the image to the specified file

# Additional transformations

h_flip = Transform(TF.hflip, TF.hflip)
v_flip = Transform(TF.vflip, TF.vflip)
invert = Transform(TF.invert, TF.invert)
invert_simplest = Transform(lambda x: -x,  lambda x: -x)

In [None]:
# Check with first examples.
# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "A drawing of pinguin", "A drawing of girafe"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, rot_90, id_t, rot_90]
# )


# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "a blockprint of a red panda", " a blockprint of elvis presley"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, invert_simplest, id_t, invert_simplest]
# )


# ALL EXAMPLES IN THE REPORTS

# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "An oil painting of a butterfly", "An oil paitning of a human face profile"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, v_flip, id_t, v_flip]
# )

# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "An oil painting of a surfer in the sea", "An oil painting of a biker"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, v_flip, id_t, v_flip]
# )

# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "An oil painting of a butterfly", "An oil painting of a human face profile"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, h_flip, id_t, h_flip]
# )


# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "An oil painting of a cat sitting on a windowsill,", "An oil painting of a bird in flight"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, rot_180, id_t, rot_180]
# )


# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "An oil painting of a snowy mountains", "An oil painting of a lion face"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, rot_90, id_t, rot_90]
# )


# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "An oil painting of a human face", "An oil painting of an helicopter"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, invert_simplest, id_t, invert_simplest]
# )





In [None]:
# Useful cell to plot and save the results

# show_tensor(img)
# print("\n")
# show_tensor(h_flip.fwd(img))

# show_and_save_tensor(img, "An oil painting of a human face.pdf")
# print("\n")
# show_and_save_tensor(h_flip.fwd(img),"An oil painting of a human face profile_hflip.pdf")

In [None]:
# JIGSAW PUZZLE IMPLEMENTATION


import random

# Fix the random seed
RANDOM_SEED = 0
random.seed(RANDOM_SEED)

# Define a fixed permutation
grid_size = (3,3)

permutation_rows = list(range(grid_size[0]))
random.shuffle(permutation_rows)
permutation_cols = list(range(grid_size[1]))
random.shuffle(permutation_cols)

def jigsaw_transform(x):
  shape = x.shape
  rows, cols = grid_size
  x_tiled = torch.clone(x)
  batch_dims = shape[:-2] # all dim. except the last two
  nx, ny = shape[-2]//rows ,shape[-1]//cols
  for i in range(rows):
    for j in range(cols):
      x_tiled[..., i*nx:(i+1)*nx, j*ny:(j+1)*ny] = x[..., permutation_rows[i]*nx:(permutation_rows[i]+1)*nx, permutation_cols[j]*ny:(permutation_cols[j]+1)*ny]

  return x_tiled

def jigsaw_inverse(x_tiled):
  shape = x_tiled.shape
  rows, cols = grid_size
  x = torch.clone(x_tiled)
  batch_dims = shape[:-2] # all dim. except the last two
  nx, ny = shape[-2]//rows ,shape[-1]//cols
  for i in range(rows):
    for j in range(cols):
      x[..., permutation_rows[i]*nx:(permutation_rows[i]+1)*nx, permutation_cols[j]*ny:(permutation_cols[j]+1)*ny] = x_tiled[..., i*nx:(i+1)*nx, j*ny:(j+1)*ny]

  return x


# Define the jigsaw transformation
jigsaw_t = Transform(
    lambda x: jigsaw_transform(x),  # Forward transform updates globals
    lambda x: jigsaw_inverse(x)     # Inverse uses globals
)

print(permutation_rows)
print(permutation_cols)

[0, 2, 1]
[2, 1, 0]


In [None]:
# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=0,
#     prompts=["", "", "An oil painting of a hulk face", "An oil painting of a hut in the greeny wood"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, jigsaw_t, id_t, jigsaw_t]
# )


# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=42,
#     prompts=["", "", "An oil painting of a fruit bowl", "An oil painting of an old man"],
#     weights= [-3.0, -3.0, 3.5, 3.5],
#     transforms = [id_t, jigsaw_t, id_t, jigsaw_t]
# )



In [None]:
# show_tensor(img)
# print("\n")
# show_tensor(jigsaw_t.fwd(img))

# show_and_save_tensor(img, "An oil painting of a fruit bowl.pdf")
# print("\n")
# show_and_save_tensor(jigsaw_t.fwd(img),"An oil painting of an old man.pdf")

In [None]:
# Example with 3 prompts

lamb = 7.0
Np = 3
wmin = (1-lamb)/Np
wpos = (lamb)/Np

# img = sample_illusion_generalized(
#     model, gam=1.2, mu=0.5, N=50, seed=42,
#     prompts=["", "", "",  "An oil painting of a snowy moutain", "An oil painting of a lion face", "An oil painting of a beach at the sea"],
#     weights= [wmin, wmin, wmin, wpos, wpos, wpos],
#     transforms = [id_t, rot_90, rot_180, id_t, rot_90, rot_180]
# )



50it [01:04,  1.29s/it]


In [None]:
# show_tensor(img)
# print("\n")
# show_tensor(rot_90.fwd(img))
# print("\n")
# show_tensor(rot_180.fwd(img))

# show_and_save_tensor(img, "An oil painting of a snowy moutain_3.pdf")
# print("\n")
# show_and_save_tensor(rot_90.fwd(img),"An oil painting of a lion face_3.pdf")
# print("\n")
# show_and_save_tensor(rot_180.fwd(img),"An oil painting of a beach at the sea_3.pdf")