# STYLE-ALIGNED WITH CONTROLNET

In [None]:
# Clone the repository
!git clone https://github.com/alessioborgi/StyleAlignedDiffModels.git

# Change directory to the cloned repository
%cd StyleAlignedDiffModels
%ls

# Set up Git configuration
!git config --global user.name "Alessio Borgi"
!git config --global user.email "alessioborgi3@gmail.com"

# Stage the changes
#!git add .

# Commit the changes
#!git commit -m "Added some content to your-file.txt"

# Push the changes (replace 'your-token' with your actual personal access token)
#!git push origin main

In [None]:
# Install the required packages
!pip install -r requirements.txt > /dev/null

In [None]:
from __future__ import annotations
import cv2
import copy
import torch
import einops
import mediapy
import numpy as np
from PIL import Image
import torch.nn as nn
from tqdm import tqdm
from typing import Any
from typing import Callable
from dataclasses import dataclass
from diffusers.utils import load_image
from torch.nn import functional as nnf
from diffusers.models import attention_processor
from diffusers.image_processor import PipelineImageInput
from transformers import DPTImageProcessor, DPTForDepthEstimation
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL

from src.Handler import Handler
from src.StyleAlignedArgs import StyleAlignedArgs
from src.ControlNet import SDXL_ControlNet_Model, concat_zero_control
from src.Depth_Map import get_depth_map
from src.HarrisCorner import get_edge_map

# Create Alias for torch.tensor to increase readability.
T = torch.tensor 
TN = T

In [None]:
# Load the ControlNet model with specified parameters.
ControlNet_Model = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0",  # Model identifier.
    variant="fp16",                         # Use 16-bit floating point precision.
    use_safetensors=True,                   # Use SafeTensors for security.
    torch_dtype=torch.float16               # Set Torch data type to float16.
).to("cuda")                                # Move model to GPU.

# Load the AutoencoderKL model with specified parameters.
AutoencoderKL_Model = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix",        # Model identifier.
    torch_dtype=torch.float16               # Set Torch data type to float16.
).to("cuda")                                # Move model to GPU.

# Initialize the Stable Diffusion XL ControlNet Pipeline
SDXL_ControlNet_Pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",  # Model identifier.
    controlnet=ControlNet_Model,                       # Attach the loaded ControlNet model.
    vae=AutoencoderKL_Model,                                     # Attach the loaded AutoencoderKL model.
    variant="fp16",                              # Use 16-bit floating point precision.
    use_safetensors=True,                        # Use SafeTensors for security.
    torch_dtype=torch.float16                    # Set Torch data type to float16.
).to("cuda")                                     # Move pipeline to GPU.

# Enable model CPU offload to optimize memory usage.
SDXL_ControlNet_Pipeline.enable_model_cpu_offload()

# Define Style Aligned Arguments with specified parameters.
sa_args = StyleAlignedArgs(
    share_group_norm=False,     # Do not share GroupNorm layers.
    share_layer_norm=False,     # Do not share LayerNorm layers.
    share_attention=True,       # Share Attention layers.
    adain_queries=True,         # Apply Adaptive Instance Normalization to queries.
    adain_keys=True,            # Apply Adaptive Instance Normalization to keys.
    adain_values=False          # Do not apply Adaptive Instance Normalization to values.
)

# Initialize Handler with the pipeline.
handler = Handler(SDXL_ControlNet_Pipeline)

# Register the Style Aligned Arguments with the handler.
handler.register(sa_args)

#### 7.1: CONTROL-NET WITH SIMPLE IMAGE & STYLE-ALIGNMENT


In [None]:
# Load and resize the control image to 1024x1024 pixels.
control_image = load_image("./imgs/sun.png").resize((1024, 1024))

# Display the control image using mediapy.
mediapy.show_image(control_image)

In [None]:
# Define the reference style and prompts for the controlnet.
reference_style_controlnet = "flat design style"
reference_prompt = f"a poster in {reference_style_controlnet}"
target_prompt = f"the sun in {reference_style_controlnet}"

# Set the conditioning scale for controlnet.
controlnet_conditioning_scale = 0.8

# Specify the number of images to generate per prompt.
num_images_per_prompt = 3  # Adjust according to VRAM at your disposal.

# Generate random latents for the inference process.
latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(SDXL_ControlNet_Pipeline.unet.dtype)
latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(SDXL_ControlNet_Pipeline.unet.dtype)

# Call the controlnet pipeline to generate images based on the prompts and control image.
images_generated = SDXL_ControlNet_Model(SDXL_ControlNet_Pipeline, [reference_prompt, target_prompt],
                         image=control_image,
                         num_inference_steps=50,
                         controlnet_conditioning_scale=controlnet_conditioning_scale,
                         num_images_per_prompt=num_images_per_prompt,
                         latents=latents)

# Display the generated images along with the control image.
mediapy.show_images(
    [images_generated[0], control_image] + images_generated[1:],
    titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images_generated))]
)

#### 7.2: CONTROL-NET WITH DEPTH MAP & STYLE-ALIGNMENT


In [None]:
# Load the DPT model for depth estimation and move it to the GPU.
DPT_Estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")

# Load the corresponding image processor for the DPT model.
DPT_Feature_Processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")

# Load the control image from the specified path.
control_image = load_image("./imgs/train.png")

# Generate a depth map for the control image using the feature processor and depth estimator.
control_depth_image = get_depth_map(control_image, DPT_Feature_Processor, DPT_Estimator)

# Display the generated depth map using mediapy.
mediapy.show_image(control_depth_image)

In [None]:
# Define the reference style for ControlNet.
reference_style_controlnet = "flat design style"

# Create prompts for reference and target images.
reference_prompt = f"a poster in {reference_style_controlnet}"  # Prompt for generating the reference image.
target_prompt = f"a train in {reference_style_controlnet}"      # Prompt for generating the target image.

# Set the conditioning scale for ControlNet.
controlnet_conditioning_scale = 0.8

# Specify the number of images to generate per prompt.
num_images_per_prompt = 3  # Adjust according to VRAM size.

# Generate random latents for the inference process.
latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(SDXL_ControlNet_Pipeline.unet.dtype)
latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(SDXL_ControlNet_Pipeline.unet.dtype)

# Call the ControlNet pipeline to generate images based on the prompts and control depth image.
images = SDXL_ControlNet_Model(
    SDXL_ControlNet_Pipeline,
    [reference_prompt, target_prompt],  # Reference and target prompts.
    image=control_depth_image,          # Control depth image input.
    num_inference_steps=50,             # Number of inference steps.
    controlnet_conditioning_scale=controlnet_conditioning_scale,  # Conditioning scale for ControlNet.
    num_images_per_prompt=num_images_per_prompt,  # Number of images to generate per prompt.
    latents=latents                     # Latents for the inference process.
)

# Display the generated images along with the control depth image.
mediapy.show_images(
    [images[0], control_depth_image] + images[1:],  # Reference image, control depth image, and other generated images.
    titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))]  # Titles for each image.
)

#### 7.4: CONTROL-NET WITH EDGE MAP (CANNY DETECTOR) & STYLE-ALIGNMENT

In [None]:
# Load the DPT model for depth estimation and move it to the GPU.
DPT_Estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")

# Load the corresponding image processor for the DPT model.
DPT_Feature_Processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")

# Load the control image from the specified path.
control_image = load_image("./imgs/train.png")

# Generate edge map for the control image using the feature processor and depth estimator.
control_edge_image = get_edge_map(control_image, DPT_Feature_Processor, DPT_Estimator)

# Display the generated depth map using mediapy.
mediapy.show_image(control_edge_image)

In [None]:
# Define the reference style for ControlNet.
reference_style_controlnet = "flat design style"

# Create prompts for reference and target images.
reference_prompt = f"a poster in {reference_style_controlnet}"  # Prompt for generating the reference image.
target_prompt = f"a train in {reference_style_controlnet}"      # Prompt for generating the target image.

# Set the conditioning scale for ControlNet.
controlnet_conditioning_scale = 0.8

# Specify the number of images to generate per prompt.
num_images_per_prompt = 3  # Adjust according to VRAM size.

# Generate random latents for the inference process.
latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(SDXL_ControlNet_Pipeline.unet.dtype)
latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(SDXL_ControlNet_Pipeline.unet.dtype)

# Call the ControlNet pipeline to generate images based on the prompts and control edge image.
images = SDXL_ControlNet_Model(
    SDXL_ControlNet_Pipeline,
    [reference_prompt, target_prompt],  # Reference and target prompts.
    image=control_edge_image,            # Control edge image input.
    num_inference_steps=50,              # Number of inference steps.
    controlnet_conditioning_scale=controlnet_conditioning_scale,  # Conditioning scale for ControlNet.
    num_images_per_prompt=num_images_per_prompt,  # Number of images to generate per prompt.
    latents=latents                     # Latents for the inference process.
)

# Display the generated images along with the control edge image.
mediapy.show_images(
    [images[0], control_edge_image] + images[1:],  # Reference image, control edge image, and other generated images.
    titles=["reference", "edge"] + [f'result {i}' for i in range(1, len(images))]  # Titles for each image.
)