# Check influence of masking certain regions in h-space

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sdhelper import SD

sd = SD('SDXL-Turbo')

In [None]:
# config
prompt = "a photo of a cat"
seed = 42

In [None]:
# base image
base_result = sd(prompt, seed=seed)
base_result.result_image

In [None]:
# modified image
def show_modification(mask: torch.Tensor):
    # plot base
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 4, 1)
    plt.title('Base')
    plt.imshow(base_result.result_image)
    plt.axis('off')

    # setup mask application function
    mask = mask.to(sd.device, dtype=torch.float16)
    def mod_fn(module, input, output, pos, mask=mask):
        # only modify the mid_block
        if pos != 'mid_block': return None
        # resize mask to match output
        while mask.shape[-1] < output.shape[-1]:
            mask = mask.repeat_interleave(2, 0).repeat_interleave(2, 1)
        while mask.shape[-1] > output.shape[-1]:
            mask = torch.logical_or(torch.logical_or(mask[::2, ::2], mask[1::2, ::2]), torch.logical_or(mask[::2, 1::2], mask[1::2, 1::2]))
        while len(mask.shape) < len(output.shape):
            mask = mask.unsqueeze(0)
        # apply mask
        return output * (mask)

    # plot modified
    modified_result = sd(prompt, seed=seed, modification=mod_fn)
    plt.subplot(1, 4, 2)
    plt.title('Modified')
    plt.imshow(modified_result.result_image)
    plt.axis('off')

    # plot difference
    diff = (np.array(modified_result.result_image) - np.array(base_result.result_image)) / 2
    diff += 128
    diff = diff.clip(0, 255).astype(np.uint8)
    plt.subplot(1, 4, 3)
    plt.title('Difference')
    plt.imshow(diff)
    plt.axis('off')

    # plot mask
    plt.subplot(1, 4, 4)
    plt.title('Mask')
    plt.imshow(mask.cpu().numpy())
    plt.show()

In [None]:
mask = torch.zeros([16,16])
mask[1:15,1:15] = 1
show_modification(mask)

In [None]:
mask = torch.ones([16,16])
mask[1:15,1:15] = 0
show_modification(mask)

In [None]:
mask = torch.zeros([16,16])
mask[1:15,1:15] = 1
mask[2:14,2:14] = 0
show_modification(mask)

In [None]:
mask = torch.zeros([16,16])
mask[1:15,1:15] = 1
mask[2:14,2:14] = 0
mask[3:13,3:13] = 1
mask[4:12,4:12] = 0
mask[5:11,5:11] = 1
mask[6:10,6:10] = 0
show_modification(mask)

In [None]:
mask = torch.zeros([16,16])
show_modification(mask)