In [None]:
%load_ext autoreload
%autoreload 2

import os, sys
from pathlib import Path

# Add parent directory to sys.path
parent_dir = Path.cwd().parent.resolve()
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))

# Verify that the path has been added correctly
print(sys.path[0])

from diffusers import FluxPipeline
from diffusers.models import AutoencoderTiny
import torch
import optparse

os.environ['HF_HOME'] = '/scratch/nevali'
os.environ['TRANSFORMERS_CACHE'] = '/scratch/nevali'
os.environ['HF_DATASETS_CACHE'] = '/scratch/nevali'

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from PIL import Image
from importlib import reload

In [None]:
import torch
import gc
from accelerate.utils import release_memory

def clear_all_gpu_memory():
    # Run garbage collection
    gc.collect()

    # Get number of available GPUs
    num_gpus = torch.cuda.device_count()
    print(f"Found {num_gpus} GPU(s).")

    # Iterate through each GPU
    for device_id in range(num_gpus):
        with torch.cuda.device(device_id):
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.reset_accumulated_memory_stats()
            release_memory()
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.ipc_collect()
    print("GPU memory cleared across all available devices.")

In [None]:
from diffusers import FluxPipeline

dtype = torch.float16
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", 
                                    device_map="balanced",
                                    torch_dtype=dtype)
pipe.set_progress_bar_config(disable=True)

In [None]:
import cache_and_edit
reload(cache_and_edit)
from cache_and_edit import *
import cache_and_edit.hooks
reload(cache_and_edit.hooks)

cached_pipe = CachedPipeline(pipe)
# cached_pipe_dest = CachedPipeline(pipe)

In [None]:
from typing import Optional, Tuple

def resize_bounding_box(
    bb_mask: torch.Tensor,
    target_size: Tuple[int, int] = (64, 64),
) -> torch.Tensor:
    """
    Given a bounding box mask, patches it into a mask with the target size.
    The mask is a 2D tensor of shape (H, W) where each element is either 0 or 1.
    Any patch that contains at least one 1 in the original mask will be set to 1 in the output mask.

    Args:
        bb_mask (torch.Tensor): The bounding box mask as a boolean tensor of shape (H, W).
        target_size (Tuple[int, int]): The size of the target mask as a tuple (H, W).

    Returns:
        torch.Tensor: The resized bounding box mask as a boolean tensor of shape (H, W).
    """
    
    w_mask, h_mask = bb_mask.shape[-2:]
    w_target, h_target = target_size

    # Make sure the sizes are compatible
    if w_mask % w_target != 0 or h_mask % h_target != 0:
        raise ValueError(
            f"Mask size {bb_mask.shape[-2:]} is not compatible with target size {target_size}"
        )
    
    # Compute the size of a patch
    patch_size = (w_mask // w_target, h_mask // h_target)
    print(f"patch_size: {patch_size}")

    # Iterate over the mask, one patch at a time, and save a 0 patch if the patch is empty or a 1 patch if the patch is not empty
    out_mask = torch.zeros((w_target, h_target), dtype=bb_mask.dtype, device=bb_mask.device)
    for i in range(w_target):
        for j in range(h_target):
            patch = bb_mask[
                i * patch_size[0] : (i + 1) * patch_size[0],
                j * patch_size[1] : (j + 1) * patch_size[1],
            ]
            if torch.sum(patch) > 0:
                out_mask[i, j] = 1
            else:
                out_mask[i, j] = 0

    return out_mask


def get_combined_latents(
    bg_latents: torch.Tensor,
    fg_latents: torch.Tensor,
    bb_mask: torch.Tensor,
) -> torch.Tensor:
    """
    Given a background and foreground latents, and a bounding box mask,
    combine the latents into a single tensor by putting the foreground latents
    inside the bounding box mask and the background latents outside of it.

    Args:
        bg_latents (torch.Tensor): The background latents as a tensor of shape (H * W, C).
        fg_latents (torch.Tensor): The foreground latents as a tensor of shape (H * W, C).
        bb_mask (torch.Tensor): The bounding box mask as a boolean tensor of shape (H * W, ).

    Returns:
        torch.Tensor: The combined latents as a tensor of shape (H * W, C).

    """

    if bb_mask.ndim == 2:
        bb_mask = bb_mask.flatten()

    assert bg_latents.shape == fg_latents.shape, "Background and foreground latents must have the same shape"
    assert bg_latents.shape[0] == bb_mask.shape[0], "Background latents and bounding box mask must have the same number of elements"
    assert fg_latents.shape[0] == bb_mask.shape[0], "Foreground latents and bounding box mask must have the same number of elements"

    if bb_mask.dtype == torch.bool:
        bb_mask = bb_mask.float()
    if bb_mask.ndim == 1:
        bb_mask = bb_mask.unsqueeze(1)

    
    # Create a new tensor to hold the combined latents
    combined_latents = bg_latents * (1 - bb_mask) + fg_latents * bb_mask
    return combined_latents

In [None]:
 # [0, 1, 2, 17, 18, 25, 28, 53, 54, 56].  [28, 53, 54, 56, 25]
vital_layers = [f"transformer.transformer_blocks.{i}" for i in [0, 1, 17, 18]] + \
                [f"transformer.single_transformer_blocks.{i-19}" for i in [25, 28, 53, 54, 56]]
mask_path = "../data/Real-Real/0001 a professional photograph of a puppy in the snow, ultra realistic/mask_bg_fg.jpg"
mask = Image.open(mask_path).convert("L")
# convert to tensor
mask = np.array(mask)
mask = torch.from_numpy(mask)
mask = mask / 255.0
mask = resize_bounding_box(mask, target_size=(64, 64)).flatten().unsqueeze(1)
prompts = ["a house in the forest", "a green cow", "a house in the forest"]

In [None]:
def display_images(imgs: list):
    fig, axes = plt.subplots(1, len(imgs), figsize=(len(imgs)*5, 5))
    for ax, im in zip(axes, imgs):
        ax.imshow(im)
        ax.axis('off')  # hide axes ticks

    plt.tight_layout()
    plt.show()

In [None]:
from data.benchmark_data import gather_images
os_path = "../data/"
all_images = gather_images(os_path)

example = all_images[0]
print(example.category)
print(example.prompt)

# Plot the example
example.plot_sample()

In [None]:
from cache_and_edit.inversion import * 

cut_img = torch.tensor(np.array(example.fg_mask.resize(example.fg_image.size))).to(torch.bool).unsqueeze(-1) * torch.tensor(np.array(example.fg_image))
reframed_fg_img, resized_mask = place_image_in_bounding_box(
    cut_img,
    (torch.from_numpy(np.array(example.target_mask)) / 255.0).to(dtype=bool)
)

reframed_fg_img = Image.fromarray(reframed_fg_img.numpy())

bg_noise = get_inverted_input_noise(cached_pipe, example.bg_image, 100)
fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, 100)

display(cached_pipe.run(
    "",
    num_inference_steps=28,
    seed=42,
    guidance_scale=3.5,
    latents=fg_noise.unsqueeze(0),
    width=reframed_fg_img.size[0],
    height=reframed_fg_img.size[1]
).images[0].resize((256, 256)))

display(cached_pipe.run(
    "",
    num_inference_steps=28,
    seed=42,
    guidance_scale=3.5,
    latents=bg_noise.unsqueeze(0),
    width=example.bg_image.size[0],
    height=example.bg_image.size[1]
).images[0].resize((256, 256)))

resized_mask = resize_bounding_box(resized_mask, (32, 32)).flatten()

In [None]:
images = cached_pipe.run_inject_qkv(
    ["", "", ""],
    num_inference_steps=28,
    seed=42,
    guidance_scale=0,
    empty_clip_embeddings=False,
    q_mask=resized_mask.unsqueeze(1),
    positions_to_inject = [f"transformer.transformer_blocks.{i}" for i in range(19)]  +  [f"transformer.single_transformer_blocks.{i}" for i in range(38)],
    positions_to_inject_foreground = [f"transformer.transformer_blocks.{i}" for i in range(19)] + [f"transformer.single_transformer_blocks.{i}" for i in range(38)], #  + [f"transformer.single_transformer_blocks.{i}" for i in range(2)],
    latents=torch.stack([bg_noise, fg_noise, bg_noise]),
    width=512,
    height=512
).images

display_images(images)