In [None]:
!git clone https://github.com/huggingface/diffusers.git

In [None]:
!pip install huggingface_hub transformers datasets accelerate sentencepiece einops protobuf matplotlib

In [None]:
!pip install -U ./diffusers

In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

In [None]:
test_prompts_expt = [
    "A charismatic speaker is captured mid-speech. He has short, tousled brown hair that’s slightly messy on top. He has a round circle face, clean shaven, adorned with rounded rectangular-framed glasses with dark rims, is animated as he gestures with his left hand. He is holding a black microphone in his right hand, speaking passionately. The man is wearing a light grey sweater over a white t-shirt. He’s also wearing a simple black lanyard hanging around his neck. The lanyard badge has the text “Anakin AI”. Behind him, there is a blurred background with a white banner containing logos and text (including Anakin AI), a professional conference setting.", #Anakin AI
    "a red dog wearing a blue hat sits with a yellow cat wearing pink sunglasses", #wonderflex on reddit.
    "A Samsung LED moniter's screen on a table displays an image of a garden with signboard mentions 'All is Well', A teddy toy placed on the table, a cat is sleeping near the teddy toy, a mushroom dish on red plate placed on the table, raining outside, a parrot sitting on the nearby window, a flex banner with text 'Enjoy the life' visible from outside of the window,"
    "3d model of a green war balloon, clash of clans, fantasy game, front view, game asset, detailed, war ready, photorealistic, in a war enviroment, spring, disney style, pixar style",
    "Photo of a felt puppet diorama scene of a tranquil nature scene of a secluded forest clearing with a large friendly, rounded robot is rendered in a risograph style. An owl sits on the robots shoulders and a fox at its feet. Soft washes of color, 5 color, and a light-filled palette create a sense of peace and serenity, inviting contemplation and the appreciation of natural beauty.",
    "The Golden gate bridge"
]

In [None]:
import torch
import einops
from utils import plot_image_grid, cosine_similarity, plot_similarity_matrix
from custom_sd3_pipeline import StableDiffusion3Pipeline
from custom_sd3_transformer import SD3Transformer2DModel

In [None]:
expt_transformer = SD3Transformer2DModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.bfloat16, subfolder="transformer")

In [None]:
pipe = StableDiffusion3Pipeline.from_pretrained(
        "stabilityai/stable-diffusion-3-medium-diffusers",
        transformer=expt_transformer,
        torch_dtype=torch.bfloat16
        )
pipe.to("cuda")

In [None]:
def test_sd3(prompts, layer_order):
    results = []
    print(f"layer_order {layer_order}")
    generator = torch.Generator('cpu')
    generator.manual_seed(19943434)
    for prompt in prompts:
        out, _, _ = pipe(
            prompt=prompt,
            height=768,
            width=1360,
            num_inference_steps=28,
            guidance_scale=5.0,
            max_sequence_length=256,
            return_dict=True,
            layer_order=layer_order,
            generator=generator
        )
        results.append((prompt, out.images[0]))
    plot_image_grid(results, rows=len(prompts), cols=1, figsize=(25,25))
    return results
    

## Baseline

In [None]:
encoder_actns_sims = []
actns_cos_sims = []
results = []

generator = torch.Generator('cpu')
generator.manual_seed(19943434)
for prompt in test_prompts_expt:
    out, encoder_actns, actns = pipe(
        prompt=prompt,
        height=768,
        width=1360,
        num_inference_steps=28,
        max_sequence_length=256,
        guidance_scale=5.0,
        return_dict=True,
        layer_order=list(range(len(expt_transformer.transformer_blocks))),
        generator=generator,
        track_activations=True
    )
    results.append((prompt, out.images[0]))
    encoder_actns_sims.append(cosine_similarity(encoder_actns))
    actns_cos_sims.append(cosine_similarity(actns))
    del encoder_actns
    del actns

encoder_actns_sims = torch.stack(encoder_actns_sims, dim=0)
actns_cos_sims = torch.stack(actns_cos_sims, dim=0)

In [None]:
plot_similarity_matrix(torch.mean(encoder_actns_sims,dim=0))

In [None]:
plot_similarity_matrix(torch.mean(actns_cos_sims,dim=0))

In [None]:
plot_image_grid(results,rows=5,cols=1,figsize=(30,25))

In [None]:
# Middle layers seem to be from 5 to 20

In [None]:
N = len(expt_transformer.transformer_blocks)

## Skipping

In [None]:
layer_order = list(range(20)) + list(range(22,N))

results_skip_transformer_20_21 = test_sd3(test_prompts_expt, layer_order)

In [None]:
layer_order = list(range(13)) + list(range(15,N))

results_skip_transformer = test_sd3(test_prompts_expt, layer_order)

In [None]:
layer_order = list(range(1,N))

results_skip_first_transformer = test_sd3(test_prompts_expt, layer_order)

## Skip Repeat

In [None]:
layer_order = list(range(14)) + [15, 15, 15, 15] + list(range(18,N))

results_skip_repeat_transformer_14_17 = test_sd3(test_prompts_expt, layer_order)

In [None]:
layer_order = list(range(18)) + [19, 19, 19, 19] + list(range(22,N))

results_skip_repeat_transformer_18_21 = test_sd3(test_prompts_expt, layer_order)

## Reverse

In [None]:
layer_order = list(range(14)) + list(reversed([14, 15, 16, 17])) + list(range(18,N))

results_skip_reverese_transformer_14_17 = test_sd3(test_prompts_expt, layer_order)

In [None]:
layer_order = list(range(18)) + [21, 20, 19, 18] + list(range(22,N))

results_skip_reverese_transformer = test_sd3(test_prompts_expt, layer_order)

## Parallel

In [None]:
layer_order = list(range(10)) + [(10, 11, 12, 13)] + list(range(22,N))

results_skip_parallel_transformer_10_14 = test_sd3(test_prompts_expt, layer_order)

In [None]:
layer_order = list(range(18)) + [(18, 19, 20, 21)] + list(range(22,N))

results_skip_parallel_transformer = test_sd3(test_prompts_expt, layer_order)

## Looped-Parallel

In [None]:
layer_order = list(range(10)) + [(10, 11, 12, 13)] * 4 + list(range(14,N))

results_looped_parallel_transformer_10_13 = test_sd3(test_prompts_expt, layer_order)

In [None]:
layer_order = list(range(15)) + [(15, 16, 17, 18)] * 4 + list(range(19,N))

results_looped_parallel_transformer = test_sd3(test_prompts_expt, layer_order)

In [None]:
results_looped_parallel_transformer #looped 15-18
results #base

In [None]:
import matplotlib.pyplot as pyplot
import textwrap
prompts = []
images = []
titles = [
    "Baseline", "Repeat transformer layers 15 for 14-17", "Repeat transformer layers 19 for 18-21"
]
for item in zip(
        results,
        results_skip_repeat_transformer_14_17,
        results_skip_repeat_transformer_18_21,
    ):
    prompts.append(item[0][0])
    images.append([i[1] for i in item])

num_rows = len(prompts)
num_cols = len(images[0])
fig, axes = pyplot.subplots(num_rows, num_cols, figsize=(30, 25), constrained_layout=True)

# Plot the images in the grid
for i, row_images in enumerate(images):
    for j, image in enumerate(row_images):
        axes[i, j].imshow(image)  # Display the image
        axes[i, j].axis('off')  # Hide the axes
        if i == 0:
            axes[0,j].set_title(titles[j], fontsize=24)
    # wrapped_text = "\n".join(textwrap.wrap(prompts[i], width=200))
    # # Set the row title by merging the column cells
    # axes[i, num_cols // 2 - 1].set_title(wrapped_text, fontsize=16, pad=0)

# Adjust layout
# pyplot.show()
pyplot.savefig('sd3_image_1.png')