In [None]:
!git clone https://github.com/huggingface/diffusers.git

In [None]:
!pip install huggingface_hub transformers datasets accelerate sentencepiece einops protobuf matplotlib accelerate

In [None]:
!pip install -U ./diffusers

In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

In [None]:
test_prompts_expt = [
    "A charismatic speaker is captured mid-speech. He has short, tousled brown hair that’s slightly messy on top. He has a round circle face, clean shaven, adorned with rounded rectangular-framed glasses with dark rims, is animated as he gestures with his left hand. He is holding a black microphone in his right hand, speaking passionately. The man is wearing a light grey sweater over a white t-shirt. He’s also wearing a simple black lanyard hanging around his neck. The lanyard badge has the text “Anakin AI”. Behind him, there is a blurred background with a white banner containing logos and text (including Anakin AI), a professional conference setting.", #Anakin AI
    "a red dog wearing a blue hat sits with a yellow cat wearing pink sunglasses", #wonderflex on reddit.
    "A Samsung LED moniter's screen on a table displays an image of a garden with signboard mentions 'All is Well', A teddy toy placed on the table, a cat is sleeping near the teddy toy, a mushroom dish on red plate placed on the table, raining outside, a parrot sitting on the nearby window, a flex banner with text 'Enjoy the life' visible from outside of the window,"
    "3d model of a green war balloon, clash of clans, fantasy game, front view, game asset, detailed, war ready, photorealistic, in a war enviroment, spring, disney style, pixar style",
    "Photo of a felt puppet diorama scene of a tranquil nature scene of a secluded forest clearing with a large friendly, rounded robot is rendered in a risograph style. An owl sits on the robots shoulders and a fox at its feet. Soft washes of color, 5 color, and a light-filled palette create a sense of peace and serenity, inviting contemplation and the appreciation of natural beauty.",
    "The Golden gate bridge"
]

In [None]:
import torch
import einops
from utils import plot_image_grid, cosine_similarity, plot_similarity_matrix
from custom_auraflow_pipeline import AuraFlowPipeline
from custom_auraflow_transformer import AuraFlowTransformer2DModel

In [None]:
expt_transformer = AuraFlowTransformer2DModel.from_pretrained("fal/AuraFlow-v0.3", torch_dtype=torch.float16, subfolder="transformer", variant="fp16")

In [None]:
pipe = AuraFlowPipeline.from_pretrained(
        "fal/AuraFlow-v0.2",
        transformer=expt_transformer,
        torch_dtype=torch.float16,
        variant="fp16",
        )
pipe.to("cuda")
pipe.enable_sequential_cpu_offload()

In [None]:
def test_aura(prompts, layer_order,single_layer_order):
    results = []
    print(f"layer_order {layer_order}")
    print(f"single_layer_order {single_layer_order}")
    generator = torch.Generator('cpu')
    generator.manual_seed(19943434)
    for prompt in prompts:
        out, _, _, _ = pipe(
            prompt=prompt,
            height=1024,
            width=1024,
            num_inference_steps=50,
            guidance_scale=3.5,
            max_sequence_length=256,
            return_dict=True,
            layer_order=layer_order,
            single_layer_order=single_layer_order,
            generator=generator
        )
        results.append((prompt, out.images[0]))
    plot_image_grid(results, rows=len(prompts), cols=1, figsize=(25,25))
    return results
    

## Baseline

In [None]:
import torch.nn.functional as F
# tensor = steps x layers x batch x a x b
def cosine_similarity(tensor, compute_mean=False):
    tensor = tensor.to("cuda")

    if compute_mean:
        # Step 1: Compute the average along the 's' dimension -> shape becomes (l, b, a, c)
        averaged_tensor = einops.reduce(tensor, 's l b a c -> l b a c', 'mean')
    else:
        averaged_tensor = tensor
    # Step 2: Reshape (b, a, c) into a single dimension -> shape becomes (l, a*b*c)
    reshaped_tensor = einops.rearrange(averaged_tensor, 'l a b c -> l (a b c)')

    # Step 3: Compute cosine similarity between each pair in the 'l' dimension
    # Cosine similarity: (x * y) / (||x|| * ||y||)
    # Here, we compute pairwise cosine similarities using PyTorch's F.cosine_similarity
    layers = reshaped_tensor.size()[0]
    # Initialize a similarity matrix
    cosine_similarities = torch.zeros((layers, layers))

    # Compute cosine similarity between each pair of tensors in the 'l' dimension
    for i in range(layers):
        for j in range(layers):
            cosine_similarities[i, j] = F.cosine_similarity(
                reshaped_tensor[i], reshaped_tensor[j], dim=0
            )
    return cosine_similarities

In [None]:
encoder_actns_sims = []
actns_cos_sims = []
single_actns_cos_sims = []
results = []

generator = torch.Generator('cpu')
generator.manual_seed(19943434)
for prompt in test_prompts_expt:
    out, encoder_actns, actns, single_actns = pipe(
        prompt=prompt,
        height=1024,
        width=1024,
        num_inference_steps=50,
        max_sequence_length=256,
        guidance_scale=3.5,
        return_dict=True,
        layer_order=list(range(len(expt_transformer.joint_transformer_blocks))),
        single_layer_order=list(range(len(expt_transformer.single_transformer_blocks))),
        generator=generator,
        track_activations=True
    )
    results.append((prompt, out.images[0]))
    encoder_actns_sims.append(cosine_similarity(encoder_actns))
    actns_cos_sims.append(cosine_similarity(actns))
    single_actns_cos_sims.append(cosine_similarity(single_actns))
    del encoder_actns
    del actns
    del single_actns

# encoder_actns_sims = torch.mean(torch.stack(encoder_actns_sims, dim=0),dim=0)
# actns_cos_sims = torch.mean(torch.stack(actns_cos_sims, dim=0),dim=0)
# single_actns_cos_sims = torch.mean(torch.stack(single_actns_cos_sims, dim=0),dim=0)

In [None]:
plot_similarity_matrix(torch.mean(torch.stack(encoder_actns_sims),dim=0))

In [None]:
plot_similarity_matrix(torch.mean(torch.stack(actns_cos_sims),dim=0))

In [None]:
plot_similarity_matrix(torch.mean(torch.stack(single_actns_cos_sims),dim=0))

In [None]:
plot_image_grid(results, rows=5,cols=1)

In [None]:
# Middle layers seem to be from 5 to 20

In [None]:
M = len(expt_transformer.joint_transformer_blocks)
N = len(expt_transformer.single_transformer_blocks)

In [None]:
M,N

## Skipping

In [None]:
layer_order = [0,1,3]
single_layer_order = list(range(N))
results_skip_layer_middle = test_aura(test_prompts_expt, layer_order, single_layer_order)
#skipping one joint layer

In [None]:
layer_order = [1,2,3]
single_layer_order = list(range(N))
results_skip_layer_first = test_aura(test_prompts_expt, layer_order, single_layer_order)
#skipping first joint layer

In [None]:
layer_order = [0,1,2]
single_layer_order = list(range(N))
results_skip_layer_last = test_aura(test_prompts_expt, layer_order, single_layer_order)
#skipping last joint layer

In [None]:
layer_order = list(range(M))
single_layer_order = list(range(4,N))
results_skip_single_layer_first = test_aura(test_prompts_expt, layer_order, single_layer_order)
#skipping first 4 single layers

In [None]:
layer_order = list(range(M))
single_layer_order = list(range(N-2))
results_skip_single_layer_last = test_aura(test_prompts_expt, layer_order, single_layer_order)
#skipping last 2 single layers

In [None]:
layer_order = list(range(M))
single_layer_order = list(range(N-14)) + list(range(N-11,N))
results_skip_single_layer_middle = test_aura(test_prompts_expt, layer_order, single_layer_order)
#skipping random 3 middle single layers

## Skip Repeat

In [None]:
layer_order = [0,1,1,3]
single_layer_order = list(range(N))
results_repeat_layer = test_aura(test_prompts_expt, layer_order, single_layer_order)
#skipping first joint layer

In [None]:
layer_order = list(range(M))
single_layer_order = list(range(N-14)) + [19, 19, 19] + list(range(N-11,N))
results_skip_single_layer_middle_repeat = test_aura(test_prompts_expt, layer_order, single_layer_order)

## Reverse

In [None]:
layer_order = [0,2,1,3]
single_layer_order = list(range(N))
results_skip_reverse_layer = test_aura(test_prompts_expt, layer_order, single_layer_order)

In [None]:
layer_order = list(range(M))
single_layer_order = list(range(N-14)) + [20, 19, 18] + list(range(N-11,N))
results_single_layer_middle_reverse = test_aura(test_prompts_expt, layer_order, single_layer_order)

## Parallel

In [None]:
layer_order = [0, (1,2), 3]
single_layer_order = list(range(N))
results_parallel_layer = test_aura(test_prompts_expt, layer_order, single_layer_order)

In [None]:
layer_order = list(range(M))
single_layer_order = list(range(N-14)) + [(18, 19, 20)] + list(range(N-11,N))
results_parallel_single_layer = test_aura(test_prompts_expt, layer_order, single_layer_order)

## Looped-Parallel

In [None]:
layer_order = [0, (1,2), (1,2), 3]
single_layer_order = list(range(N))
results_parallel_layer_loop = test_aura(test_prompts_expt, layer_order, single_layer_order)

In [None]:
layer_order = list(range(M))
single_layer_order = list(range(N-14)) + [(18, 19, 20)] * 3 + list(range(N-11,N))
results_parallel_single_layer_loop = test_aura(test_prompts_expt, layer_order, single_layer_order)

In [None]:
from utils import plot_similarity_matrices
plot_similarity_matrices(
    [
        torch.mean(torch.stack(encoder_actns_sims),dim=0),
        torch.mean(torch.stack(actns_cos_sims),dim=0),
        torch.mean(torch.stack(single_actns_cos_sims),dim=0)
    ],
    [
        "Encoder hidden state activations - MMDiT layers",
        "Hidden state activations - MMDiT layers",
        "Hidden state activations - Single layers"
    ]
)

In [None]:
import matplotlib.pyplot as pyplot
import textwrap
prompts = []
images = []
titles = [
    "Baseline", "Execute middle MMDiT layers in looped parallel", "Execute some middle single layers in looped parallel"
]
for item in zip(
        results,
        results_parallel_layer_loop,
        results_parallel_single_layer_loop,
    ):
    prompts.append(item[0][0])
    images.append([i[1] for i in item])

num_rows = len(prompts)
num_cols = len(images[0])
fig, axes = pyplot.subplots(num_rows, num_cols, figsize=(20, 25), constrained_layout=True)

# Plot the images in the grid
for i, row_images in enumerate(images):
    for j, image in enumerate(row_images):
        axes[i, j].imshow(image)  # Display the image
        axes[i, j].axis('off')  # Hide the axes
        if i == 0:
            axes[0,j].set_title(titles[j], fontsize=16)
    # wrapped_text = "\n".join(textwrap.wrap(prompts[i], width=200))
    # # Set the row title by merging the column cells
    # axes[i, num_cols // 2 - 1].set_title(wrapped_text, fontsize=16, pad=0)

# Adjust layout
# pyplot.show()
pyplot.savefig('aura_six.png')