In [None]:
import torch
from PIL import Image
import numpy as np

from marble import (
    get_session,
    run_parametric_control,
    setup_control_mlps,
    setup_pipeline,
)

In [None]:
control_mlps = setup_control_mlps()
ip_adapter = setup_pipeline()
rembg_session = get_session()

# Create the 3x3 editing grid

In [None]:
# Edit strengths for metallic. More negative = more metallic, best results between range -20 to 20
edit_strengths1 = [-20, 0, 20]

# Edit strengths for roughness. More positive = more roughness, best results between range -1 to 1
edit_strengths2 = [-1, 0, 1]


all_images = []
for edit_strength1 in edit_strengths1:
    for edit_strength2 in edit_strengths2:
        target_image_path = "input_images/context_image/toy_car.png"
        target_image = Image.open(target_image_path)

        texture_image = "input_images/texture/metal_bowl.png"
        texture_image = Image.open(texture_image)

        result = run_parametric_control(
            ip_adapter,
            target_image,
            {
                control_mlps["metallic"]: edit_strength1,
                control_mlps["roughness"]: edit_strength2,
            },
            texture_image=texture_image,
        )

        all_images.append(result.resize((512, 512)))

In [None]:
import matplotlib.pyplot as plt


def show_image_grid(images, x, y, figsize=(10, 10)):
    """
    Display a list of images in an x by y grid.

    Args:
        images (list of np.array): List of images (e.g., numpy arrays).
        x (int): Number of columns.
        y (int): Number of rows.
        figsize (tuple): Size of the figure.
    """
    fig, axes = plt.subplots(y, x, figsize=figsize)
    axes = axes.flatten()

    for i in range(x * y):
        ax = axes[i]
        if i < len(images):
            ax.imshow(images[i])
            ax.axis("off")
        else:
            ax.axis("off")  # Hide unused subplots

    plt.tight_layout()
    plt.show()


show_image_grid(all_images, len(edit_strengths1), len(edit_strengths2))