# Crystal Growth Kernel: Interactive Demo and Benchmarks

This notebook is a supplementary demonstration for the paper by **A.V. Redkov** and **D.I. Trotsenko** entitled **_"When Epitaxy Meets AI: Toward All-In-Silico Technology Design"_**. It showcases the capabilities of the AI-based approach for the creation of a crystal growth (or epitaxial) kernel. The kernel is a digital library consisting of several AI models for ultrafast and accurate epitaxial growth prediction with atomistic precision.

The kernel is trained on a large dataset of in-silico epitaxial growth experiments (299,000 experiments, 3.2 TB of data) conducted under various growth conditions, obtained using a GPU server and an atomistic growth model (Monte Carlo or other). In each of these experiments, the final outcome properties—such as growth rate, RMS roughness, average roughness, kurtosis, peak-to-valley, skewness, and growth stability—were analyzed. Additionally, the morphology (simulated AFM images) was saved, forming a dataset for analysis and model training using AI.

<img src="./dataset_preparation.png" alt="Crystal growth kernel preparation" width="1200"/>
<i> Figure 1. The scheme of crystal growth kernel preparation. From the atomistic model to the huge dataset and training of AI models. </i>

The created kernel allows one to instantly predict all these properties, generate surface morphology typical for any combination of growth conditions, determine all possible types of growth within the model, and build multidimensional structure and stability zone diagrams, which show the regions of the growth parameter space where each type of growth occurs.

<img src="./cgkernel.png" alt="Scheme of CG-kernel capabilities" width="1200"/>

<i> Figure 2. The scheme of crystal growth kernel capabilities: from predicting single properties to building multidimensional structure zone diagrams and instant generation of surface morphology. </i>

The model system on which the data was generated includes 7 key parameters (see the paper), governing different aspects of epitaxial growth and affecting morphology, and are determined by growth conditions and crystal properties. Therefore they referred as 'Growth conditions'.
- Surface concentration of adatoms (Cs) 
- Desorption probability of an atom (Pd)
- Terrace-to-terrace hopping probability (PES), related to the Schwoebel barrier.
- Adatom drift perpendicular to steps (Pbias). 
- Step transparency (T).
- Number of threading dislocations (M), which form voids on the surface.
- Categorical parameter Nnucl represents the artificially allowed surface processes in cellular automata model, i.e. 1D and 2D nucleation.

The notebook progresses from simple calls to interactive visualization and benchmarking. Use the controls in later cells to explore maps, compare different regressors with dataset points, and generate morphology surfaces interactively.

<i>*Note that this notebook presents the CGKernel trained specifically using the simple atomistic base model described in our paper. Other kernels may be created for different base atomistic models, such as MBE, MOCVD, and others, involving any number of input parameters and predicted output properties. The framework for training the kernel is available upon request.</i>

## Please use GPU for tests if availiable!
In Google Colab it can be turned on using menu 'Edit'->'Notebook settings'->'GPU T4'

## Google Colab Setup (run next three cells only if you work in Google Colab)

In [None]:
# Google Colab Setup Cell
# Run this cell first in Google Colab to install dependencies and clone the repository

# Install required packages (in some cases, if a warning about installation of some packages is shown, it may be required to restart the runtime)
!pip install -r https://raw.githubusercontent.com/avredkov/CG-Kernel/main/requirements_colab.txt

In [None]:
# Clone the repository 
!git clone https://github.com/avredkov/CG-Kernel.git

In [None]:
%cd '/content/CG-Kernel'
PATH_TO_KERNEL='..'
import sys
sys.path.append('/content/CG-Kernel')

## Jupyter Setup (run this cell if run in Jupyter locally)

In [None]:
PATH_TO_KERNEL="Path to the folder with cgkernel.py"
import sys
sys.path.append(PATH_TO_KERNEL)

## Plotly library configuration for working in both Colab and Jupyter Notebook

In [None]:
# Plotly Configuration for Google Colab and JupyterLab Compatibility
# This cell configures Plotly to work properly in both Google Colab and JupyterLab

import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Import Plotly helper functions from demo_utils
from demo_utils import configure_plotly, show_plot, detect_environment, clear_output_safely

# Configure Plotly for the current environment (detects automatically)
configure_plotly()

print("Plotly is ready for interactive visualizations!")

# 1. Initialization

Setup: imports, logging, and kernel initialization.
- Reads `cgkernel_config.json` from the project root
- Initializes CGKernel (loads models and preprocessors)
- Loads config-driven display labels and dataset path

In [None]:
from __future__ import annotations

import logging
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import warnings
from sklearn.exceptions import InconsistentVersionWarning
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)

# Configure logging level (adjust as needed)
logging.basicConfig(level=logging.CRITICAL)

# Discover project root (folder containing cgkernel_config.json) without relying on local imports

def _discover_project_root(start: str | Path = ".", max_up: int = 6) -> Path:
    # For Google Colab, check if we're in the cloned directory
    if "/content/" in str(Path.cwd()):
        # We're in Colab, look for the cloned repo
        for path in Path("/content").iterdir():
            if path.is_dir() and (path / "cgkernel_config.json").exists():
                return path
    
    # Original logic for local development
    p = Path(start).resolve()
    for _ in range(max_up):
        if (p / "cgkernel_config.json").exists():
            return p
        if p.parent == p:
            break
        p = p.parent
    return Path(start).resolve()

PROJECT_ROOT = _discover_project_root("..")

# Ensure project root is importable before importing local modules
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# Now import kernel and helpers
from cgkernel import CGKernel
from demo_utils import (
    load_json,
    get_axis_label,
    get_display_label,
    get_category_label,
    load_dataframe_from_config,
    format_array_or_str,
)

CONFIG = load_json(PROJECT_ROOT / "cgkernel_config.json")

# Initialize kernel using the project root
kernel = CGKernel(config_dir=str(PROJECT_ROOT))

# Report device and warn if CPU
import torch
_device = "cuda" if torch.cuda.is_available() else "cpu"
if _device == "cuda":
    try:
        _device_name = torch.cuda.get_device_name(torch.cuda.current_device())
    except Exception:
        _device_name = "Unknown CUDA device"
    print(f"\nCurrently running on device: CUDA ({_device_name})\n")
else:
    print("\nCurrently running on device: CPU")
    print("Warning: Running on CPU. Inference is significantly faster on GPU.")

# Load dataset for plotting ranges and overlay points in later cells
try:
    DF_ALL = load_dataframe_from_config(PROJECT_ROOT, CONFIG)
except Exception as e:
    DF_ALL = pd.DataFrame()
    print(f"Dataset not loaded (plots that need it will be disabled): {e}")

# Load dataset for plotting ranges and overlay points in later cells
try:
    DF_ALL = load_dataframe_from_config(PROJECT_ROOT, CONFIG)
except Exception as e:
    DF_ALL = pd.DataFrame()
    print(f"Dataset not loaded (plots that need it will be disabled): {e}")

print("The kernel is ready. \n")
print("How to use: \n")
kernel.help()

## 2. Basic Property Predictions: Single and Batch
This cell demonstrates how to call kernel functions directly to predict various properties of the growing surface based on a set of macroscopic growth conditions using the kernel's function <b>predict_property()</b>.

Example include prediction of continuous properties: Growth rate (R), Kurtosis (Sku), Peak-to-valley (Sz), Average Roughness (Sa), RMS Roughness (Ssq), Skewness (Ssk)

We demonstrate both single-sample and batch predictions.



In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd


# Example inputs (edit values as needed) for a single set of growth conditions and for a batch (you may add as many rows as needed)

_sample = {
        "Cs": 0.03,
        "Pes": 0.2,
        "Pbias": 0.03,
        "T": 0.50,
        "Pd": 0.0001,
        "M": 0,
        "Nnucl": 2,
    }

_batch = [
        {"Cs": 0.03, "Pes": 0.20, "Pbias": 0.03, "T": 0.50, "Pd": 0.0001, "M": 0, "Nnucl": 2},
        {"Cs": 0.10, "Pes": 0.40, "Pbias": 0.00,  "T": 0.50, "Pd": 0.0052, "M": 2, "Nnucl": 1},
        {"Cs": 0.20, "Pes": 0.70, "Pbias": -0.10,  "T": 0.30, "Pd": 0.000023, "M": 4, "Nnucl": 0},
    ]

# Prediction of all the properties for a given set(s) of growth conditions
reg_keys = sorted(kernel.regressors.keys())


# Single-sample predictions
print("Single-sample predictions:\n")
for k in reg_keys:
    y = kernel.predict_property(k, _sample)[0]
    print(f"  {get_display_label(CONFIG, k, default=k)}: {y:.5g}")

# Batch predictions
print(f"\n Batch predictions (N={len(_batch)} rows):\n ")
for k in reg_keys:
    ys = kernel.predict_property(k, _batch)
    print(f"  {get_display_label(CONFIG, k, default=k)}: {format_array_or_str(ys, precision=4)}")
    


## 3. Interactive Example of prediction of continuous properties

This cell provides and interactive example for predictions. Please choose the values of growth conditions using sliders and see how it immdeiately affects the results.

In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import ipywidgets as widgets
from IPython.display import display
from demo_utils import percentile_range, clear_output_safely

# Fallbacks if helper label functions are not defined in the environment
try:
	get_axis_label
except NameError:
	def get_axis_label(config, name):  # noqa: N802
		return name

try:
	get_display_label
except NameError:
	def get_display_label(config, name, default=None):  # noqa: N802
		return default or name

try:
	get_category_label
except NameError:
	def get_category_label(config, cat_col, value):  # noqa: N802
		return f"{cat_col}={value}"

# Guard: need dataset for inputs
if DF_ALL.empty:
	display(widgets.HTML("<b>Dataset CSV not available.</b> Set 'ranges_dataset_csv' in cgkernel_config.json."))
else:
	# Feature choices from config
	FEATURES_CONT = (CONFIG.get("numeric_columns", []) or []) + (CONFIG.get("integer_columns", []) or [])
	CAT_COL = CONFIG.get("categorical_column", "Nnucl")
	REGRESSOR_KEYS = sorted(kernel.regressors.keys())

	# Regime dropdown
	unique_regs = sorted(
		pd.to_numeric(DF_ALL[CAT_COL], errors="coerce").dropna().astype(int).unique().tolist()
	)
	regime_dd = widgets.Dropdown(
		options=[(get_category_label(CONFIG, CAT_COL, int(v)), int(v)) for v in unique_regs],
		value=(unique_regs[2] if unique_regs else 0),
		description=get_axis_label(CONFIG, CAT_COL),
		style={"description_width": "auto"},
	)

	# Sliders for all continuous/integer features
	def make_feature_sliders() -> list[widgets.Widget]:
		sliders: list[widgets.Widget] = []
		int_cols = set(CONFIG.get("integer_columns", []) or [])
		for feat in FEATURES_CONT:
			vmin, vmax = percentile_range(DF_ALL[feat])
			if feat in int_cols:
				default_val = int(
					np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)
				) if DF_ALL[feat].notna().any() else int(round((vmin + vmax) / 2.0))
				w = widgets.IntSlider(
					value=default_val,
					min=int(np.floor(vmin)),
					max=int(np.ceil(vmax)),
					step=1,
					description=get_axis_label(CONFIG, feat),
					continuous_update=True,
					layout=widgets.Layout(width="96%"),
					style={"description_width": "auto"},
				)
			else:
				step = max((vmax - vmin) / 200.0, 1e-6)
				default_val = float(
					np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)
				) if DF_ALL[feat].notna().any() else float((vmin + vmax) / 2.0)
				w = widgets.FloatSlider(
					value=default_val,
					min=float(vmin),
					max=float(vmax),
					step=step,
					description=get_axis_label(CONFIG, feat),
					readout_format=".5g",
					continuous_update=True,
					layout=widgets.Layout(width="96%"),
					style={"description_width": "auto"},
				)
			# Attach original feature name for reverse lookup
			w._feature_name = feat  # type: ignore[attr-defined]
			sliders.append(w)
		return sliders

	sliders_box = widgets.VBox(make_feature_sliders())

	# Output area for predictions
	pred_out = widgets.Output()

	def collect_input_row() -> dict:
		row: dict = {}
		for w in sliders_box.children:
			feat_name = getattr(w, "_feature_name", w.description)
			val = w.value
			# Cast integers based on config
			if feat_name in (CONFIG.get("integer_columns", []) or []):
				row[feat_name] = int(val)
			else:
				row[feat_name] = float(val)
		# Always include selected regime
		row[CAT_COL] = int(regime_dd.value)
		return row

	def update(_=None):
		with pred_out:
			clear_output_safely(pred_out)
			if not REGRESSOR_KEYS:
				print("No regressors available.")
				return

			input_row = collect_input_row()

			# Predict all properties
			results = []
			for prop_key in REGRESSOR_KEYS:
				try:
					pred = kernel.predict_property(prop_key, [input_row])
					pred_val = float(pred[0]) if hasattr(pred, "__len__") else float(pred)
				except Exception as e:
					pred_val = np.nan
				results.append({
					"Property": get_display_label(CONFIG, prop_key, default=prop_key),
					"Key": prop_key,
					"Value": pred_val,
				})

			df_results = pd.DataFrame(results, columns=["Property",  "Value"])

			# Print current predictions
			print("\nPredicted values (all at once):")
			display(df_results)

	# Wire events
	regime_dd.observe(update, names="value")
	for w in sliders_box.children:
		w.observe(update, names="value")

	# Layout
	ui = widgets.VBox([
		widgets.HBox([regime_dd]),
		widgets.Label("Set input feature values:"),
		sliders_box,
	])

	display(ui)
	display(pred_out)
	update()

## 4. Overview of All Possible Growth Morphology Types Across the Dataset
This cell demonstrates the kernel's ability to display all encountered types of morphology across the dataset (function <b> show_morphology_classes()</b>). The morphologies were detected automatically using convolutional networks and clustering algorithms (see paper). The images shown here were pre-saved during the kernel creation process. For on-the-fly GAN-generated surface morphologies based on exact growth conditions, please proceed to the next cells. The color of the circle near the class number corresponds to

In [None]:
kernel.show_morphology_classes()

## 5. Basic Prediction of morphology type and growth stability
This cell demonstrates how to call kernel functions directly to predict morphology class (<b>predict_morphology_class()</b>) and estimate growth stability (<b>predict_stability()</b>)
For the predicted class the cell also shows pre-saved morphology of the predicted class (for on-the-fly generation, please refer to the next cells)

We demonstrate both single-sample and batch predictions.


In [None]:
import numpy as np
from pathlib import Path
from ipywidgets import Image as IPyImage

_sample = {
        "Cs": 0.03,
        "Pes": 0.2,
        "Pbias": 0.03,
        "T": 0.50,
        "Pd": 0.0001,
        "M": 0,
        "Nnucl": 2,
    }

_batch = [
        {"Cs": 0.03, "Pes": 0.20, "Pbias": 0.03, "T": 0.50, "Pd": 0.0001, "M": 0, "Nnucl": 2},
        {"Cs": 0.001, "Pes": 0.20, "Pbias": 0.25,  "T": 1, "Pd": 0.0052, "M": 0, "Nnucl": 2},
        {"Cs": 0.20, "Pes": 0.70, "Pbias": -0.10,  "T": 0.30, "Pd": 0.000023, "M": 4, "Nnucl": 0},
    ]

print("Single-sample predictions:\n")

print (f'Growth conditions set: {_sample}')
# Prediction of stability of the growth at given growth conditions
s_idx, s_proba = kernel.predict_stability(_sample)
s = int(s_idx[0])
s_label = "Stable" if s == 1 else "Unstable"
print(f"  Stability: {s} ({s_label})")

# Prediction of a type of growth morphology  at given growth conditions and showing the pre-saved example of such type of the surface.
c_idx, c_proba = kernel.predict_morphology_class(_sample)
try:
    c_val = c_idx.flatten()[0] if isinstance(c_idx, np.ndarray) else c_idx
    c = int(str(c_val).strip())
    print(f"  Morphology class: {c}")

    p = (PROJECT_ROOT / "morphologies" / f"{c}.png")
    if p.exists():
        with open(p, 'rb') as f:
            img_data = f.read()
        display(IPyImage(value=img_data, width=150, height=150))
except Exception:
    pass

# Batch predictions
print(f"\n Batch predictions for {len(_batch)} rows:\n ")
print (f'Growth conditions set list: {_batch}')
s_idx, s_proba = kernel.predict_stability(_batch)
if isinstance(s_idx, np.ndarray):
    s_labels = ["Stable" if int(v) == 1 else "Unstable" for v in s_idx]
    print(f"  Stability: {format_array_or_str(s_idx.astype(int), separator=',')} ({', '.join(s_labels)})")
else:
    val = str(s_idx).strip()
    try:
        vi = int(val)
        s_label = "Stable" if vi == 1 else "Unstable"
        print(f"  Stability: {val} ({s_label})")
    except Exception:
        print(f"  Stability: {val}")


# Prediction of a class of growth morphology at given growth conditions and showing the pre-saved example of such type of the surface.
c_idx, c_proba = kernel.predict_morphology_class(_batch)
if isinstance(c_idx, np.ndarray):
    print(f"  Morphology classes: {format_array_or_str(c_idx.astype(int), separator=',')}")
    try:
        uniq = np.unique(c_idx.astype(int)).tolist()
        image_widgets = []  # List to hold image widgets
        for cid in uniq:
            p = (PROJECT_ROOT / "morphologies" / f"{int(cid)}.png")
            if p.exists():
                # Open the image and read it as bytes
                with open(p, 'rb') as f:
                    img_data = f.read()
                
                # Create the image widget using the image data (not the path)
                img_widget = IPyImage(value=img_data, width=150, height=150)
                image_widgets.append(img_widget)
        
        if image_widgets:
            # Debug: Check if image widgets are collected correctly
            print(f"\nThe pre-saved morhologies for {len(image_widgets)} predicted classes: ")
            row = widgets.HBox(image_widgets)
            display(row)
        else:
            print("No images found for the morphology classes.")
    except Exception as e:
        print(f"Error: {e}")
else:
    print(f"  Morphology classes: {str(c_idx)}")


## 6. Interactive 3D prediction surface with dataset points 

This interactive cell allows one to plot the dependence of various output properties, namely: Growth Rate (R), Kurtosis (Sku), Peak-to-Valley (Sz), Average Roughness (Sa), RMS Roughness (Ssq), and Skewness (Ssk) on 2 input parameters across any axes.

- Choose the X/Y axes and set the scale (linear/log) for any axis. Adjust other parameters using sliders to see how the dependence of Z vs. X and Y changes.
- Compare kernel predictions (solid surface) against the original dataset points (black dots) saved in the kernel folder.

The 'Grid' parameter regulates the size of the grid on which the prediction is made, while the 'Points from the dataset to show' parameter controls how many points from the dataset are displayed (it may work slower if this number is large).



In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display

from demo_utils import clear_output_safely, percentile_range, build_meshgrid, load_colormap_from_png, find_colormap_png

# Guard: need dataset for overlay
if DF_ALL.empty:
    display(widgets.HTML("<b>Dataset CSV not available.</b> Set 'ranges_dataset_csv' in cgkernel_config.json."))
else:
    # Feature choices from config
    FEATURES_CONT = (CONFIG.get("numeric_columns", []) or []) + (CONFIG.get("integer_columns", []) or [])
    CAT_COL = CONFIG.get("categorical_column", "Nnucl")
    REGRESSOR_KEYS = sorted(kernel.regressors.keys())

    # Widgets
    prop_options = [(get_display_label(CONFIG, k, default=k), k) for k in REGRESSOR_KEYS]
    prop_dd = widgets.Dropdown(options=prop_options, value=(REGRESSOR_KEYS[4] if REGRESSOR_KEYS else None), description="Property:", style={"description_width": "auto"})
    x_dd = widgets.Dropdown(options=FEATURES_CONT, value=(FEATURES_CONT[0] if FEATURES_CONT else None), description="X:")
    y_dd = widgets.Dropdown(options=FEATURES_CONT, value=(FEATURES_CONT[1] if len(FEATURES_CONT) > 1 else (FEATURES_CONT[0] if FEATURES_CONT else None)), description="Y:")
    x_scale = widgets.ToggleButtons(options=[("linear", "linear"), ("log10", "log")], value="log", description="X scale:")
    y_scale = widgets.ToggleButtons(options=[("linear", "linear"), ("log10", "log")], value="linear", description="Y scale:")
    z_scale = widgets.ToggleButtons(options=[("linear", "linear"), ("log10", "log")], value="log", description="Z scale:")
    grid_slider = widgets.IntSlider(value=60, min=20, max=200, step=5, description="Grid:", continuous_update=False)
    unique_regs_3d = sorted(pd.to_numeric(DF_ALL[CAT_COL], errors="coerce").dropna().astype(int).unique().tolist())
    regime_dd = widgets.Dropdown(options=[(get_category_label(CONFIG, CAT_COL, int(v)), int(v)) for v in unique_regs_3d], value=(unique_regs_3d[2] if unique_regs_3d else 0), description=get_axis_label(CONFIG, CAT_COL), style={"description_width": "auto"})
    subsample_slider = widgets.IntSlider(value=3000, min=0, max=50000, step=500, description="Random points from the dataset to show:", continuous_update=False, layout=widgets.Layout(width="400px"), style={"description_width": "auto"})

    # Sliders for non-axis features
    def make_feature_sliders(x_name: str, y_name: str):
        sliders = []
        for feat in FEATURES_CONT:
            if feat in (x_name, y_name):
                continue
            vmin, vmax = percentile_range(DF_ALL[feat])
            step = max((vmax - vmin) / 200.0, 1e-6)
            if feat in (CONFIG.get("integer_columns", []) or []):
                w = widgets.IntSlider(value=int(np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)) if DF_ALL[feat].notna().any() else int(round((vmin + vmax) / 2.0)), min=int(np.floor(vmin)), max=int(np.ceil(vmax)), step=1, description=get_axis_label(CONFIG, feat), continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"})
            else:
                w = widgets.FloatSlider(value=float(np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)) if DF_ALL[feat].notna().any() else float((vmin + vmax) / 2.0), min=float(vmin), max=float(vmax), step=step, description=get_axis_label(CONFIG, feat), readout_format=".5g", continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"})
            sliders.append(w)
        return sliders

    sliders_box = widgets.VBox(make_feature_sliders(x_dd.value, y_dd.value))
    plot_out = widgets.Output()

    # Colormap
    cmap_png = find_colormap_png(PROJECT_ROOT, CONFIG)
    cmap = load_colormap_from_png(cmap_png) if cmap_png else None

    # Predict grid via kernel
    def predict_surface(prop_key: str, Xg: np.ndarray, Yg: np.ndarray, x_name: str, y_name: str, fixed_regime: Optional[int], const_values: dict) -> np.ndarray:
        vals = []
        for i in range(Xg.size):
            row: dict = {**const_values}
            row[x_name] = float(Xg.reshape(-1)[i])
            row[y_name] = float(Yg.reshape(-1)[i])
            # Always include selected regime
            row[CAT_COL] = int(regime_dd.value)
            vals.append(row)
        preds = kernel.predict_property(prop_key, vals)
        return preds.reshape(Xg.shape)

    def update(_=None):
        with plot_out:
            clear_output_safely(plot_out)
            if not prop_dd.value:
                print("No regressor selected.")
                return
            # Build grid
            Xv, Yv, Xg, Yg = build_meshgrid(DF_ALL, x_dd.value, y_dd.value, int(grid_slider.value), x_scale.value == "log", y_scale.value == "log")

            # Constant values
            const = {}
            for w in sliders_box.children:
                # Map back to canonical feature name by reverse label lookup
                label_to_name = {get_axis_label(CONFIG, f): f for f in FEATURES_CONT}
                name = label_to_name.get(w.description, w.description)
                const[name] = float(w.value)

            # Regime and scatter selection
            fixed_regime = None
            scatter_df = DF_ALL[[x_dd.value, y_dd.value, CAT_COL, prop_dd.value]].dropna()
            if regime_dd.value != "All":
                fixed_regime = int(regime_dd.value)
                scatter_df = scatter_df[pd.to_numeric(DF_ALL[CAT_COL], errors="coerce").astype(int) == fixed_regime]

            # Predict
            Z = predict_surface(prop_dd.value, Xg, Yg, x_dd.value, y_dd.value, fixed_regime, const)

            # Apply z-scale
            if z_scale.value == "log":
                Z_plot = np.log10(np.clip(Z, 1e-12, None))
            else:
                Z_plot = Z

            # Prepare scatter overlay
            ss = subsample_slider.value
            if ss and len(scatter_df) > ss:
                scatter_df = scatter_df.sample(n=ss, random_state=0)
            x_sc = scatter_df[x_dd.value].to_numpy()
            y_sc = scatter_df[y_dd.value].to_numpy()
            z_sc_raw = scatter_df[prop_dd.value].to_numpy()
            z_sc = np.log10(np.clip(z_sc_raw, 1e-12, None)) if z_scale.value == "log" else z_sc_raw
            # Build figure
            fig = make_subplots(rows=1, cols=1, specs=[[{"type": "scene"}]])
            fig.add_trace(go.Surface(x=np.log10(np.clip(Xv, 1e-12, None)) if x_scale.value == "log" else Xv, y=np.log10(np.clip(Yv, 1e-12, None)) if y_scale.value == "log" else Yv, z=Z_plot, colorscale=(cmap.scale if cmap else "Viridis"), opacity=0.9, showscale=True))
            fig.add_trace(go.Scatter3d(x=np.log10(np.clip(x_sc, 1e-12, None)) if x_scale.value == "log" else x_sc, y=np.log10(np.clip(y_sc, 1e-12, None)) if y_scale.value == "log" else y_sc, z=z_sc, mode="markers", marker=dict(size=2, color="black", opacity=0.25), name="Dataset points"))
            fig.update_layout(template="plotly_white", width=1100, height=700, margin=dict(l=0, r=0, t=40, b=0), title=f"{get_display_label(CONFIG, prop_dd.value)} vs {get_axis_label(CONFIG, x_dd.value)} and {get_axis_label(CONFIG, y_dd.value)}")
            fig.update_scenes(xaxis_title=("log₁₀(" + get_axis_label(CONFIG, x_dd.value) + ")" if x_scale.value == "log" else get_axis_label(CONFIG, x_dd.value)), yaxis_title=("log₁₀(" + get_axis_label(CONFIG, y_dd.value) + ")" if y_scale.value == "log" else get_axis_label(CONFIG, y_dd.value)), zaxis_title=("log₁₀(" + get_display_label(CONFIG, prop_dd.value) + ")" if z_scale.value == "log" else get_display_label(CONFIG, prop_dd.value)))
            
            # Display the plot with proper configuration for both Colab and JupyterLab
            show_plot(fig)

    def on_axes_change(_):
        sliders_box.children = make_feature_sliders(x_dd.value, y_dd.value)
        for w in sliders_box.children:
            w.observe(update, names="value")
        update()

    # Wire events
    prop_dd.observe(update, names="value")
    x_dd.observe(on_axes_change, names="value")
    y_dd.observe(on_axes_change, names="value")
    x_scale.observe(update, names="value")
    y_scale.observe(update, names="value")
    z_scale.observe(update, names="value")
    grid_slider.observe(update, names="value")
    regime_dd.observe(update, names="value")
    subsample_slider.observe(update, names="value")
    for w in sliders_box.children:
        w.observe(update, names="value")

    # Layout
    ui = widgets.VBox([
        widgets.HBox([prop_dd, regime_dd]),
        widgets.HBox([x_dd, y_dd, x_scale, y_scale, z_scale, grid_slider, subsample_slider]),
        widgets.Label("Set constant values for non-axis features:"),
        sliders_box,
    ])
    display(ui)
    display(plot_out)
    update()


## 7. Interactive growth regime map and its confidence map 

This interactive cell allows you to plot the dependence of morphology type on two input parameters across any axes.

- Choose axes and scales (linear/log).
- View predictions and confidence for the morphology classifier across a 2D plane on any axes.
- Colors are consistent with those shown in the examples of the classes provided by the function <b>show_morphology_classes()</b>.

Note: The map uses the kernel’s morphology classifier (<b>predict_morphology_class()</b>). Representative pre-saved thumbnails of the morphology classes displayed on the map are shown below for convenience.


In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display

from demo_utils import clear_output_safely, percentile_range, build_meshgrid, load_colormap_from_png, find_colormap_png, make_discrete_step_colorscale

if "morphology" not in kernel.classifiers or DF_ALL.empty:
    display(widgets.HTML("<b>Morphology classifier or dataset unavailable.</b>"))
else:
    FEATURES_CONT = (CONFIG.get("numeric_columns", []) or []) + (CONFIG.get("integer_columns", []) or [])
    CAT_COL = CONFIG.get("categorical_column", "Nnucl")

    x_dd = widgets.Dropdown(options=FEATURES_CONT, value=("Cs" if "Cs" in FEATURES_CONT else (FEATURES_CONT[0] if FEATURES_CONT else None)), description="X:")
    y_dd = widgets.Dropdown(options=FEATURES_CONT, value=("Pbias" if "Pbias" in FEATURES_CONT else (FEATURES_CONT[1] if len(FEATURES_CONT) > 1 else (FEATURES_CONT[0] if FEATURES_CONT else None))), description="Y:")
    unique_regs = sorted(pd.to_numeric(DF_ALL[CAT_COL], errors="coerce").dropna().astype(int).unique().tolist())
    regime_dd = widgets.Dropdown(options=[(get_category_label(CONFIG, CAT_COL, int(v)), int(v)) for v in unique_regs], value=(2 if 2 in unique_regs else int(unique_regs[0])), description=get_axis_label(CONFIG, CAT_COL))
    grid_slider = widgets.IntSlider(value=120, min=20, max=300, step=10, description="Grid:", continuous_update=False)
    x_scale = widgets.ToggleButtons(options=[("linear", "linear"), ("log10", "log")], value="log", description="X scale:")
    y_scale = widgets.ToggleButtons(options=[("linear", "linear"), ("log10", "log")], value="linear", description="Y scale:")

    # Sliders for non-axis features (constant values)
    def make_feature_sliders_m(x_name: str, y_name: str):
        sliders = []
        preset_values = {"Pes": 0.5, "T": 1.0, "Pd": 0.005, "M": 0}
        for feat in FEATURES_CONT:
            if feat in (x_name, y_name):
                continue
            vmin, vmax = percentile_range(DF_ALL[feat])
            step = max((vmax - vmin) / 200.0, 1e-6)
            if feat in (CONFIG.get("integer_columns", []) or []):
                default_val = int(preset_values.get(feat, int(np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)) if DF_ALL[feat].notna().any() else int(round((vmin + vmax) / 2.0))))
                w = widgets.IntSlider(
                    value=default_val,
                    min=int(np.floor(vmin)), max=int(np.ceil(vmax)), step=1,
                    description=get_axis_label(CONFIG, feat), continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"}
                )
            else:
                default_val = float(preset_values.get(feat, float(np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)) if DF_ALL[feat].notna().any() else float((vmin + vmax) / 2.0)))
                w = widgets.FloatSlider(
                    value=default_val,
                    min=float(vmin), max=float(vmax), step=step,
                    description=get_axis_label(CONFIG, feat), readout_format=".5g", continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"}
                )
            sliders.append(w)
        return sliders

    sliders_box_m = widgets.VBox(make_feature_sliders_m(x_dd.value, y_dd.value))

    plot_out = widgets.Output(layout=widgets.Layout(height="auto", max_height="none", overflow_y="visible"))
    images_out = widgets.Output(layout=widgets.Layout(height="auto", max_height="none", overflow_y="visible"))

    # Colors
    # Fixed regime colors (cluster palette)
    REGIME_COLORS = ['#0B4EB5', '#3195C5', '#EA733B', '#207AD4', '#DC4328', '#E35B31', '#3FA599', '#F18A44', '#D52B1E', '#0039A6', '#88B67A', '#FCC158', '#389DAF', '#57AE86', '#2A8DDC', '#B9BF6D', '#F49C4B', '#F8AF51', '#1564C4', '#EAC861']
    # Keep PNG only for confidence or other plots if needed
    cmap_png = find_colormap_png(PROJECT_ROOT, CONFIG)
    cmap = load_colormap_from_png(cmap_png) if cmap_png else None

    def predict_classes(Xg: np.ndarray, Yg: np.ndarray, x_name: str, y_name: str, fixed_regime: int, const_vals: dict) -> tuple[np.ndarray, np.ndarray]:
        vals = []
        for i in range(Xg.size):
            row = {**const_vals}
            row[x_name] = float(Xg.reshape(-1)[i])
            row[y_name] = float(Yg.reshape(-1)[i])
            row[CAT_COL] = int(fixed_regime)
            vals.append(row)
        idx, proba = kernel.predict_morphology_class(vals)
        conf = np.max(proba, axis=1) if proba is not None else np.ones_like(idx, dtype=float)
        return idx.reshape(Xg.shape), conf.reshape(Xg.shape)

    def update(_=None):
        with plot_out:
            clear_output_safely(plot_out)
            Xv, Yv, Xg, Yg = build_meshgrid(DF_ALL, x_dd.value, y_dd.value, int(grid_slider.value), x_scale.value == "log", y_scale.value == "log")
            # Build constant values from sliders
            const = {}
            label_to_name = {get_axis_label(CONFIG, f): f for f in FEATURES_CONT}
            for w in sliders_box_m.children:
                name = label_to_name.get(w.description, w.description)
                const[name] = float(w.value)
            z_idx, z_conf = predict_classes(Xg, Yg, x_dd.value, y_dd.value, int(regime_dd.value), const)

            # Colorscales: class (discrete) uses fixed palette; confidence uses PNG scale
            palette = [REGIME_COLORS[i % len(REGIME_COLORS)] for i in range(int(np.max(z_idx)) + 1)]
            class_scale = make_discrete_step_colorscale(palette)

            x_plot = np.log10(np.clip(Xv, 1e-12, None)) if x_scale.value == "log" else Xv
            y_plot = np.log10(np.clip(Yv, 1e-12, None)) if y_scale.value == "log" else Yv

            fig = make_subplots(rows=1, cols=2, subplot_titles=("Predicted class", "Confidence (probability for the most probable class )"))
            fig.add_trace(go.Heatmap(x=x_plot, y=y_plot, z=z_idx.astype(int), colorscale=class_scale, showscale=False, hovertemplate="x=%{x:.4g}<br>y=%{y:.4g}<br>class=%{z}<extra></extra>"), row=1, col=1)
            fig.add_trace(go.Heatmap(x=x_plot, y=y_plot, z=z_conf, colorscale="Agsunset", zmin=0.0, zmax=1.0, colorbar=dict(title="Confidence"), hovertemplate="x=%{x:.4g}<br>y=%{y:.4g}<br>conf=%{z:.3f}<extra></extra>"), row=1, col=2)
            fig.update_layout(template="plotly_white", title="Growth regime (class) and confidence", width=1200, height=600, margin=dict(l=0, r=0, t=40, b=0))
            fig.update_xaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, x_dd.value) + ")" if x_scale.value == "log" else get_axis_label(CONFIG, x_dd.value)), row=1, col=1)
            fig.update_yaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, y_dd.value) + ")" if y_scale.value == "log" else get_axis_label(CONFIG, y_dd.value)), row=1, col=1)
            fig.update_xaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, x_dd.value) + ")" if x_scale.value == "log" else get_axis_label(CONFIG, x_dd.value)), row=1, col=2)
            fig.update_yaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, y_dd.value) + ")" if y_scale.value == "log" else get_axis_label(CONFIG, y_dd.value)), row=1, col=2)
            
            # Display the plot with proper configuration for both Colab and JupyterLab
            show_plot(fig)
        
        # Update images in separate output
        update_images(z_idx)

    def update_images(z_idx):
        with images_out:
            clear_output_safely(images_out)
            # Show example morphologies for classes present in the map
            classes_present = sorted(np.unique(z_idx.astype(int)).tolist())
            thumbs = []
            for cid in classes_present:
                p = (PROJECT_ROOT / "morphologies" / f"{cid}.png")
                if p.exists():
                    import base64
                    with p.open("rb") as _f:
                        _b64 = base64.b64encode(_f.read()).decode("ascii")
                    src = f"data:image/png;base64,{_b64}"
                    thumbs.append(f'<div style="margin:6px;text-align:center;"><img src="{src}" style="height:140px;image-rendering:crisp-edges;"><div style="font-size:12px;">Class {cid}</div></div>')
            if thumbs:
                html = '<div style="display:flex;flex-wrap:wrap;align-items:flex-start;">' + "".join(thumbs) + "</div>"
                display(widgets.HTML(value=html))

    def on_axes_change_m(_):
        sliders_box_m.children = make_feature_sliders_m(x_dd.value, y_dd.value)
        for w in sliders_box_m.children:
            w.observe(update, names="value")
        update()

    # Wire and layout
    x_dd.observe(on_axes_change_m, names="value")
    y_dd.observe(on_axes_change_m, names="value")
    regime_dd.observe(update, names="value")
    grid_slider.observe(update, names="value")
    x_scale.observe(update, names="value")
    y_scale.observe(update, names="value")
    for w in sliders_box_m.children:
        w.observe(update, names="value")

    ctrl = widgets.VBox([
        widgets.HBox([x_dd, y_dd, grid_slider]),
        widgets.HBox([regime_dd, x_scale, y_scale]),
        widgets.Label("Set constant values for non-axis features:"),
        sliders_box_m,
    ])
    display(ctrl)
    display(plot_out)
    display(images_out)
    display(widgets.HTML("""
    <style>
    /* Expand the Jupyter output area for this cell */
    div.jp-OutputArea-output {
        max-height: none !important;
        overflow: visible !important;
    }
    div.output_subarea {
        max-height: none !important;
        overflow: visible !important;
    }
    </style>
    """))
    update()


## 8. Interactive stability map and its confidence map 

This cell mirrors the regime map but uses the stability classifier (<b>predict_stability()</b>) of the kernel.

Select axes and scales and analyze how different parameters affect the dynamical stability of the growth process.


In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import display

from demo_utils import clear_output_safely, percentile_range, build_meshgrid, load_colormap_from_png, find_colormap_png, make_discrete_step_colorscale

if "stability" not in kernel.classifiers or DF_ALL.empty:
    display(widgets.HTML("<b>Stability classifier or dataset unavailable.</b>"))
else:
    FEATURES_CONT = (CONFIG.get("numeric_columns", []) or []) + (CONFIG.get("integer_columns", []) or [])
    CAT_COL = CONFIG.get("categorical_column", "Nnucl")

    x_dd_s = widgets.Dropdown(options=FEATURES_CONT, value=("Cs" if "Cs" in FEATURES_CONT else (FEATURES_CONT[0] if FEATURES_CONT else None)), description="X:")
    y_dd_s = widgets.Dropdown(options=FEATURES_CONT, value=("Pbias" if "Pbias" in FEATURES_CONT else (FEATURES_CONT[1] if len(FEATURES_CONT) > 1 else (FEATURES_CONT[0] if FEATURES_CONT else None))), description="Y:")
    unique_regs_s = sorted(pd.to_numeric(DF_ALL[CAT_COL], errors="coerce").dropna().astype(int).unique().tolist())
    regime_dd_s = widgets.Dropdown(options=[(get_category_label(CONFIG, CAT_COL, int(v)), int(v)) for v in unique_regs_s], value=(2 if 2 in unique_regs_s else int(unique_regs_s[0])), description=get_axis_label(CONFIG, CAT_COL), style={"description_width": "auto"})
    grid_slider_s = widgets.IntSlider(value=120, min=20, max=300, step=10, description="Grid:", continuous_update=False)
    x_scale_s = widgets.ToggleButtons(options=[("linear", "linear"), ("log10", "log")], value="log", description="X scale:")
    y_scale_s = widgets.ToggleButtons(options=[("linear", "linear"), ("log10", "log")], value="linear", description="Y scale:")

    plot_out_s = widgets.Output(layout=widgets.Layout(height="auto", max_height="none", overflow_y="visible"))

    # Note: confidence colormap may still use PNG; class colors are fixed below
    cmap_png_s = find_colormap_png(PROJECT_ROOT, CONFIG)
    cmap_s = load_colormap_from_png(cmap_png_s) if cmap_png_s else None

    def predict_stab(Xg: np.ndarray, Yg: np.ndarray, x_name: str, y_name: str, fixed_regime: int, const_vals: dict) -> tuple[np.ndarray, np.ndarray]:
        vals = []
        for i in range(Xg.size):
            row = {**const_vals}
            row[x_name] = float(Xg.reshape(-1)[i])
            row[y_name] = float(Yg.reshape(-1)[i])
            row[CAT_COL] = int(fixed_regime)
            vals.append(row)
        idx, proba = kernel.predict_stability(vals)
        conf = np.max(proba, axis=1) if proba is not None else np.ones_like(idx, dtype=float)
        return idx.reshape(Xg.shape), conf.reshape(Xg.shape)


    # Sliders for non-axis features (constant values) for stability
    def make_feature_sliders_s(x_name: str, y_name: str):
        sliders = []
        preset_values = {"Pes": 0.5, "T": 1.0, "Pd": 0.005, "M": 0}
        for feat in FEATURES_CONT:
            if feat in (x_name, y_name):
                continue
            vmin, vmax = percentile_range(DF_ALL[feat])
            step = max((vmax - vmin) / 200.0, 1e-6)
            if feat in (CONFIG.get("integer_columns", []) or []):
                default_val = int(preset_values.get(feat, int(np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)) if DF_ALL[feat].notna().any() else int(round((vmin + vmax) / 2.0))))
                w = widgets.IntSlider(
                    value=default_val,
                    min=int(np.floor(vmin)), max=int(np.ceil(vmax)), step=1,
                    description=get_axis_label(CONFIG, feat), continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"}
                )
            else:
                default_val = float(preset_values.get(feat, float(np.median(pd.to_numeric(DF_ALL[feat], errors="coerce").dropna().values)) if DF_ALL[feat].notna().any() else float((vmin + vmax) / 2.0)))
                w = widgets.FloatSlider(
                    value=default_val,
                    min=float(vmin), max=float(vmax), step=step,
                    description=get_axis_label(CONFIG, feat), readout_format=".5g", continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"}
                )
            sliders.append(w)
        return sliders

    sliders_box_s = widgets.VBox(make_feature_sliders_s(x_dd_s.value, y_dd_s.value))

    def on_axes_change_s(_):
        sliders_box_s.children = make_feature_sliders_s(x_dd_s.value, y_dd_s.value)
        for w in sliders_box_s.children:
            w.observe(update_s, names="value")
        update_s()

    # Final update function (uses slider values for constants)
    def update_s(_=None):
        with plot_out_s:
            clear_output_safely(plot_out_s)
            Xv, Yv, Xg, Yg = build_meshgrid(DF_ALL, x_dd_s.value, y_dd_s.value, int(grid_slider_s.value), x_scale_s.value == "log", y_scale_s.value == "log")
            const = {}
            label_to_name = {get_axis_label(CONFIG, f): f for f in FEATURES_CONT}
            for w in sliders_box_s.children:
                name = label_to_name.get(w.description, w.description)
                const[name] = float(w.value)
            z_idx, z_conf = predict_stab(Xg, Yg, x_dd_s.value, y_dd_s.value, int(regime_dd_s.value), const)

            # Use a stepped binary colorscale and fix z-range to avoid interpolation to brown
            class_scale = [
                [0.0, "#F24942"], [0.4999, "#F24942"],
                [0.5, "#3CB27D"], [1.0, "#3CB27D"]
            ]

            x_plot = np.log10(np.clip(Xv, 1e-12, None)) if x_scale_s.value == "log" else Xv
            y_plot = np.log10(np.clip(Yv, 1e-12, None)) if y_scale_s.value == "log" else Yv

            fig = make_subplots(rows=1, cols=2, subplot_titles=("Predicted clas: 0 (Red) - Unstable, 1 (Green) - Stable", "Confidence (max probability)"))
            fig.add_trace(
                go.Heatmap(
                    x=x_plot, y=y_plot, z=z_idx.astype(int),
                    colorscale=class_scale, zmin=0, zmax=1, showscale=False,
                    hovertemplate="x=%{x:.4g}<br>y=%{y:.4g}<br>class=%{z}<extra></extra>"
                ),
                row=1, col=1
            )
            fig.add_trace(
                go.Heatmap(
                    x=x_plot, y=y_plot, z=z_conf,
                    colorscale="Agsunset", zmin=0.0, zmax=1.0,
                    colorbar=dict(title="Confidence"),
                    hovertemplate="x=%{x:.4g}<br>y=%{y:.4g}<br>conf=%{z:.3f}<extra></extra>"
                ),
                row=1, col=2
            )
            fig.update_layout(template="plotly_white", title="Stability (class) and confidence", width=1200, height=600, margin=dict(l=0, r=0, t=40, b=0))
            fig.update_xaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, x_dd_s.value) + ")" if x_scale_s.value == "log" else get_axis_label(CONFIG, x_dd_s.value)), row=1, col=1)
            fig.update_yaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, y_dd_s.value) + ")" if y_scale_s.value == "log" else get_axis_label(CONFIG, y_dd_s.value)), row=1, col=1)
            fig.update_xaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, x_dd_s.value) + ")" if x_scale_s.value == "log" else get_axis_label(CONFIG, x_dd_s.value)), row=1, col=2)
            fig.update_yaxes(title_text=("log₁₀(" + get_axis_label(CONFIG, y_dd_s.value) + ")" if y_scale_s.value == "log" else get_axis_label(CONFIG, y_dd_s.value)), row=1, col=2)
            
            # Display the plot with proper configuration for both Colab and JupyterLab
            show_plot(fig)

    # Wire and layout
    x_dd_s.observe(on_axes_change_s, names="value")
    y_dd_s.observe(on_axes_change_s, names="value")
    regime_dd_s.observe(update_s, names="value")
    grid_slider_s.observe(update_s, names="value")
    x_scale_s.observe(update_s, names="value")
    y_scale_s.observe(update_s, names="value")
    
    # Wire sliders to update function
    for w in sliders_box_s.children:
        w.observe(update_s, names="value")

    ctrl_s = widgets.VBox([
        widgets.HBox([x_dd_s, y_dd_s, grid_slider_s]),
        widgets.HBox([regime_dd_s, x_scale_s, y_scale_s]),
        widgets.Label("Set constant values for non-axis features:"),
        sliders_box_s,
    ])
    display(ctrl_s)
    display(plot_out_s)
    display(widgets.HTML("""
    <style>
    /* Expand the Jupyter output area for this cell */
    div.jp-OutputArea-output {
        max-height: none !important;
        overflow: visible !important;
    }
    div.output_subarea {
        max-height: none !important;
        overflow: visible !important;
    }
    </style>
    """))
    update_s()


## 9. Interactive morphology surface generation 

This cell demonstrates the capability of the kernel's function <b>generate_morphology(...)</b>, which uses a trained generative adversarial network for on-the-fly generation of a 256*256 surface morphology based on a given set of growth conditions (not a pre-saved general representation of the class, but a precise representation of the surface for this exact set).

Use the sliders to adjust the growth conditions and explore how the surface evolves under the influence of each parameter. The figure updates whenever any parameter changes. 

You can also choose the random seed to generate similar figures for the same set of growth conditions.

The option "Scale to Sz" scales the image from [0:1] to the Peak-to-Valley height, preidicted by kernel for this set of growth condtions.

In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import ipywidgets as widgets
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display

from demo_utils import clear_output_safely, load_colormap_from_png,find_colormap_png, get_presaved_morphology_parameters

# Guard for generator
HAS_GEN = getattr(kernel, "generator", None) is not None
if not HAS_GEN:
    display(widgets.HTML("<b>Generator not available.</b> Place TorchScript at models/generator.ts or configure 'generator' in cgkernel_config.json."))
else:
    FEATURES_CONT = (CONFIG.get("numeric_columns", []) or []) + (CONFIG.get("integer_columns", []) or [])
    CAT_COL = CONFIG.get("categorical_column", "Nnucl")

    # Load colormap from config
    colormap_png = find_colormap_png(PROJECT_ROOT, CONFIG)
    if colormap_png is not None:
        cmap = load_colormap_from_png(colormap_png)
        colorscale = cmap.scale
    else:
        colorscale = "Viridis"  # fallback

    # Build sliders from config percentiles/ranges
    def slider_for_feature(name: str):
        fs = next((f for f in CONFIG.get("features", []) if f.get("name") == name), None)
        if not fs:
            return None
        vmin = float(fs.get("min", 0.0))
        vmax = float(fs.get("max", 1.0))
        if name in (CONFIG.get("integer_columns", []) or []):
            return widgets.IntSlider(value=int(round((vmin + vmax) / 2.0)), min=int(np.floor(vmin)), max=int(np.ceil(vmax)), step=1, description=get_axis_label(CONFIG, name), continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"})
        else:
            step = max((vmax - vmin) / 200.0, 1e-6)
            return widgets.FloatSlider(value=float((vmin + vmax) / 2.0), min=float(vmin), max=float(vmax), step=step, description=get_axis_label(CONFIG, name), readout_format=".5g", continuous_update=True, layout=widgets.Layout(width="96%"), style={"description_width": "auto"})

    sliders = [slider_for_feature(n) for n in FEATURES_CONT]
    sliders = [w for w in sliders if w is not None]
    regime_dd_g = widgets.Dropdown(options=[(get_category_label(CONFIG, CAT_COL, int(v)), int(v)) for v in range(int(next((f.get("max") for f in CONFIG.get("features", []) if f.get("name") == CAT_COL), 2)) + 1)], value=int(CONFIG.get("defaults", {}).get(CAT_COL, 0)), description=get_axis_label(CONFIG, CAT_COL), style={"description_width": "auto"})
    seed_slider = widgets.IntSlider(value=0, min=0, max=10000, step=1, description="Seed:", continuous_update=False)
    scale_chk = widgets.Checkbox(value=True, description="Scale by Sz")
    show_3d_chk = widgets.Checkbox(value=False, description="Show 3D graph")

    plot_out_2d = widgets.Output()
    plot_out_3d = widgets.Output()

    # Create presaved sets thumbnails
    def create_presaved_thumbnails():
        """Create clickable thumbnails for presaved morphology sets."""
        thumbnails = []
        
        for class_id in range(20):  # Classes 0-19
            img_path = PROJECT_ROOT / "morphologies" / f"{class_id}.png"
            if img_path.exists():
                # Read image and convert to base64 for HTML embedding
                import base64
                with open(img_path, 'rb') as f:
                    img_data = f.read()
                img_b64 = base64.b64encode(img_data).decode('ascii')
                
                # Create button with image background
                button = widgets.Button(
                    description=f"Class {class_id}",
                    layout=widgets.Layout(
                        width='70px',
                        height='70px',
                        margin='2px',
                        border='1px solid #ccc'
                    ),
                    style=widgets.ButtonStyle(
                        font_size='8px'
                    )
                )
                
                # Add click handler
                def make_click_handler(cid):
                    def on_click(b):
                        apply_presaved_parameters(cid)
                    return on_click
                
                button.on_click(make_click_handler(class_id))
                
                # Create container with button and image
                container = widgets.VBox([
                    widgets.HTML(f"""
                        <div style="
                            width: 60px; 
                            height: 50px; 
                            background-image: url('data:image/png;base64,{img_b64}');
                            background-size: contain;
                            background-repeat: no-repeat;
                            background-position: center;
                            border: 1px solid #ddd;
                            margin: 2px;
                        "></div>
                    """),
                    button
                ], layout=widgets.Layout(width='75px', height='90px'))
                
                thumbnails.append(container)
        
        return thumbnails

    def apply_presaved_parameters(class_id):
        """Apply predefined parameters for a specific morphology class."""
        params = get_presaved_morphology_parameters(class_id)
        
        # Update sliders with new values
        for w in sliders:
            label_to_name = {get_axis_label(CONFIG, f): f for f in FEATURES_CONT}
            name = label_to_name.get(w.description, w.description)
            if name in params:
                w.value = params[name]
        
        # Update regime dropdown
        if CAT_COL in params:
            regime_dd_g.value = int(params[CAT_COL])
        
        # Trigger plot update
        update_plot()

    # Create presaved sets row
    presaved_thumbnails = create_presaved_thumbnails()
    presaved_title = widgets.HTML("<h4>Presaved sets (click to apply)</h4>")
    presaved_row = widgets.HBox(presaved_thumbnails, layout=widgets.Layout(
        overflow_x='scroll',
        width='100%',
        height='100px'
    ))

    def gather_values() -> dict:
        vals = {CAT_COL: int(regime_dd_g.value)}
        for w in sliders:
            # reverse label lookup
            label_to_name = {get_axis_label(CONFIG, f): f for f in FEATURES_CONT}
            name = label_to_name.get(w.description, w.description)
            vals[name] = float(w.value)
        return vals

    def update_plot(_=None):
        vals = gather_values()
        try:
            img = kernel.generate_morphology(vals, seed=int(seed_slider.value), scale_by_peak_to_valley=bool(scale_chk.value))
            if img.ndim == 3:
                img = img[0]
            
            # Update 2D plot
            with plot_out_2d:
                clear_output_safely(plot_out_2d)
                # 2D surface plot with config colormap
                fig_2d = px.imshow(img, origin="lower", color_continuous_scale=colorscale)
                fig_2d.update_layout(width=700, height=600, margin=dict(l=0, r=0, t=30, b=0))
                fig_2d.update_coloraxes(colorbar_title="Height (a.u.)")
                
                # Display the plot with proper configuration for both Colab and JupyterLab
                show_plot(fig_2d)
            
            # Update 3D plot
            with plot_out_3d:
                clear_output_safely(plot_out_3d)
                # 3D surface plot if checkbox is checked
                if show_3d_chk.value:
                    # Get image dimensions
                    height, width = img.shape
                    
                    # Create coordinate arrays
                    x_coords = np.arange(width)
                    y_coords = np.arange(height)
                    X, Y = np.meshgrid(x_coords, y_coords)
                    
                    # Create 3D scatter plot
                    """
                    fig_3d = go.Figure(data=go.Scatter3d(
                        x=x_flat,
                        y=y_flat,
                        z=z_flat,
                        mode='markers',
                        marker=dict(
                            size=2,
                            color=z_flat,
                            colorscale=colorscale,
                            opacity=0.8,
                            showscale=True,
                            colorbar=dict(title="Height (a.u.)")
                        )
                    ))
                    """
                    # Create 3D surface plot
                    fig_3d = go.Figure(data=go.Surface(
                        x=X,
                        y=Y,
                        z=img,
                        colorscale=colorscale,
                        opacity=0.9,
                        showscale=True,
                        colorbar=dict(title="Height (a.u.)")
                    ))
                    
                    # Set equal scale for all axes (same range as image dimensions)
                    max_range = max(width, height, img.max() - img.min())
                    fig_3d.update_layout(
                        scene=dict(
                            xaxis=dict(range=[0, max_range]),
                            yaxis=dict(range=[0, max_range]),
                            zaxis=dict(range=[img.min(), img.min() + max_range]),
                            aspectmode='cube'
                        ),
                        width=700,
                        height=600,
                        margin=dict(l=0, r=0, t=30, b=0),
                        title="3D Surface Visualization"
                    )
                    
                    # Display the plot with proper configuration for both Colab and JupyterLab
                    show_plot(fig_3d)
                    
        except Exception as e:
            print(f"Generation failed: {e}")

    for w in sliders:
        w.observe(update_plot, names="value")
    regime_dd_g.observe(update_plot, names="value")
    seed_slider.observe(update_plot, names="value")
    scale_chk.observe(update_plot, names="value")
    show_3d_chk.observe(update_plot, names="value")

    ui = widgets.VBox([
        presaved_title,
        presaved_row,
        widgets.HBox([regime_dd_g, seed_slider, scale_chk, show_3d_chk]),
        widgets.VBox(sliders),
    ])
    display(ui)
    display(plot_out_2d)
    display(plot_out_3d)
    update_plot()


## 10. Inference-time benchmarking

The cell demonstrates per-sample latency across batch sizes for selected models from regressors, classifiers, or the GAN surface generator. This is critical for FEM coupling and reactor optimization, where millions of evaluations may be needed for parameter sweeps or uncertainty quantification.

In [None]:
from __future__ import annotations

from pathlib import Path

import pandas as pd
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display

from demo_utils import clear_output_safely,benchmark_property_inference, load_colormap_from_png, find_colormap_png, pick_equidistant_colors, violin_latency_plot,show_plot

# Create options for both regressors and classifiers
regressor_options = [(get_display_label(CONFIG, k, default=k), k) for k in sorted(kernel.regressors.keys())]
classifier_options = [(get_display_label(CONFIG, k, default=k), k) for k in sorted(kernel.classifiers.keys())]

# Check if GAN generator is available
has_generator = getattr(kernel, "generator", None) is not None
generator_options = [("GAN Morphology Generator", "gan_generator")] if has_generator else []

# Combine options with labels
all_options = [("--- Regressors ---", None)] + regressor_options
if classifier_options:
    all_options.append(("--- Classifiers ---", None))
    all_options.extend(classifier_options)
if generator_options:
    all_options.append(("--- Generators ---", None))
    all_options.extend(generator_options)

# Widgets
prop_dd = widgets.Dropdown(
    options=all_options,
    value=regressor_options[0][1] if regressor_options else None,
    description="Property to benchmark:",
    style={"description_width": "auto"}
)

benchmark_btn = widgets.Button(
    description="Perform Benchmark",
    button_style="primary",
    layout=widgets.Layout(width="200px")
)

# Benchmark parameters
BATCH_SIZES = [1, 10,100,1000,10000,100000]
REPEATS = 50
WARMUP = 100

# Output area
benchmark_out = widgets.Output()

def benchmark_classifier_inference(kernel, classifier_name, batch_sizes, repeats=5, warmup=100):
    """Benchmark classifier inference times across different batch sizes."""
    import time
    import torch
    import numpy as np
    import pandas as pd
    
    # Get feature ranges from config
    def get_feature_ranges():
        ranges = {}
        for fs in CONFIG.get("features", []):
            name = fs.get("name")
            if name in (CONFIG.get("numeric_columns", []) or []) + (CONFIG.get("integer_columns", []) or []):
                vmin = float(fs.get("min", 0.0))
                vmax = float(fs.get("max", 1.0))
                ranges[name] = (vmin, vmax)
        return ranges
    
    feature_ranges = get_feature_ranges()
    
    def generate_random_input():
        """Generate a single random input within allowed ranges."""
        input_data = {}
        for name, (vmin, vmax) in feature_ranges.items():
            if name in (CONFIG.get("integer_columns", []) or []):
                # Integer features
                input_data[name] = int(np.random.randint(int(vmin), int(vmax) + 1))
            else:
                # Float features
                input_data[name] = float(np.random.uniform(vmin, vmax))
        
        # Add categorical feature
        cat_col = CONFIG.get("categorical_column", "Nnucl")
        cat_max = int(next((f.get("max", 2) for f in CONFIG.get("features", []) if f.get("name") == cat_col), 2))
        input_data[cat_col] = int(np.random.randint(0, cat_max + 1))
        
        return input_data
    
    # Determine which classifier method to use
    if classifier_name == "morphology":
        predict_func = kernel.predict_morphology_class
    elif classifier_name == "stability":
        predict_func = kernel.predict_stability
    else:
        raise ValueError(f"Unknown classifier: {classifier_name}")
    
    results = []
    
    for batch_size in batch_sizes:
        # Generate batch data (time not included in benchmark)
        batch_data = [generate_random_input() for _ in range(batch_size)]
        
        # Warmup runs
        for _ in range(warmup):
            try:
                predict_func(batch_data)
            except Exception:
                pass  # Ignore warmup errors
        
        # Benchmark runs
        times = []
        for _ in range(repeats):
            # Generate fresh random data for each run (time not included in benchmark)
            batch_data = [generate_random_input() for _ in range(batch_size)]
            
            start_time = time.perf_counter()
            try:
                predict_func(batch_data)
                end_time = time.perf_counter()
                times.append((end_time - start_time) * 1e6)  # Convert to microseconds
            except Exception as e:
                print(f"Error during benchmark for batch size {batch_size}: {e}")
                times.append(np.nan)
        
        # Calculate statistics
        valid_times = [t for t in times if not np.isnan(t)]
        if valid_times:
            per_pred_time = np.mean(valid_times) / batch_size
            results.append({
                "batch_size": batch_size,
                "per_pred_us": per_pred_time,
                "total_time_us": np.mean(valid_times),
                "std_time_us": np.std(valid_times)
            })
        else:
            results.append({
                "batch_size": batch_size,
                "per_pred_us": np.nan,
                "total_time_us": np.nan,
                "std_time_us": np.nan
            })
    
    return pd.DataFrame(results)

def benchmark_gan_inference(kernel, batch_sizes, repeats=5, warmup=100):
    """Benchmark GAN generator inference times across different batch sizes."""
    import time
    import torch
    import numpy as np
    import pandas as pd
    
    # Get feature ranges from config
    def get_feature_ranges():
        ranges = {}
        for fs in CONFIG.get("features", []):
            name = fs.get("name")
            if name in (CONFIG.get("numeric_columns", []) or []) + (CONFIG.get("integer_columns", []) or []):
                vmin = float(fs.get("min", 0.0))
                vmax = float(fs.get("max", 1.0))
                ranges[name] = (vmin, vmax)
        return ranges
    
    feature_ranges = get_feature_ranges()
    
    def generate_random_input():
        """Generate a single random input within allowed ranges."""
        input_data = {}
        for name, (vmin, vmax) in feature_ranges.items():
            if name in (CONFIG.get("integer_columns", []) or []):
                # Integer features
                input_data[name] = int(np.random.randint(int(vmin), int(vmax) + 1))
            else:
                # Float features
                input_data[name] = float(np.random.uniform(vmin, vmax))
        
        # Add categorical feature
        cat_col = CONFIG.get("categorical_column", "Nnucl")
        cat_max = int(next((f.get("max", 2) for f in CONFIG.get("features", []) if f.get("name") == cat_col), 2))
        input_data[cat_col] = int(np.random.randint(0, cat_max + 1))
        
        return input_data
    
    results = []
    
    for batch_size in batch_sizes:
        # Generate batch data (time not included in benchmark)
        batch_data = [generate_random_input() for _ in range(batch_size)]
        
        # Warmup runs
        for _ in range(warmup):
            try:
                kernel.generate_morphology(batch_data)
            except Exception:
                pass  # Ignore warmup errors
        
        # Benchmark runs
        times = []
        for _ in range(repeats):
            # Generate fresh random data for each run (time not included in benchmark)
            batch_data = [generate_random_input() for _ in range(batch_size)]
            
            start_time = time.perf_counter()
            try:
                kernel.generate_morphology(batch_data)
                end_time = time.perf_counter()
                times.append((end_time - start_time) * 1e6)  # Convert to microseconds
            except Exception as e:
                print(f"Error during GAN benchmark for batch size {batch_size}: {e}")
                times.append(np.nan)
        
        # Calculate statistics
        valid_times = [t for t in times if not np.isnan(t)]
        if valid_times:
            per_pred_time = np.mean(valid_times) / batch_size
            results.append({
                "batch_size": batch_size,
                "per_pred_us": per_pred_time,
                "total_time_us": np.mean(valid_times),
                "std_time_us": np.std(valid_times)
            })
        else:
            results.append({
                "batch_size": batch_size,
                "per_pred_us": np.nan,
                "total_time_us": np.nan,
                "std_time_us": np.nan
            })
    
    return pd.DataFrame(results)

def run_benchmark(_):
    with benchmark_out:
        clear_output_safely(benchmark_out)
        
        if not prop_dd.value:
            print("Please select a property to benchmark.")
            return
            
        # Check if it's a regressor, classifier, or generator
        is_regressor = prop_dd.value in kernel.regressors
        is_classifier = prop_dd.value in kernel.classifiers
        is_generator = prop_dd.value == "gan_generator"
        
        if not (is_regressor or is_classifier or is_generator):
            print(f"Property '{prop_dd.value}' not found in regressors, classifiers, or generators.")
            return
        
        # Print parameters and warning
        print("Benchmark Parameters:")
        if is_generator:
            print(f"  Property: GAN Morphology Generator")
            print(f"  Type: Generator")
        else:
            print(f"  Property: {get_display_label(CONFIG, prop_dd.value, default=prop_dd.value)}")
            print(f"  Type: {'Regressor' if is_regressor else 'Classifier'}")
        print(f"  Batch sizes: {BATCH_SIZES}")
        print(f"  Repeats per batch size: {REPEATS}")
        print(f"  Warmup runs: {WARMUP}")
        
        # Report device
        import torch as _torch
        _dev = "cuda" if _torch.cuda.is_available() else "cpu"
        print(f"  Device: {_dev.upper()}")
        if _dev == "cpu":
            print("  Warning: Running on CPU. Latency will be higher than on GPU.")
        
        print("\nThe benchmark process may take from seconds to minutes depending on settings and hardware...")
        print("Running benchmark...\n")
        
        try:
            # Run benchmark
            if is_regressor:
                latency_df = benchmark_property_inference(kernel, prop_dd.value, BATCH_SIZES, repeats=REPEATS, warmup=WARMUP)
            elif is_classifier:
                # For classifiers, create a custom benchmark
                latency_df = benchmark_classifier_inference(kernel, prop_dd.value, BATCH_SIZES, repeats=REPEATS, warmup=WARMUP)
            else:  # is_generator
                # For GAN generator
                latency_df = benchmark_gan_inference(kernel, BATCH_SIZES, repeats=REPEATS, warmup=WARMUP)
            
            # Colors from PNG (match example style if available)
            colormap_png = find_colormap_png(PROJECT_ROOT, CONFIG)
            if colormap_png is not None:
                cmap = load_colormap_from_png(colormap_png)
                categories = sorted(latency_df["batch_size"].unique().tolist(), key=lambda x: int(x))
                discrete_colors = pick_equidistant_colors(cmap.discrete, len(categories))
            else:
                cmap = None
                discrete_colors = None

            # Fancy violin plot (log y-axis)
            if is_generator:
                title = "GAN Morphology Generator: per-prediction inference time across batch sizes"
            else:
                title = f"{get_display_label(CONFIG, prop_dd.value)}: per-prediction inference time across batch sizes"
            
            fig = violin_latency_plot(
                latency_df,
                title=title,
                discrete_colors=discrete_colors,
            )
            
            # Display the plot with proper configuration for both Colab and JupyterLab
            if fig:
              show_plot(fig)
            
            print("Benchmark completed successfully!")
            
        except Exception as e:
            print(f"Benchmark failed: {e}")

# Wire events
benchmark_btn.on_click(run_benchmark)

# Layout
ui = widgets.VBox([
    widgets.HBox([prop_dd, benchmark_btn]),
])

# Display the UI and then the output widget separately
display(ui)
display(benchmark_out)