# Interactive Sampling Function effect demonstration
This notebook demonstrates how sampling functions choose where they sample the noise via interactive plots with adjustable parameters using ipywidgets.

The base image has a different color for each pixel with a noticeable gradient.

The resulting image pixels show where each pixel was sampled via their color.

In [152]:
# Import Required Libraries
import ipywidgets as widgets
import numpy as np
from IPython.display import display

%matplotlib inline
import io
import math
from pathlib import Path

from PIL import Image

# Import all the distortion functions
from mettagrid.map.scenes.simsam_functions import (
    arbitrary_tilted_lattice,  # noqa: F401
    arbitrary_tilted_napkin,  # noqa: F401
    cross_curse,  # noqa: F401
    radial_symmetry,  # noqa: F401
    spiral,  # noqa: F401
    squeezed_noise,  # noqa: F401
    the_sphere,  # noqa: F401
    xy_noise,  # noqa: F401
)


## Create Base Image and Apply Distortion
Generate a base image if it doesn't exist and apply distortion functions with adjustable parameters.

In [153]:
def create_base_image(width=1000, height=1000):
    base = np.array([[math.sqrt(x**2 + y**2), x, y] for x in range(width) for y in range(height)]).reshape(
        width, height, 3
    )
    base = (base - np.min(base)) / (np.max(base) - np.min(base))
    base = np.floor(base * 255).astype("uint8")
    return Image.fromarray(base, mode="RGB")


# Create and save base image if it's not already there
img_path = "./Base_image.png"
if not Path(img_path).is_file():
    base_img = create_base_image()
    base_img.save("./Base_image.png")

In [154]:
def apply_distortion(distortion_func, params, width, height):
    # Pre-calculate the indices array once
    x, y = np.meshgrid(np.arange(width), np.arange(height))
    indices = np.stack([x.flatten(), y.flatten()], axis=1)

    # Vectorized distortion calculation
    distorted = np.array([distortion_func(x, y, width, height, **params) for x, y in indices])
    distorted_image = distorted.reshape(height, width, 2)
    distorted_image = distorted_image - np.min(distorted_image)
    distorted_image = np.floor(distorted_image * 999 / np.max(distorted_image))

    # Convert to integer indices once
    indices = np.floor(distorted_image).astype(np.int32)

    img = Image.open(img_path)
    # Use numpy's advanced indexing for faster pixel access
    img_array = np.array(img)
    result = img_array[indices[:, :, 0].clip(0, 999), indices[:, :, 1].clip(0, 999)]

    scale = 5
    final = Image.fromarray(result, mode="RGB")
    final = final.resize((int(width * scale), int(height * scale)), Image.NEAREST)  # type: ignore

    # Convert PIL images to bytes for display
    buf_result = io.BytesIO()
    final.save(buf_result, format="PNG")

    # Keep original image fixed at 300x300
    if not hasattr(apply_distortion, "cached_original"):
        original_resized = img.resize((300, 300), Image.NEAREST)  # type: ignore
        buf_original = io.BytesIO()
        original_resized.save(buf_original, format="PNG")
        apply_distortion.cached_original = buf_original.getvalue()

    # Use cached original image
    return widgets.HBox(
        [
            widgets.VBox(
                [widgets.Label("Original:"), widgets.Image(value=apply_distortion.cached_original, format="png")],
                layout=widgets.Layout(height="auto", align_items="flex-start"),
            ),
            widgets.VBox(
                [widgets.Label("Distorted:"), widgets.Image(value=buf_result.getvalue(), format="png")],
                layout=widgets.Layout(height="auto", align_items="flex-start"),
            ),
        ],
        layout=widgets.Layout(align_items="flex-start"),
    )


In [158]:
# Create widgets for dimensions of the map

dimensions_widgets = {
    "width": widgets.IntSlider(value=60, min=20, max=200, description="Width:"),
    "height": widgets.IntSlider(value=60, min=20, max=200, description="Height:"),
}

# Create function specific widgets for each function or group of similar functions

napkin_lattice_widgets = {
    "x_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="X Zoom:"),
    "y_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="Y Zoom:"),
    "angle_theta": widgets.FloatSlider(value=0, min=0, max=1, step=0.005, description="Angle:", readout_format=".3f"),
    "line1_wavelength": widgets.IntSlider(value=5, min=1, max=20, description="Line1 Wave:"),
    "line2_wavelength": widgets.IntSlider(value=5, min=1, max=20, description="Line2 Wave:"),
    "line1_thickness": widgets.IntSlider(value=1, min=1, max=10, description="Line1 Thick:"),
    "line2_thickness": widgets.IntSlider(value=1, min=1, max=10, description="Line2 Thick:"),
}

spiral_widgets = {
    "zoom": widgets.FloatSlider(value=0.10, min=0.01, max=2.00, step=0.01, description="Zoom:"),
    "squeeze": widgets.FloatSlider(value=1.5, min=1.0, max=5.0, step=0.1, description="Squeeze:"),
    "angle_theta": widgets.FloatSlider(value=0, min=0, max=1, step=0.01, description="Angle:"),
    "P": widgets.FloatSlider(value=2, min=-20, max=20, description="P param, spiral thickness:"),
    "xc": widgets.FloatSlider(value=0, min=-1, max=1, description="x off-center:"),
    "yc": widgets.FloatSlider(value=0, min=-1, max=1, description="y off-center:"),
}

the_sphere_widgets = {
    "x_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="X Zoom:"),
    "y_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="Y Zoom:"),
    "angle_theta": widgets.FloatSlider(value=0, min=0, max=1, step=0.005, description="Angle:", readout_format=".3f"),
    "x_pow": widgets.IntSlider(value=2, min=1, max=6, description="Power of x:"),
    "y_pow": widgets.IntSlider(value=2, min=1, max=6, description="Power of y:"),
    "xc": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="x off-center:"),
    "yc": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="y off-center:"),
    "P": widgets.FloatSlider(value=2, min=-10, max=10.0, step=0.1, description="P param:"),
    "ax": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="ax:"),
    "ay": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="ay:"),
    "bx": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="bx:"),
    "by": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="by:"),
}

cross_curse_widgets = {
    "x_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="X Zoom:"),
    "y_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="Y Zoom:"),
    "angle_theta": widgets.FloatSlider(value=0, min=0, max=1, step=0.005, description="Angle:", readout_format=".3f"),
    "x_pow": widgets.IntSlider(value=2, min=0, max=6, description="Power of x:"),
    "y_pow": widgets.IntSlider(value=2, min=1, max=6, description="Power of y:"),
    "xc": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="x off-center:"),
    "yc": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="y off-center:"),
}

radial_symmetry_widgets = {
    "x_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="X Zoom:"),
    "y_zoom": widgets.FloatSlider(value=1.5, min=0.1, max=5.0, step=0.1, description="Y Zoom:"),
    "angle_theta": widgets.FloatSlider(value=0, min=0, max=1, step=0.005, description="Angle:", readout_format=".3f"),
    "symmetry": widgets.IntSlider(value=2, min=0, max=8, description="Symmetry:"),
    "xc": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="x off-center:"),
    "yc": widgets.FloatSlider(value=0.0, min=-1.0, max=1.0, step=0.05, description="y off-center:"),
}

# Specify the set of widgets and the needed function
current_widgets = radial_symmetry_widgets
current_function = radial_symmetry

output = widgets.Output()


def on_value_change(change):
    with output:
        params = {name: widget.value for name, widget in current_widgets.items()}
        dims = {name: widget.value for name, widget in dimensions_widgets.items()}
        output.clear_output()
        display(apply_distortion(current_function, params, dims["width"], dims["height"]))


for widget in current_widgets.values():
    widget.observe(on_value_change, names="value")
for widget in dimensions_widgets.values():
    widget.observe(on_value_change, names="value")

# Display widgets and initial output
widgets_vbox = widgets.VBox(list(dimensions_widgets.values()) + list(current_widgets.values()))
display(widgets_vbox, output)

VBox(children=(IntSlider(value=60, description='Width:', max=200, min=20), IntSlider(value=60, description='He…

Output()