In [None]:
from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel
from rembg import remove
from PIL import Image, ImageFilter
import torch
import torch.nn as nn
from ip_adapter_instantstyle import IPAdapterXL
from ip_adapter_instantstyle.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images
from PIL import Image, ImageChops
import numpy as np
import glob
import os


"""Import DPT for Depth Model"""
import DPT.util.io

from torchvision.transforms import Compose

from DPT.dpt.models import DPTDepthModel
from DPT.dpt.midas_net import MidasNet_large
from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet

from parametric_control_mlp import control_mlp

In [None]:
# Metallic MLP
mlp = control_mlp(1024)
mlp.load_state_dict(torch.load('model_weights/metallic.pt'))
mlp = mlp.to("cuda", dtype=torch.float16)
mlp.eval()

# Roughness MLP
mlp2 = control_mlp(1024)
mlp2.load_state_dict(torch.load('model_weights/roughness.pt'))
mlp2 = mlp2.to("cuda", dtype=torch.float16)
mlp2.eval()


In [None]:
"""Get MARBLE Model ready"""
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "models/image_encoder"
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin"
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0"
device = "cuda"

"""Load IP-Adapter + Instant Style + Editing MLP"""
cur_block = ('up', 0, 1)
torch.cuda.empty_cache()

controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device)
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    use_safetensors=True,
    torch_dtype=torch.float16,
    add_watermarker=False,
).to(device)

pipe.unet = register_cross_attention_hook(pipe.unet)
block_name = cur_block[0] + "_blocks." + str(cur_block[1])+ ".attentions." + str(cur_block[2])
print("Testing block {}".format(block_name))
ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device, target_blocks=[block_name], edit_mlp=mlp, edit_mlp2=mlp2)


In [None]:
"""
Get Depth Model Ready
"""
import cv2
model_path = "DPT/dpt_weights/dpt_hybrid-midas-501f0c75.pt"
net_w = net_h = 384
model = DPTDepthModel(
    path=model_path,
    backbone="vitb_rn50_384",
    non_negative=True,
    enable_attention_hooks=False,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

transform = Compose(
        [
            Resize(
                net_w,
                net_h,
                resize_target=None,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method="minimal",
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            normalization,
            PrepareForNet(),
        ]
    )

model.eval()

In [None]:
# Edit strengths for metallic. More negative = more metallic, best results between range -20 to 20
edit_strengths1 = [-20, 0, 20]

# Edit strengths for roughness. More positive = more roughness, best results between range -1 to 1
edit_strengths2 = [-1, 0, 1]


all_images = []
for edit_strength1 in edit_strengths1:
    for edit_strength2 in edit_strengths2:
        
        target_image_path = 'input_images/context_image/toy_car.png'
        target_image = Image.open(target_image_path).convert('RGB')
        
        """
        Compute depth map from input_image
        """

        img = np.array(target_image)

        img_input = transform({"image": img})["image"]

        # compute
        with torch.no_grad():
            sample = torch.from_numpy(img_input).unsqueeze(0)

            # if optimize == True and device == torch.device("cuda"):
            #     sample = sample.to(memory_format=torch.channels_last)
            #     sample = sample.half()

            prediction = model.forward(sample)
            prediction = (
                torch.nn.functional.interpolate(
                    prediction.unsqueeze(1),
                    size=img.shape[:2],
                    mode="bicubic",
                    align_corners=False,
                )
                .squeeze()
                .cpu()
                .numpy()
            )

        depth_min = prediction.min()
        depth_max = prediction.max()
        bits = 2
        max_val = (2 ** (8 * bits)) - 1

        if depth_max - depth_min > np.finfo("float").eps:
            out = max_val * (prediction - depth_min) / (depth_max - depth_min)
        else:
            out = np.zeros(prediction.shape, dtype=depth.dtype)

        out = (out / 256).astype('uint8')
        depth_map = Image.fromarray(out).resize((1024, 1024))
        
        
        """Preprocessing data for MARBLE"""
        rm_bg = remove(target_image)
        target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB')# Convert mask to grayscale

        noise = np.random.randint(0, 256, target_image.size + (3,), dtype=np.uint8)
        noise_image = Image.fromarray(noise)
        mask_target_img = ImageChops.lighter(target_image, target_mask)
        invert_target_mask = ImageChops.invert(target_mask)

        from PIL import ImageEnhance
        gray_target_image = target_image.convert('L').convert('RGB')
        gray_target_image = ImageEnhance.Brightness(gray_target_image)

        # Adjust brightness
        # The factor 1.0 means original brightness, greater than 1.0 makes the image brighter. Adjust this if the image is too dim
        factor = 1.0  # Try adjusting this to get the desired brightness

        gray_target_image = gray_target_image.enhance(factor)
        grayscale_img = ImageChops.darker(gray_target_image, target_mask)
        img_black_mask = ImageChops.darker(target_image, invert_target_mask)
        grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img)
        init_img = grayscale_init_img
        
        # The texture to be applied onto car
        ip_image = Image.open('input_images/texture/metal_bowl.png')


        init_img = target_image
        init_img = init_img.resize((1024,1024))
        mask = target_mask.resize((1024, 1024))


        cur_seed = 42
        images = ip_model.generate_edit_mlp_lr_multi(pil_image = ip_image, image=init_img, control_image=depth_map, \
                                                     mask_image=mask, controlnet_conditioning_scale=1., num_samples=1, \
                                                     num_inference_steps=30, seed=cur_seed, edit_strength=edit_strength1, \
                                                     edit_strength2=edit_strength2, strength=1)
        all_images.append(images[0].resize((512,512)))



In [None]:
import matplotlib.pyplot as plt
def show_image_grid(images, x, y, figsize=(10, 10)):
    """
    Display a list of images in an x by y grid.

    Args:
        images (list of np.array): List of images (e.g., numpy arrays).
        x (int): Number of columns.
        y (int): Number of rows.
        figsize (tuple): Size of the figure.
    """
    fig, axes = plt.subplots(y, x, figsize=figsize)
    axes = axes.flatten()

    for i in range(x * y):
        ax = axes[i]
        if i < len(images):
            ax.imshow(images[i])
            ax.axis('off')
        else:
            ax.axis('off')  # Hide unused subplots

    plt.tight_layout()
    plt.show()
show_image_grid(all_images, len(edit_strengths1), len(edit_strengths2))