In [None]:
!git clone https://github.com/huggingface/diffusers.git
# use the 803e817 commit

In [None]:
!pip install huggingface_hub transformers datasets accelerate sentencepiece einops protobuf matplotlib

In [None]:
!pip install ./diffusers

In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

## Setup the test dataset

In [None]:
test_prompts_expt = [
    "A charismatic speaker is captured mid-speech. He has short, tousled brown hair that’s slightly messy on top. He has a round circle face, clean shaven, adorned with rounded rectangular-framed glasses with dark rims, is animated as he gestures with his left hand. He is holding a black microphone in his right hand, speaking passionately. The man is wearing a light grey sweater over a white t-shirt. He’s also wearing a simple black lanyard hanging around his neck. The lanyard badge has the text “Anakin AI”. Behind him, there is a blurred background with a white banner containing logos and text (including Anakin AI), a professional conference setting.", #Anakin AI
    "a red dog wearing a blue hat sits with a yellow cat wearing pink sunglasses", #wonderflex on reddit.
    "A Samsung LED moniter's screen on a table displays an image of a garden with signboard mentions 'All is Well', A teddy toy placed on the table, a cat is sleeping near the teddy toy, a mushroom dish on red plate placed on the table, raining outside, a parrot sitting on the nearby window, a flex banner with text 'Enjoy the life' visible from outside of the window,"
    "3d model of a green war balloon, clash of clans, fantasy game, front view, game asset, detailed, war ready, photorealistic, in a war enviroment, spring, disney style, pixar style",
    "Photo of a felt puppet diorama scene of a tranquil nature scene of a secluded forest clearing with a large friendly, rounded robot is rendered in a risograph style. An owl sits on the robots shoulders and a fox at its feet. Soft washes of color, 5 color, and a light-filled palette create a sense of peace and serenity, inviting contemplation and the appreciation of natural beauty.",
    "The Golden gate bridge"
]

In [None]:
test_prompts_expt

## Load the flux model

In [None]:
import torch
import einops
from utils import plot_image_grid, cosine_similarity, plot_similarity_matrix
from custom_flux_pipeline import FluxPipeline
from custom_flux_transformer import FluxTransformer2DModel

In [None]:
expt_transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, subfolder="transformer")

In [None]:
pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell",
        transformer=expt_transformer,
        torch_dtype=torch.bfloat16
        )
pipe.to("cuda")

### Question 1:
Is there a common representation space for the diffusion transformer as shown for other transformers in [Transformer layers as painters](https://arxiv.org/pdf/2407.09298) ?

### Test: 
1. Collect the hidden states of the transformer for each layer at each timestep across multiple inputs.
2. Test #1: Avg the hidden states of the layers across inputs.
3. Test #2: Avg the hidden states of the layers across timesteps. (Intutively should give same result as for a single timestep a.k.a shouldn't vary too much)
4. Test #3: Avg the hidden states of the layers across inputs and timesteps.
5. Compute the cosine similarity using the average hidden states (activations) of each layer from the step above. 

In [None]:
def test_flux(prompts, layer_order, single_layer_order):
    results = []
    print(f"layer_order {layer_order}")
    print(f"single_layer_order {single_layer_order}")
    generator = torch.Generator('cpu')
    generator.manual_seed(19943434)
    for prompt in prompts:
        out, encoder_actns, actns, single_actns = pipe(
            prompt=prompt,
            guidance_scale=0.,
            height=768,
            width=1360,
            num_inference_steps=4,
            max_sequence_length=256,
            return_dict=True,
            layer_order=layer_order,
            single_layer_order=single_layer_order,
            generator=generator
        )
        results.append((prompt, out.images[0]))
    plot_image_grid(results, rows=len(prompts), cols=1, figsize=(30,25))
    return results

## Baseline

In [None]:
encoder_actns_sims = []
actns_cos_sims = []
single_actns_cos_sims = []
results = []

generator = torch.Generator('cpu')
generator.manual_seed(19943434)
for prompt in test_prompts_expt:
    out, encoder_actns, actns, single_actns = pipe(
        prompt=prompt,
        guidance_scale=0.,
        height=768,
        width=1360,
        num_inference_steps=4,
        max_sequence_length=256,
        return_dict=True,
        layer_order=list(range(len(expt_transformer.transformer_blocks))),
        single_layer_order=list(range(len(expt_transformer.single_transformer_blocks))),
        generator=generator
    )
    results.append((prompt, out.images[0]))
    encoder_actns_sims.append(cosine_similarity(encoder_actns, compute_mean=True))
    actns_cos_sims.append(cosine_similarity(actns, compute_mean=True))
    single_actns_cos_sims.append(cosine_similarity(single_actns, compute_mean=True))

encoder_actns_sims = torch.stack(encoder_actns_sims, dim=0)
actns_cos_sims = torch.stack(actns_cos_sims, dim=0)
single_actns_cos_sims = torch.stack(single_actns_cos_sims, dim=0)

In [None]:
plot_similarity_matrix(einops.reduce(encoder_actns_sims, 'i a b -> a b', 'mean'))

In [None]:
plot_similarity_matrix(einops.reduce(actns_cos_sims, 'i a b -> a b', 'mean'))

In [None]:
plot_similarity_matrix(einops.reduce(single_actns_cos_sims, 'i a b -> a b', 'mean'))

In [None]:
plot_image_grid(results, rows=4, cols=1, figsize=(30,25))

In [None]:
## Skipping middle layers of the transformer blocks seems to affect prompt adherence.
### Eg: An antique chest with 3 drawers, A red circle on top a blue square. -> Retest this.
## Skipping middle layers of the single transformer blocks seem to affect 
# [0,1,2,3] -> base
# [0,2,3] -> skip 
# [0,1,1,3] -> skip-repeat
# [0,(1,2),3] -> parallel
# [3,2,1,0] -> reverse 
# [0, (1,2), (1,2), 3] -> looped-parallel

## Skip

### Skipping middle transformer_layers

In [None]:
layer_order = [0,1,2,3,4,5,6,11,12,13,14,15,16,17,18] #skipping 7,8,9,10 layers
single_layer_order = list(range(len(expt_transformer.single_transformer_blocks)))

results_skip_transformer = test_flux(test_prompts_expt, layer_order, single_layer_order)

### Skipping middle single transformer blocks
There seem to be two groupings in the middle layers. one group till layer 16-17 and another group till layer 33

In [None]:
layer_order=list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(10)) + list(range(14,len(expt_transformer.single_transformer_blocks)))

results_skip_single_layers = test_flux(test_prompts_expt, layer_order, single_layer_order)

In [None]:
layer_order=list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(22)) + list(range(27,len(expt_transformer.single_transformer_blocks)))

results_skip_single_layers_later = test_flux(test_prompts_expt, layer_order, single_layer_order)

### Skipping first and last layers: transformer blocks

In [None]:
layer_order=list(range(1,len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(0,len(expt_transformer.single_transformer_blocks)))

results_skip_layers_first = test_flux(test_prompts_expt, layer_order, single_layer_order)

In [None]:
layer_order=list(range(len(expt_transformer.transformer_blocks)-1))
single_layer_order=list(range(0,len(expt_transformer.single_transformer_blocks)))

results_skip_layers_last = test_flux(test_prompts_expt, layer_order, single_layer_order)

### Skipping first and last layers: single transformer blocks

In [None]:
layer_order=list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(1,len(expt_transformer.single_transformer_blocks)))

results_skip_single_layers_first = test_flux(test_prompts_expt, layer_order, single_layer_order)

In [None]:
layer_order=list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(len(expt_transformer.single_transformer_blocks)-1))
results_skip_single_layers_last = test_flux(test_prompts_expt, layer_order, single_layer_order)

## Repeat

### Repeat transformer blocks

In [None]:
layer_order=[9 if 3 <= i <= 16 else i for i in range(len(expt_transformer.transformer_blocks))]
single_layer_order=list(range(len(expt_transformer.single_transformer_blocks)))

results_layer_repeat = test_flux(test_prompts_expt, layer_order, single_layer_order)

### Repeat single transformer blocks

In [None]:
layer_order=list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=[19 if 4 <= i <= 35 else i for i in range(len(expt_transformer.single_transformer_blocks))]

results_single_layer_repeat = test_flux(test_prompts_expt, layer_order, single_layer_order)

## Reverse

### Reverse transformer blocks

In [None]:
layer_order=[0,1,2] + list(reversed(list(range(3,len(expt_transformer.transformer_blocks)-3)))) + list(range(len(expt_transformer.transformer_blocks)-3, len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(len(expt_transformer.single_transformer_blocks)))

results_mid_layer_reverse = test_flux(test_prompts_expt, layer_order, single_layer_order)

### Reverse single transformer blocks

In [None]:
layer_order=list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=[0,1,2,3,4] + list(reversed(list(range(5, len(expt_transformer.single_transformer_blocks)-4)))) + list(range(len(expt_transformer.single_transformer_blocks)-4, len(expt_transformer.single_transformer_blocks)))

results_mid_single_layer_reverse = test_flux(test_prompts_expt, layer_order, single_layer_order)

## Parallel

### Parallel: transformer blocks

In [None]:
layer_order = [0, 1, 2] + [tuple(range(3,16))] + list(range(17,len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(len(expt_transformer.single_transformer_blocks)))
results_parallel_transformer = test_flux(test_prompts_expt, layer_order, single_layer_order)

### Parallel: single transformer blocks

In [None]:
layer_order = list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=[0,1,2,3] + [tuple(range(4, len(expt_transformer.single_transformer_blocks)-4))] + list(range(len(expt_transformer.single_transformer_blocks)-4,len(expt_transformer.single_transformer_blocks)))
results_parallel_single_transformer = test_flux(test_prompts_expt, layer_order, single_layer_order)

## Looped-Parallel

### Looped Parallel: transformer blocks

In [None]:
parallel_layers = tuple(range(3,16))
layer_order = [0, 1, 2] + [parallel_layers]*len(parallel_layers) + list(range(17,len(expt_transformer.transformer_blocks)))
single_layer_order=list(range(len(expt_transformer.single_transformer_blocks)))
results_parallel_transformer_looped = test_flux(test_prompts_expt, layer_order, single_layer_order)

### Looped Parallel: single transformer blocks

In [None]:
parallel_layers = tuple(range(4, len(expt_transformer.single_transformer_blocks)-4))
layer_order = list(range(len(expt_transformer.transformer_blocks)))
single_layer_order=[0,1,2,3] + [parallel_layers]*len(parallel_layers) + list(range(len(expt_transformer.single_transformer_blocks)-4,len(expt_transformer.single_transformer_blocks)))
results_parallel_single_transformer_looped = test_flux(test_prompts_expt, layer_order, single_layer_order)

In [None]:
import matplotlib.pyplot as pyplot
import textwrap
prompts = []
images = []
titles = [
    "Baseline", "Repeat parallel and avg MMDiT middle layers", "Repeat parallel and avg Single middle layers"
]
for item in zip(
        results,
        results_parallel_transformer,
        results_parallel_single_transformer,
    ):
    prompts.append(item[0][0])
    images.append([i[1] for i in item])

num_rows = len(prompts)
num_cols = len(images[0])
fig, axes = pyplot.subplots(num_rows, num_cols, figsize=(30, 25), constrained_layout=True)

# Plot the images in the grid
for i, row_images in enumerate(images):
    for j, image in enumerate(row_images):
        axes[i, j].imshow(image)  # Display the image
        axes[i, j].axis('off')  # Hide the axes
        if i == 0:
            axes[0,j].set_title(titles[j], fontsize=24)
    # wrapped_text = "\n".join(textwrap.wrap(prompts[i], width=200))
    # # Set the row title by merging the column cells
    # axes[i, num_cols // 2 - 1].set_title(wrapped_text, fontsize=16, pad=0)

# Adjust layout
# pyplot.show()
pyplot.savefig('image_eight.png')