# W&B Enterprise Workshop: Advanced Geological AI

This hands-on workshop demonstrates how Weights & Biases acts as a system of record for a conditional diffusion workflow in subsurface modeling, using a fast, simulated setup so you can see results in ~60 minutes.

- Who this is for
  - Geoscientists, ML engineers, and platform teams experimenting with conditional diffusion on seismic-conditioned 3D volumes.
- What you’ll accomplish in ~60 minutes
  - Run a simulated conditional diffusion training loop with live monitoring
  - Inspect 3D volumes (prediction, forward-modeled condition, residual) inline
  - Log validation tables mixing metrics, images, and interactive plots
  - Launch a short hyperparameter sweep with early termination
  - Generate a stakeholder-ready, programmatic report

What to expect
- Runtime: end-to-end in about an hour; the model is simulated for speed.
- Data: uses a pre-versioned dataset artifact from the W&B Registry.
- 3D: PyVista Html renders are logged inline; ipyvolume is optional and can be disabled via config.
- GPU: not required for this simulation.

Where to look in W&B (links appear automatically after the first run)
- Run Overview: live metrics and logs
- System: CPU/GPU/memory correlated with your metrics
- Artifacts: dataset lineage and model checkpoints
- Tables: per-epoch validation table with images, Html, and metrics
- Model Registry: tracked checkpoints with aliases (e.g., “staging”)
- Report: auto-generated, shareable executive summary

Why W&B here (neutral and value-focused)
- Single source of truth: configs, metrics, system telemetry, artifacts, and reports in one place
- First-class lineage: dataset → run → model graph without stitching tools
- Rich media: images, Html 3D, and tables together for SME review
- Governance-ready: model registry with aliases and embedded model cards
- Programmatic reports: auto-updated, stakeholder-friendly reporting

Next step: install requirements and import libraries below.

In [1]:
# Install minimal dependencies for the W&B workshop
# What this does
# - Installs W&B core (`wandb`) and Workspaces (`wandb-workspaces`) for programmatic reports
# - Scientific stack: numpy, tqdm, scikit-image, pillow
# - 3D backends: PyVista (default) + trame exporters; ipyvolume is optional
# - Plotly for 2D/HTML charts; 3D volume defaults to PyVista Html renders
#
# 3D defaults and toggles (set in the training config)
# - enable_3d: True/False           # disable all 3D when False
# - enable_high_fidelity_3d: True/False  # controls PyVista renders
# - enable_ipyvolume: True/False    # optional ipyvolume viewer
#
# Why W&B here
# - Logs rich media (images, Html 3D) alongside metrics and tables for SME review without leaving the run context
# - No GPU required for this simulation (works in local notebooks and Colab)

%pip install wandb numpy tqdm wandb-workspaces plotly pillow scikit-image pyvista ipyvolume ipython-genutils trame trame-vuetify trame-vtk -q

import numpy as np
import wandb
import wandb_workspaces.reports.v2 as wr
from tqdm import tqdm
import time
import os
from pathlib import Path
from typing import List, Tuple, Dict, Any
import json
import plotly.graph_objects as go
import plotly.io as pio
from skimage.metrics import structural_similarity as ssim
from plotly.subplots import make_subplots
import pyvista as pv
import ipyvolume as ipv
from skimage.transform import resize
from textwrap import dedent

# Set PyVista to use an off-screen plotter for notebook environments
pv.set_jupyter_backend(None)




[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## 1. W&B Project Configuration

In this section you’ll authenticate to W&B and set `ENTITY` (team/org) and `PROJECT` (workspace for this workshop). All subsequent runs, artifacts, tables, and reports will be organized here.

- What this does
  - Initializes W&B auth for your user
  - Defines the team `ENTITY` and `PROJECT` used by runs, artifacts, sweeps, and reports
  - Ensures configs and system telemetry are centralized per project for easy comparison

- Where to look in W&B (after the first run)
  - Run Overview: live metrics, configs, and logs
  - System: CPU/GPU/memory automatically captured alongside your metrics
  - Artifacts: dataset and model lineage within this project
  - Tables: per-epoch validation table appears under the run’s media

- Authentication notes
  - By default, runs log to `wandb.ai`. This can be changed by setting WANDB_HOST and WANDB_BASE_URL environment variables.
  - If your API key isn’t set in the environment, the first `wandb.login()` will prompt you to paste it.
  - You can set these environment variables to skip prompts and preselect your workspace:
    - `WANDB_API_KEY` (find it at `https://wandb.ai/authorize`)
    - `WANDB_ENTITY`
    - `WANDB_PROJECT`

- Why W&B here
  - Team workspaces with RBAC keep experiments, configs, and artifacts in one place
  - Built‑in auditability: dataset → run → model lineage without stitching tools

- Try it yourself
  - [ ] Set `ENTITY` to your team and `PROJECT` to this workshop’s project name
  - [ ] Optionally export `WANDB_API_KEY`, `WANDB_ENTITY`, `WANDB_PROJECT` as environment variables
  - [ ] Run the auth cell; confirm you see the “View project/run” links after the first run
  - [ ] Click through to the project page to verify organization and permissions

In [2]:

ENTITY = "wandb_emea"  # Replace with your W&B team entity
PROJECT = "workshop-ex123456789"

# Ensure you have logged in and defined ENTITY and PROJECT in a previous cell
assert "ENTITY" in locals(), "Please define the ENTITY variable"
assert "PROJECT" in locals(), "Please define the PROJECT variable"

# W&B Team Authentication
wandb.login()


print(f"W&B Authentication: Entity = {ENTITY}")
print(f"Project: {PROJECT}")

# Configuration for (simulated) conditional diffusion training
# wandb.config makes hyperparameters a first-class citizen.
# They are saved with every run, ensuring 100% reproducibility and
# enabling powerful, automated hyperparameter sweeps

example_config = {
    # Model Architecture
    "model_architecture": "ConditionalUNet3D",
    "task": "geological_structure_generation",
    "input_modality": "seismic_amplitude",
    "output_modality": "karst_structures",
    
    # Training Parameters
    "epochs": 10,
    "batches_per_epoch": 20,
    "learning_rate": 1e-3,
    "batch_size": 4,
    
    # Diffusion Parameters
    "timesteps": 1000,
    "noise_schedule": "cosine",
    "conditioning_strength": 0.8,
    
    # Geological Domain
    "voxel_resolution": "25m",
    "depth_range": "0-800m",
    "geological_context": "karst_detection",

    # Visualization Parameters
    "enable_3d": True,
    "enable_high_fidelity_3d": True,
    "enable_ipyvolume": False,
}

print("Training configuration loaded")

[34m[1mwandb[0m: Currently logged in as: [33mallanstevenson[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


W&B Authentication: Entity = wandb_emea
Project: workshop-ex123456789
Training configuration loaded


## 2. Simulation and Visualization Helpers

This cell contains the helper functions for our workshop. This is not part of the core W&B integration.

- What this does
  - `simulate_conditional_diffusion_progress`: simulates learning progression for conditional diffusion (X_pred given Y)
  - `simulate_forward_model`: approximates Y_pred = f(X_pred) to evaluate condition-consistency
  - `plot_well_log_comparison`: interactive Plotly well-log comparison (GT vs prediction)
  - `normalize_for_visualization`: stable normalization for images/slices
  - Optional 3D view utilities used later in validation logging

- Where to look in W&B (after first validation epoch)
  - Media: `Viz/diagnostics/central_slice_grid` (2D grid of Y, X, X_pred, Residual)
  - Media: `Viz3D/PyVista_Renders/*` (inline Html 3D renders for prediction and residual)
  - Tables: `val_table/validation_table` (per-epoch rows with images, Html, metrics, and well-log plot)

- 3D backends (defaults and options)
  - Default: PyVista Html renders for a responsive, stable inline 3D experience
  - Optional: ipyvolume can be enabled via config; Plotly 3D (Tier 1/2) is optional and disabled by default for speed
  - You can control 3D via config toggles: `enable_3d`, `enable_high_fidelity_3d`, `enable_ipyvolume`

- Why W&B here
  - Rich media (Html 3D + images) is logged alongside metrics and tables, enabling SME review directly in the run context
  - Results are comparable across runs without managing files or screenshots

- Try it yourself
  - [ ] Set `enable_3d=False` in config to see a lightweight, metrics-only run
  - [ ] Re-enable `enable_high_fidelity_3d=True` to inspect PyVista Html renders
  - [ ] Toggle `enable_ipyvolume=True` to compare 3D backends
  - [ ] Inspect `val_table/validation_table` and expand a row to view all media for a sample

In [3]:
# By returning a wandb.Html object, we can log custom, interactive visualizations
# like this Plotly chart directly into the W&B dashboard. This allows domain experts
# to analyze results without switching contexts or downloading files.
def plot_well_log_comparison(gt_log: np.ndarray, pred_log: np.ndarray, depth: np.ndarray) -> wandb.Html:
    """Creates an interactive Plotly chart comparing ground truth and predicted well logs."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=gt_log, y=depth, mode='lines', name='Ground Truth Log', line=dict(color='black')))
    fig.add_trace(go.Scatter(x=pred_log, y=depth, mode='lines', name='Predicted Log', line=dict(color='crimson', dash='dash')))
    fig.update_layout(
        title="Well Log Comparison",
        xaxis_title="Signal Amplitude",
        yaxis_title="Depth (m)",
        yaxis_autorange='reversed' # Depth increases downwards
    )
    html = pio.to_html(fig)
    return wandb.Html(html)

def generate_synthetic_loss(
    step: int,
    schedule: str = "cosine",
    conditioning_strength: float = 0.8,
    learning_rate: float = 1e-3,
    batch_size: int = 4,
    base_seed: int = 42,
) -> float:
    """
    Parameterized synthetic loss that responds to sweep knobs.
    - schedule: shapes the decay (cosine|linear|sigmoid)
    - conditioning_strength: higher => faster improvement
    - learning_rate: modestly speeds decay
    - batch_size: larger => lower noise
    Deterministic per step via base_seed + step.
    """
    # Progress in [0, 1]
    progress = np.clip(step / 600.0, 0.0, 1.0)

    if schedule == "linear":
        sched = progress
    elif schedule == "sigmoid":
        sched = 1.0 / (1.0 + np.exp(-10.0 * (progress - 0.5)))
    else:  # cosine
        sched = 0.5 - 0.5 * np.cos(np.pi * progress)

    # Faster decay with stronger conditioning and slightly higher LR
    decay_speed = 1.0 + 0.8 * conditioning_strength + 0.3 * np.log10(max(learning_rate, 1e-6) / 1e-3 + 1.0)
    base_curve = 1.5 * np.exp(-3.0 * sched * decay_speed) + 0.05

    # Noise shrinks with batch size
    rng = np.random.RandomState(int(base_seed) + int(step))
    noise = rng.normal(0.0, 0.05 / max(batch_size, 1))

    return float(max(0.005, base_curve + noise))

# Simulation controls: 
# conditioning_strength controls structure emergence
# noise_schedule (‘cosine’, ‘linear’, ‘sigmoid’) controls noise decay shape
# Runs are deterministic per epoch via base_seed
def simulate_conditional_diffusion_progress(
    seismic_condition: np.ndarray,
    karst_target: np.ndarray,
    epoch: int,
    total_epochs: int = 10,
    conditioning_strength: float = 0.8,
    noise_schedule: str = "cosine",
    base_seed: int = 42,
) -> np.ndarray:
    """
    Simulate conditional diffusion learning progression (single-channel).
    - conditioning_strength: scales structure emergence (higher = stronger alignment to karst_target)
    - noise_schedule: 'cosine' | 'linear' | 'sigmoid' controls noise decay shape over epochs
    - Deterministic per epoch via base_seed + epoch
    """
    # Training progress (0.0 at start, 1.0 at end) → used to shape schedule
    progress = epoch / float(max(1, total_epochs - 1))

    # Deterministic seed derived from base_seed and epoch
    np.random.seed(int(base_seed) + int(epoch) * 1000)

    # Schedule shaping (affects how fast noise decreases / structure increases)
    if noise_schedule == "linear":
        schedule_factor = progress
    elif noise_schedule == "sigmoid":
        schedule_factor = 1.0 / (1.0 + np.exp(-12.0 * (progress - 0.5)))
    else:  # "cosine" default
        schedule_factor = 0.5 - 0.5 * np.cos(np.pi * progress)

    # Stage 1 (epochs 0–3): learn basic correlations; high noise
    if epoch <= 3:
        noise_level = 0.8 - (epoch / 3.0) * 0.4          # 0.8 → 0.4
        noise_level *= (1.0 - 0.25 * schedule_factor)    # slightly faster decay with schedule
        structural_learning = (epoch / 3.0) * 0.3        # 0.0 → 0.3
        structural_learning *= conditioning_strength

        structure_mask = (seismic_condition > np.percentile(seismic_condition, 60)).astype(float)
        noise = np.random.normal(0, noise_level, seismic_condition.shape)

        prediction = (
            noise * 0.7 +
            seismic_condition * structural_learning +
            karst_target * structure_mask * 0.1
        )

    # Stage 2 (epochs 4–6): refine geological structures; reduce noise
    elif epoch <= 6:
        stage_progress = (epoch - 4.0) / 2.0             # 0.0 → 1.0
        karst_regions = (
            (seismic_condition > np.percentile(seismic_condition, 40)) &
            (seismic_condition < np.percentile(seismic_condition, 85))
        )

        noise_level = 0.4 - stage_progress * 0.25        # 0.4 → 0.15
        noise_level *= (1.0 - 0.25 * schedule_factor)
        noise = np.random.normal(0, noise_level, seismic_condition.shape)

        geological_understanding = 0.3 + stage_progress * 0.4   # 0.3 → 0.7
        geological_understanding *= conditioning_strength

        prediction = (
            noise * 0.3 +
            karst_target * karst_regions * geological_understanding +
            seismic_condition * 0.2 * (1 - karst_regions)
        )

    # Stage 3 (epochs 7–9): fine-tuning; high accuracy, low noise
    else:
        stage_progress = (epoch - 7.0) / 2.0             # 0.0 → 1.0
        noise_level = 0.15 - stage_progress * 0.10       # 0.15 → 0.05
        noise_level *= (1.0 - 0.25 * schedule_factor)
        noise = np.random.normal(0, noise_level, seismic_condition.shape)

        accuracy = 0.7 + stage_progress * 0.25           # 0.7 → 0.95
        accuracy *= (0.85 + 0.3 * conditioning_strength)
        accuracy = np.clip(accuracy, 0.0, 1.0)

        uncertainty_mask = np.random.random(seismic_condition.shape) < 0.1

        prediction = (
            karst_target * accuracy +
            noise * 0.1 +
            seismic_condition * uncertainty_mask * 0.05
        )

    # Keep realistic value ranges and avoid exact zeros in residuals
    prediction = np.clip(prediction, 0, 1)
    prediction += np.random.uniform(-1e-9, 1e-9, prediction.shape)
    return prediction

def normalize_for_visualization(data: np.ndarray) -> np.ndarray:
    """Normalize geological data for W&B visualization with proper contrast."""
    data_min, data_max = np.min(data), np.max(data)
    if data_max == data_min:
        return np.full_like(data, 0.5)
    normalized = (data - data_min) / (data_max - data_min)
    # Apply gamma correction for better geological feature visibility
    return np.power(normalized, 0.7)

def simulate_forward_model(x_pred: np.ndarray) -> np.ndarray:
    """
    Simulates a forward model y=f(x) to generate a predicted condition y_pred.
    This step is crucial for evaluating the "cycle consistency" of our generator;
    we check if the forward model of our prediction (y_pred) matches the original input condition (y).
    In a real scenario, this would be a physics-based or learned model.
    Here, we simulate it by applying a slight blur and adding minor noise.
    """
    # Apply a simple blurring effect (convolution with a small kernel)
    kernel = np.ones((3, 3, 3)) / 27.0
    # Use scipy for convolution if available, otherwise a simpler method
    try:
        from scipy.ndimage import convolve
        blurred = convolve(x_pred, kernel, mode='reflect')
    except ImportError:
        # Fallback if scipy is not installed
        blurred = x_pred
        
    # Add a small amount of random noise
    noise = np.random.normal(0, 0.02, x_pred.shape)
    y_pred = blurred + noise
    
    return np.clip(y_pred, 0, 1)

### Tier 1 Analysis: Global Sanity Check (Downsampled View)

**The Question:** "Is the model globally stable?" Before diving into details, we need a quick, high-level check of the entire volume.

**The Technique:** A downsampled 3D comparison (e.g., 128³ → 64³) is ideal for fast, low-latency inspection.

Current implementation status
- The Plotly-based downsampled 3D comparison (`create_downsampled_comparison`) is defined but not currently logged due to stability/performance trade-offs in notebook environments. The logging lines are commented out.
- You can still assess global stability using the per-epoch 2D central-slice grid and (if enabled) PyVista Html renders.

Where to look in W&B
- Tables: `val_table/validation_table` (per-epoch rows with images and the well-log Plotly Html)
- Media: if enabled, 3D Html appears under `Viz3D/PyVista_Renders/*` (prediction, residual)

Why W&B here
- Downsampled overviews and detailed media are tracked in the same run context, so reviewers don’t switch tools
- Results remain versioned and comparable across runs and sweeps

Try it yourself
- [ ] Run a short training; open the run and expand a row in `val_table/validation_table`
- [ ] Compare early vs late epochs to see stability changes in the 2D grid (`Viz/diagnostics/central_slice_grid`)
- [ ] Optional: enable 3D (PyVista) via config toggles (`enable_3d=True`, `enable_high_fidelity_3d=True`)
- [ ] Future improvement: re-enable the Plotly 3D logging lines in the validation loop once environment constraints allow

In [4]:
def create_downsampled_comparison(volumes: Dict[str, np.ndarray]) -> go.Figure:
    """Downsamples volumes and returns a side-by-side Plotly figure."""
    target_shape = (64, 64, 64)
    subplot_titles = [name.replace("_", " ") for name in volumes.keys()]
    fig = make_subplots(
        rows=1, cols=len(volumes),
        specs=[[{'type': 'volume'}] * len(volumes)],
        subplot_titles=subplot_titles
    )
    for i, (name, data) in enumerate(volumes.items()):
        resized_data = resize(data, target_shape, anti_aliasing=True)
        vmin, vmax = float(np.min(resized_data)), float(np.max(resized_data))
        if vmax == vmin: vmax = vmin + 1.0
        fig.add_trace(
            go.Volume(
                value=resized_data,                 # pass 3D array directly
                cmin=vmin, cmax=vmax,
                isomin=vmin + (vmax - vmin) * 0.02,
                isomax=vmax - (vmax - vmin) * 0.02,
                opacity=0.15, opacityscale="uniform",
                surface_count=12,
                colorscale='RdBu' if ('Residual' in name or 'Error' in name) else 'viridis',
            ),
            row=1, col=i + 1
        )
    fig.update_layout(title_text="Downsampled 3D Comparison", height=420, margin=dict(t=50, b=10, l=10, r=10))
    return fig


### Tier 2 Analysis: Fine-Detail Inspection (Cropped View)

**The Question:** "Now that the global structure looks right, is the model accurately preserving fine details?"

**The Technique:** We extract a smaller cube (e.g., 64³) from the center of the original, full‑resolution volumes. This allows inspection of a specific region without downsampling.

Current implementation status
- The Plotly-based cropped 3D comparison (`create_cropped_comparison`) is defined but its logging is currently commented out due to notebook environment stability/performance constraints.
- This function is not connected to the validation table. The table currently logs 2D slices and a well‑log Plotly Html produced elsewhere.

Where to look in W&B (available today)
- Media: `Viz/diagnostics/central_slice_grid` for 2D central‑slice grids (Y, X, X_pred, Residual)
- Media: `Viz3D/PyVista_Renders/*` for inline Html 3D renders (prediction, residual) if 3D is enabled

Why W&B here
- Detailed media and metrics live in the same run context, enabling SME review without switching tools
- Results remain versioned and comparable across runs and sweeps

Try it yourself
- [ ] Enable PyVista 3D via config (`enable_3d=True`, `enable_high_fidelity_3d=True`) and compare epochs 0 / mid / last
- [ ] Inspect `Viz/diagnostics/central_slice_grid` in the Media panel for texture/sharpness
- [ ] Future improvement: re‑enable the Plotly cropped 3D logging lines when environment constraints allow

In [5]:
def create_cropped_comparison(volumes: Dict[str, np.ndarray]) -> go.Figure:
    """Crops the center of volumes and returns a side-by-side Plotly figure."""
    crop_size = 64
    subplot_titles = [name.replace("_", " ") for name in volumes.keys()]
    fig = make_subplots(
        rows=1, cols=len(volumes),
        specs=[[{'type': 'volume'}] * len(volumes)],
        subplot_titles=subplot_titles
    )
    for i, (name, data) in enumerate(volumes.items()):
        if all(dim >= crop_size for dim in data.shape):
            center = [dim // 2 for dim in data.shape]
            start = [c - crop_size // 2 for c in center]
            cropped_data = data[
                start[0]:start[0]+crop_size,
                start[1]:start[1]+crop_size,
                start[2]:start[2]+crop_size
            ]
        else:
            cropped_data = data
        vmin, vmax = float(np.min(cropped_data)), float(np.max(cropped_data))
        if vmax == vmin: vmax = vmin + 1.0
        fig.add_trace(
            go.Volume(
                value=cropped_data,                # pass 3D array directly
                cmin=vmin, cmax=vmax,
                isomin=vmin + (vmax - vmin) * 0.02,
                isomax=vmax - (vmax - vmin) * 0.02,
                opacity=0.15, opacityscale="uniform",
                surface_count=12,
                colorscale='RdBu' if ('Residual' in name or 'Error' in name) else 'viridis',
            ),
            row=1, col=i + 1
        )
    fig.update_layout(title_text="Cropped (Full-Res) 3D Comparison", height=420, margin=dict(t=50, b=10, l=10, r=10))
    return fig

### Tier 3 Analysis: Debugging with the Residual (High-Fidelity Render)

**The Question:** "Where, specifically, is my model failing?" To improve the model, we must understand the nature and location of its errors.

**The Technique:** We use a high-performance library like PyVista (default) or ipyvolume (optional) to create the best possible render of the most critical volume: the **residual** (`AI_Condition_Y_pred - Input_Condition_Y`). This visualizes error structure in 3D.

**What We Look For:**
- **Systematic Errors:** Are large errors concentrated in a specific geological layer or region?
- **Error Polarity:** Over‑prediction (positive residual) vs under‑prediction (negative residual).
- **Interactivity:** Peel back layers via opacity to locate problematic horizons.

Current implementation status
- PyVista Html renders are generated (0 / mid / last epoch) when `enable_3d=True` and `enable_high_fidelity_3d=True`.
- ipyvolume can be enabled via `enable_ipyvolume=True` (optional).
- Plotly 3D volume logging is intentionally disabled in this workshop for stability/performance in notebooks.

Where to look in W&B
- Media: `Viz3D/PyVista_Renders/*` (AI prediction and residual 3D Html)
- Media: `Viz/diagnostics/central_slice_grid` (compact 2D context per epoch)

Why W&B here
- Inline 3D Html is logged alongside metrics/tables, enabling SME review in the run context
- Results are versioned, comparable across runs/sweeps, and traceable via lineage

Try it yourself
- [ ] Ensure `enable_3d=True` and `enable_high_fidelity_3d=True`; compare epochs 0 / mid / last
- [ ] Toggle `enable_ipyvolume=True` to compare backends
- [ ] Use the residual render to target misfit regions for data curation or hyperparameter changes

In [6]:
def create_pyvista_render(volume: np.ndarray, title: str) -> str:
    """
    Render a single volume with PyVista and return HTML.
    - Downsamples large volumes to speed export
    - Uses ImageData; compatible with older PyVista versions
    """
    html_filename = f"{title.replace(' ', '_').replace('(', '').replace(')', '')}_temp.html"
    try:
        # Downsample aggressively if needed (keep <= 96^3)
        max_dim = 96
        dz, dy, dx = map(int, volume.shape)
        scale = max(1, max(dz, dy, dx) // max_dim)
        target = (max(1, dz // scale), max(1, dy // scale), max(1, dx // scale))
        vol_ds = volume if (dz, dy, dx) == target else resize(volume, target, anti_aliasing=True)
        vol_ds = np.asarray(vol_ds, dtype=np.float32, order="F")

        # Build an ImageData grid (works across PyVista versions)
        nz, ny, nx = vol_ds.shape  # z, y, x
        grid = pv.ImageData(dimensions=(nx, ny, nz))  # note order: x, y, z
        grid.spacing = (1.0, 1.0, 1.0)
        grid.origin = (0.0, 0.0, 0.0)
        grid.point_data["values"] = vol_ds.flatten(order="F")

        plotter = pv.Plotter(off_screen=True, notebook=True, window_size=(800, 600))
        plotter.add_volume(
            grid,
            scalars="values",
            cmap="RdBu" if ("Residual" in title or "Error" in title) else "viridis",
            opacity="sigmoid",
            shade=True,
        )
        plotter.add_axes()
        plotter.export_html(html_filename)
        html = Path(html_filename).read_text()
        plotter.close()
        return html
    except Exception as e:
        return f"<p>PyVista rendering failed: {e}</p>"
    finally:
        try:
            if os.path.exists(html_filename):
                os.remove(html_filename)
        except Exception:
            pass

## 3. The Core Training Function with W&B Integration

This function encapsulates the entire training and evaluation workflow. It serves as a blueprint for a reproducible, auditable, and transparent ML lifecycle. Each section is clearly marked to show how a specific W&B feature is used to track, version, and evaluate the model, turning a standard training script into an enterprise-ready system of record.

Notation and channel scope  
X: ground-truth geology (karst), X_pred: generated prediction  
Y: input seismic condition, Y_pred = f(X_pred)  
Residual = Y_pred − Y  

Single-channel visuals: This demo renders a single channel for stability and speed. Multi-channel rendering/metrics are feasible (loop over channels and log per-channel), but intentionally out of scope here.

- What this does
  - Initializes a W&B run with `ENTITY`/`PROJECT`, config, tags
  - Declares and downloads a versioned dataset artifact (lineage: dataset → run)
  - Simulates training and logs live metrics; correlates with system telemetry
  - Per-epoch validation: logs a rich `val_table/validation_table` (images + well‑log Html) and, if enabled, 3D Html renders (0/mid/last)
  - Tracks best validation loss and versions a model checkpoint as an artifact; links it in the Model Registry with alias (e.g., `staging`)
  - Finishes the run cleanly

- Where to look in W&B
  - Run Overview: live metrics and config
  - System: CPU/GPU/memory correlated with training metrics
  - Artifacts: dataset (`CigKarst:v0`) and model checkpoint (`conditional-diffusion-checkpoint`)
  - Media: `Viz/diagnostics/central_slice_grid`, `Viz3D/PyVista_Renders/*` (if 3D enabled)
  - Tables: `val_table/validation_table` (per-epoch rows with images + well‑log Plotly Html)
  - Model Registry: model registered with an alias (e.g., `staging`) for retrieval/CI

- Why W&B here
  - End-to-end lineage and governance: dataset → runs → models, without stitching tools
  - Rich media (images + Html 3D) and tables live in the same run context for SME review
  - Model Registry aliases provide stable pointers for promotion and CI/CD workflows

- Try it yourself
  - [ ] Confirm `ENTITY`/`PROJECT` are set; run the cell
  - [ ] Open the run link; monitor live metrics and the System tab
  - [ ] Expand a row in `val_table/validation_table` to inspect images and well‑log Html
  - [ ] If 3D is enabled, open `Viz3D/PyVista_Renders/*` (0/mid/last)
  - [ ] Find the model artifact and verify the registry alias (e.g., `staging`)

### About the central slice grid (per‑epoch 2D diagnostic)

- What it is
  - A compact 2×2 view from the central slice of each 3D volume:
    - Top: Y (seismic input), X (ground‑truth geology)
    - Bottom: X_pred (prediction), Residual (X_pred − X) normalized around 0
- Why it matters
  - Fast, per‑epoch qualitative check of conditioning consistency and structural fidelity without rendering full 3D.
  - Lets SMEs see whether structure emerges and residuals shrink/localize as training progresses.
- How to read it
  - Compare X vs X_pred for texture and boundaries; the residual should decrease and concentrate around true misfit regions.
  - Large, structured residuals suggest systematic errors (bias, misalignment, or forward‑model mismatch).
- Tips
  - If residuals are uniformly high: revisit conditioning strength or forward‑model assumptions.
  - If checkerboard artifacts appear: check upsampling/architecture or normalization.
  - For localized issues, open the corresponding PyVista Html render under `Viz3D/PyVista_Renders/*`.

In [7]:
def train_conditional_diffusion(config=None):
    """
    Train a conditional diffusion model for geological structure generation.
    This function is instrumented with W&B to track the entire lifecycle.
    """
    # ===================================================================
    # 1. W&B SETUP: Initialize Run and Centralize Configuration
    #    - wandb.init() starts a new run to track this experiment.
    #    - wandb.config pulls hyperparameter settings, making them available to sweeps.
    # ===================================================================
 

    # Initialize a new W&B run
    run = wandb.init(
        entity=ENTITY,
        project=PROJECT,
        config=(config or example_config), # use config if it exists, otherwise use example_config (useful for sweeps)
        # job_type and tags are powerful organizational tools that make it easy
        # to filter, group, and compare runs across a large project in the W&B UI
        job_type="training",
        tags=["conditional-diffusion"]
    )

    # Use the config from W&B (this allows sweeps to override values)
    config = wandb.config # this is a dictionary of the hyperparameters
    np.random.seed(config.get("seed", 42))  # set global seed once

    # ===================================================================
    # 2. W&B ARTIFACTS: Versioning and Loading the Dataset
    #    - run.use_artifact() declares a dependency on a specific dataset version.
    #    - This creates a lineage graph, giving us a full audit trail.
    # ===================================================================
    # Load CigKarst dataset from W&B Registry
    
    
    # run.use_artifact() creates a direct link to the dataset version used for this run.
    # This automatically generates a complete, visual data-to-model lineage graph,
    # which is essential for auditability and debugging.
    artifact = run.use_artifact('wandb-registry-dataset/CigKarst:latest', type='dataset')
    artifact_dir = artifact.download()

    # Load geological samples from the dataset
    geological_samples = []

    metadata_path = Path(artifact_dir) / "dataset_metadata.json"

    if metadata_path.exists():
        with open(metadata_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        samples = meta.get("samples", [])
        total_samples = int(meta.get("total_samples", 0))
        for sample_info in samples:
            seismic_file = Path(artifact_dir) / sample_info["seismic_file"]
            karst_file = Path(artifact_dir) / sample_info["karst_file"]
            if seismic_file.exists() and karst_file.exists():
                seismic_data = np.load(seismic_file)
                karst_data = np.load(karst_file)
                # Sanity checks for determinism and structure
                assert seismic_data.dtype == np.float32 and karst_data.dtype == np.float32, (seismic_data.dtype, karst_data.dtype)
                assert seismic_data.shape == (64, 64, 64) and karst_data.shape == (64, 64, 64), (seismic_data.shape, karst_data.shape)
                sample = {
                    "seismic": seismic_data,
                    "karst": karst_data,
                    "sample_id": sample_info["sample_id"],
                    "coordinates": sample_info["coordinates"],
                    "source": sample_info["source_volume"],
                }
                geological_samples.append(sample)
        assert len(geological_samples) == total_samples, (len(geological_samples), total_samples)
        print(f"Loaded {len(geological_samples)} real geological samples from CigKarst")
    else:
        print("Metadata file not found in artifact")
        run.finish()
        return None

    # Select a fixed set of samples for consistent validation across runs
    good_sample_ids = ["volume_0_patch_1", "volume_0_patch_3", "volume_1_patch_0"]
    fixed_samples = [s for s in geological_samples if s['sample_id'] in good_sample_ids]
    if len(fixed_samples) != len(good_sample_ids):
        raise ValueError(f"Error: Could not find all specified good samples.")
    print(f"Selected {len(fixed_samples)} pre-validated samples for tracking.")

    # Track best model
    best_validation_loss = float('inf')

    # ===================================================================
    # 3. W&B TRAINING LOOP: Logging Live Metrics
    #    - run.log() is called inside the loop to stream metrics in real-time.
    #    - This allows us to monitor model performance and system utilization live from the dashboard.
    # ===================================================================
    for epoch in range(config.get("epochs", 10)):
        print(f"Epoch {epoch}: Simulating conditional diffusion training...")
        batches_per_epoch = config.get("batches_per_epoch", 20)
        
        for batch in range(batches_per_epoch):
            step = epoch * batches_per_epoch + batch
            # Simulate training progress and calculate synthetic losses
            noise_pred_loss = generate_synthetic_loss(
            step,
            schedule=str(config.get("noise_schedule", "cosine")),
            conditioning_strength=float(config.get("conditioning_strength", 0.8)),
            learning_rate=float(config.get("learning_rate", 1e-3)),
            batch_size=int(config.get("batch_size", 4)),
            base_seed=int(config.get("seed", 42)),
            )
            reconstruction_proxy = generate_synthetic_loss(
            int(step * 1.5),
            schedule=str(config.get("noise_schedule", "cosine")),
            conditioning_strength=float(config.get("conditioning_strength", 0.8)),
            learning_rate=float(config.get("learning_rate", 1e-3)),
            batch_size=int(config.get("batch_size", 4)),
            base_seed=int(config.get("seed", 42)),
             )
            total_loss = 0.6 * noise_pred_loss + 0.4 * reconstruction_proxy

            # Prepare log dictionary for training metrics
            # Training-time proxy metrics (synthetic): reconstruction_mse approximates X_pred vs X; true condition-consistency is computed at validation.
            log_dict = {
            "train/noise_prediction_loss": noise_pred_loss,
            "train/reconstruction_mse": reconstruction_proxy,  # X_pred vs X (synthetic proxy during training)
            "train/total_loss": total_loss,
            "train/learning_rate": config.get("learning_rate", 1e-3) * (0.95 ** (epoch // 3))
            }
            
            # SINGLE wandb.log call per training step. W&B automatically captures system
            # metrics (CPU/GPU utilization, memory) alongside your custom metrics, providing
            # a holistic view to diagnose performance bottlenecks in real-time.
            run.log(log_dict)

        # ===================================================================
        # 4. W&B VALIDATION: Logging Rich Media and Tables
        #    - At the end of each epoch, we log qualitative results.
        #    - wandb.Image() for visual comparison.
        #    - wandb.Html() for custom, interactive Plotly charts.
        #    - wandb.Table() to create structured, sortable tables of predictions.
        # ===================================================================
        
        
        # wandb.Table creates a rich, interactive table in the UI. This allows for
        # sorting and filtering results, and comparing images, plots, and metrics
        # side-by-side across different runs—all within a single view.
        validation_table = wandb.Table(columns=[
        "epoch", "sample_id", "seismic_input", "ground_truth",
        "prediction", "residual_map", "well_log_comparison",
        "reconstruction_mse", "condition_consistency_mse", "ssim_score", "log_correlation"
        ])
            
        total_reconstruction_mse = 0.0
        total_condition_consistency_mse = 0.0

        # --- NEW: Create visualizations for a small BATCH of validation samples ---
        # We will create a unique key for each visualization in the log dictionary.
        visualizations_log = {}
        
        for sample in fixed_samples:
            
            # Generate the necessary volumes for this specific sample
            prediction_3d = simulate_conditional_diffusion_progress(
            sample['seismic'],
            sample['karst'],
            epoch,
            total_epochs=int(config.get("epochs", 10)),
            conditioning_strength=float(config.get("conditioning_strength", 0.8)),
            noise_schedule=str(config.get("noise_schedule", "cosine")),
            base_seed=int(config.get("seed", 42)),
            )

            y_pred_3d = simulate_forward_model(prediction_3d)
            volumes_for_viz = {
                "Input_Seismic_(Y)": sample['seismic'],
                "Ground_Truth_Karst_(X)": sample['karst'],
                "Predicted_Karst_(X_pred)": prediction_3d,
                "Forward_Model_Seismic_(Y_pred)": y_pred_3d,
                "Model_Error_(Residual)": y_pred_3d - sample['seismic']
            }
            
            #start
            # 3D: zero/mid/last only; single-channel; PyVista default; ipyvolume off by default.
            epochs_total = int(config.get("epochs", 10))
            mid_epoch = (epochs_total - 1) // 2
            should_log_3d = bool(config.get("enable_3d", True)) and (epoch in [0, mid_epoch, epochs_total - 1])

            if should_log_3d and sample['sample_id'] == fixed_samples[0]['sample_id']:
                volumes_for_viz = {
                    "Predicted_Karst_(X_pred)": prediction_3d,
                    "Forward_Model_Seismic_(Y_pred)": y_pred_3d,
                    "Model_Error_(Residual)": y_pred_3d - sample['seismic'],
                }
                ''' Removing the plotly figures for now
                ds_fig = create_downsampled_comparison({
                    "Predicted_Karst_(X_pred)": volumes_for_viz["Predicted_Karst_(X_pred)"],
                    "Forward_Model_Seismic_(Y_pred)": volumes_for_viz["Forward_Model_Seismic_(Y_pred)"],
                    "Model_Error_(Residual)": volumes_for_viz["Model_Error_(Residual)"],
                })
                visualizations_log["3D/Tier1_Downsampled"] = wandb.Plotly(ds_fig)

                cr_fig = create_cropped_comparison({
                    "Predicted_Karst_(X_pred)": volumes_for_viz["Predicted_Karst_(X_pred)"],
                    "Forward_Model_Seismic_(Y_pred)": volumes_for_viz["Forward_Model_Seismic_(Y_pred)"],
                    "Model_Error_(Residual)": volumes_for_viz["Model_Error_(Residual)"],
                })
                visualizations_log["3D/Tier2_Cropped"] = wandb.Plotly(cr_fig)
                '''
                if bool(config.get("enable_high_fidelity_3d", True)):
                    visualizations_log["Viz3D/PyVista_Renders/AI_Prediction"] = wandb.Html(
                        create_pyvista_render(volumes_for_viz["Predicted_Karst_(X_pred)"], "PV_Prediction")
                    )
                    visualizations_log["Viz3D/PyVista_Renders/Model_Error"] = wandb.Html(
                        create_pyvista_render(volumes_for_viz["Model_Error_(Residual)"], "PV_Error")
                    )

                if bool(config.get("enable_ipyvolume", False)):
                    visualizations_log["Viz3D/ipyvolume_Renders/AI_Prediction"] = wandb.Html(
                        create_ipyvolume_render(volumes_for_viz["Predicted_Karst_(X_pred)"], "IPV_Prediction")
                    )
                    visualizations_log["Viz3D/ipyvolume_Renders/Model_Error"] = wandb.Html(
                        create_ipyvolume_render(volumes_for_viz["Model_Error_(Residual)"], "IPV_Error")
                    )

            # Prepare data for logging (2D slices, logs, etc.)
            # Single-channel rendering: we use the scalar volume (or channel 0) for slices.
            # Multi-channel is feasible later by indexing and looping channels.
            slice_idx = sample['seismic'].shape[2] // 2
            seismic_slice = sample['seismic'][:, :, slice_idx]
            gt_slice = sample['karst'][:, :, slice_idx]
            pred_slice = prediction_3d[:, :, slice_idx]
            residual_slice = pred_slice - gt_slice
            max_abs_val = np.max(np.abs(residual_slice))
            residual_norm = (residual_slice + max_abs_val) / (2 * max_abs_val) if max_abs_val > 0 else np.zeros_like(residual_slice)

            # Per-epoch 2D diagnostic: compact 2x2 central-slice grid for one representative sample (single-channel).
            if sample['sample_id'] == fixed_samples[0]['sample_id']:
                grid_top = np.hstack([
                    normalize_for_visualization(seismic_slice),  # Y
                    normalize_for_visualization(gt_slice)        # X
                ])
                grid_bottom = np.hstack([
                    normalize_for_visualization(pred_slice),     # X_pred
                    residual_norm                                # Residual (already normalized to 0..1 around 0)
                ])
                grid_img = np.vstack([grid_top, grid_bottom])

                # Stable key (no epoch in the key): W&B will store history by step
                visualizations_log["Viz/diagnostics/central_slice_grid"] = wandb.Image(
                    grid_img,
                    caption="Top: Y, X | Bottom: X_pred, Residual"
                )

                # Residual_Y (Y_pred - Y) at central x-slice (sagittal)
                y_pred_slice = y_pred_3d[:, :, slice_idx]
                residual_y_slice = y_pred_slice - seismic_slice
                max_abs_val_y = float(np.max(np.abs(residual_y_slice)))
                residual_y_norm = (residual_y_slice + max_abs_val_y) / (2.0 * max_abs_val_y) if max_abs_val_y > 0 else np.zeros_like(residual_y_slice)
                visualizations_log["Viz/diagnostics/residual_Yslice"] = wandb.Image(residual_y_norm, caption="Residual Y (Y_pred - Y), central sagittal slice")

                    # Orientation indices
                z_mid = sample['seismic'].shape[0] // 2
                y_mid = sample['seismic'].shape[1] // 2
                x_mid = slice_idx  # already computed

                # Helper to build a 2x2 grid for a given slicer
                def _grid_from_slices(y2d, x2d, xpred2d, label=""):
                    res = xpred2d - x2d
                    m = float(np.max(np.abs(res)))
                    res_norm = (res + m) / (2.0 * m) if m > 0 else np.zeros_like(res)
                    top = np.hstack([normalize_for_visualization(y2d), normalize_for_visualization(x2d)])
                    bottom = np.hstack([normalize_for_visualization(xpred2d), res_norm])
                    return np.vstack([top, bottom])

                # Axial (constant z)
                axial_grid = _grid_from_slices(
                    sample['seismic'][z_mid, :, :],
                    sample['karst'][z_mid, :, :],
                    prediction_3d[z_mid, :, :],
                    "axial",
                )
                visualizations_log["Viz/slices/axial_central"] = wandb.Image(axial_grid, caption="Axial: Y, X | X_pred, Residual_X")

                # Coronal (constant y)
                coronal_grid = _grid_from_slices(
                    sample['seismic'][:, y_mid, :],
                    sample['karst'][:, y_mid, :],
                    prediction_3d[:, y_mid, :],
                    "coronal",
                )
                visualizations_log["Viz/slices/coronal_central"] = wandb.Image(coronal_grid, caption="Coronal: Y, X | X_pred, Residual_X")

                # Sagittal (constant x) — complements the existing central grid
                sagittal_grid = _grid_from_slices(
                    sample['seismic'][:, :, x_mid],
                    sample['karst'][:, :, x_mid],
                    prediction_3d[:, :, x_mid],
                    "sagittal",
                )
                visualizations_log["Viz/slices/sagittal_central"] = wandb.Image(sagittal_grid, caption="Sagittal: Y, X | X_pred, Residual_X")

                # XZ section at y_mid: Residual_Y (Y_pred - Y)
                xz_residual_y = y_pred_3d[:, y_mid, :] - sample['seismic'][:, y_mid, :]
                m_xz = float(np.max(np.abs(xz_residual_y)))
                xz_residual_y_norm = (xz_residual_y + m_xz) / (2.0 * m_xz) if m_xz > 0 else np.zeros_like(xz_residual_y)
                visualizations_log["Viz/sections/xz_central"] = wandb.Image(xz_residual_y_norm, caption="XZ Residual Y (Y_pred - Y) @ y_mid")

                # YZ section at x_mid: Residual_Y (Y_pred - Y)
                yz_residual_y = y_pred_3d[:, :, x_mid] - sample['seismic'][:, :, x_mid]
                m_yz = float(np.max(np.abs(yz_residual_y)))
                yz_residual_y_norm = (yz_residual_y + m_yz) / (2.0 * m_yz) if m_yz > 0 else np.zeros_like(yz_residual_y)
                visualizations_log["Viz/sections/yz_central"] = wandb.Image(yz_residual_y_norm, caption="YZ Residual Y (Y_pred - Y) @ x_mid")

                # Axial montage: 5 evenly spaced z-slices of Residual_X (X_pred - X)
                nz = prediction_3d.shape[0]
                z_indices = [int(round(p * (nz - 1))) for p in [0.2, 0.35, 0.5, 0.65, 0.8]]

                def _norm_residual_x_at_z(z_idx):
                    res = prediction_3d[z_idx, :, :] - sample['karst'][z_idx, :, :]
                    m = float(np.max(np.abs(res)))
                    return (res + m) / (2.0 * m) if m > 0 else np.zeros_like(res)

                tiles = [_norm_residual_x_at_z(z) for z in z_indices]
                axial_montage = np.hstack(tiles)  # single wide image for scrubber
                visualizations_log["Viz/slices/axial_tiled_5z"] = wandb.Image(axial_montage, caption=f"Residual_X axial montage z={z_indices}")
            else:
                # Same grid under a per-sample stable key
                grid_top = np.hstack([
                    normalize_for_visualization(seismic_slice),
                    normalize_for_visualization(gt_slice)
                ])
                grid_bottom = np.hstack([
                    normalize_for_visualization(pred_slice),
                    residual_norm
                ])
                grid_img = np.vstack([grid_top, grid_bottom])

                visualizations_log[f"Viz/diagnostics/central_slice_grid/sample:{sample['sample_id']}"] = wandb.Image(
                    grid_img, caption=f"Top: Y, X | Bottom: X_pred, Residual_X (sample {sample['sample_id']})"
                )
            
            # Metrics
            reconstruction_mse = float(np.mean((pred_slice - gt_slice) ** 2))
            condition_consistency_mse = float(np.mean((y_pred_3d - sample['seismic']) ** 2))
            total_reconstruction_mse += reconstruction_mse
            total_condition_consistency_mse += condition_consistency_mse
            ssim_score = ssim(gt_slice, pred_slice, data_range=1.0)
            
            # Well Logs
            well_log_depth = np.arange(prediction_3d.shape[0]) * 25
            well_x, well_y = prediction_3d.shape[1] // 2, prediction_3d.shape[2] // 2
            gt_well_log = sample['karst'][:, well_x, well_y]
            pred_well_log = prediction_3d[:, well_x, well_y]
            log_correlation = np.corrcoef(gt_well_log, pred_well_log)[0, 1] if np.std(gt_well_log) > 0 and np.std(pred_well_log) > 0 else 0.0
            
            # Add data to the validation table
            validation_table.add_data(
            epoch,
            sample['sample_id'],
            wandb.Image(normalize_for_visualization(seismic_slice)),
            wandb.Image(normalize_for_visualization(gt_slice)),
            wandb.Image(normalize_for_visualization(pred_slice)),
            wandb.Image(residual_norm, caption="Residual Map"),
            plot_well_log_comparison(gt_well_log, pred_well_log, well_log_depth),
            reconstruction_mse,                 # X_pred vs X (central slice)
            condition_consistency_mse,          # Y_pred vs Y (full 3D)
            ssim_score,
            log_correlation
            )

        # Log validation metrics ONCE per epoch
        # Log all epoch-level data in a single call
        avg_recon = total_reconstruction_mse / len(fixed_samples)
        avg_cond  = total_condition_consistency_mse / len(fixed_samples)

        # Validation loss tied to actual validation metrics
        val_total_loss = 0.45 * avg_recon + 0.55 * avg_cond
        epoch_log_data = {"val/total_loss": val_total_loss, "val/avg_reconstruction_mse": total_reconstruction_mse / len(fixed_samples),
        "val/avg_condition_consistency_mse": total_condition_consistency_mse / len(fixed_samples), "val_table/validation_table": validation_table}
        epoch_log_data.update(visualizations_log) # Add the dictionary of 3D views
        run.log(epoch_log_data)

        # ===================================================================
        # 5. W&B MODEL REGISTRY: Versioning Models as Artifacts
        #    - We check if the model has improved and, if so, save it.
        #    - wandb.Artifact() creates a versioned package of model files and metadata.
        #    - run.link_artifact() registers the model in the W&B Model Registry,
        #      assigning it an alias like "best" or "staging" for easy retrieval.
        # ===================================================================
        if val_total_loss < best_validation_loss:
            best_validation_loss = val_total_loss
            print(f"New best model at epoch {epoch}! Validation loss: {val_total_loss:.4f}")
            
            # Build a concise model card as markdown; this will render as the artifact/model “card”
            model_card_md = dedent(f"""\
            # Model Card: Conditional Diffusion (Simulated)

            ## Overview
            - Task: Geological structure generation (conditional diffusion)
            - Notation: X (GT karst), X_pred (generated), Y (seismic), Y_pred = f(X_pred), Residual = Y_pred − Y
            - Demo visuals: single-channel; 3D at 0/mid/last; 2D grid every epoch

            ## Data & Lineage
            - Dataset artifact: `{artifact.name}`
            - Validation samples: {len(fixed_samples)} (pre-validated)

            ## Training Summary (this version)
            - Epoch: {epoch}
            - val/total_loss: {val_total_loss:.6f}
            - val/avg_reconstruction_mse: {avg_recon:.6f}
            - val/avg_condition_consistency_mse: {avg_cond:.6f}

            ## Key Config
            - epochs={config.get('epochs')}, batches_per_epoch={config.get('batches_per_epoch')}
            - learning_rate={config.get('learning_rate')}, batch_size={config.get('batch_size')}
            - noise_schedule={config.get('noise_schedule')}, conditioning_strength={config.get('conditioning_strength')}
            - seed={config.get('seed')}

            ## Limitations
            - Simulated training and losses for reproducible demo
            - Embedded 3D is downsampled for responsiveness
            """).strip()
            
            checkpoint_artifact = wandb.Artifact(
                name="conditional-diffusion-checkpoint",
                type="model",
                description=model_card_md,  # <- model card as markdown
                metadata={
                    "epoch": epoch,
                    "validation_loss": round(val_total_loss, 4),
                    "dataset_artifact": artifact.name
                }
            )
            
            checkpoint_path = f"best_model_epoch_{epoch}.pth"
            Path(checkpoint_path).write_text(f"Best model checkpoint at epoch {epoch}")
            checkpoint_artifact.add_file(checkpoint_path)
            
            # Log the artifact and link it to the registry
            logged_artifact = run.log_artifact(checkpoint_artifact, aliases=["best"])
           
            # run.link_artifact() registers the model in the W&B Model Registry. Aliases
            # like "staging" or "production" create pointers for CI/CD systems, automating
            # the path from training to deployment.
            run.link_artifact(
                artifact=logged_artifact,
                target_path="wandb-registry-model/conditional-diffusion",
                aliases=["staging"]
            )
            os.remove(checkpoint_path)

    # ===================================================================
    # 6. W&B CLEANUP: Finalizing the Run
    #    - run.finish() marks the run as complete and uploads any remaining data.
    # ===================================================================
    run.finish()

## 4. Execute a Single Training Run

Run a baseline training to generate live metrics, validation media, and a versioned model artifact.

- What this does
  - Starts a W&B run with your config and tags
  - Streams training metrics; correlates with system telemetry
  - Logs per‑epoch validation media:
    - `Viz/diagnostics/central_slice_grid` (2D: Y, X, X_pred, Residual)
    - Optional 3D Html (0 / mid / last) when `enable_3d=True` and `enable_high_fidelity_3d=True`
  - Tracks best validation loss and versions a checkpoint artifact; links it in the Model Registry (e.g., alias `staging`)

- Where to look in W&B
  - Run Overview: metrics and config
  - System: CPU/memory correlated with metrics
  - Media: `Viz/diagnostics/central_slice_grid`, `Viz/PyVista_Renders/*` (if 3D enabled)
  - Tables: `val_table/validation_table` (images + well‑log Plotly Html)
  - Artifacts/Registry: `conditional-diffusion-checkpoint` with alias

- Why W&B here
  - Live metrics + rich media + artifacts in one place, with lineage and aliases for governance

- Try it yourself
  - [ ] Run the cell; click the “View run” link
  - [ ] Expand rows in `val_table/validation_table` to inspect media
  - [ ] If 3D is enabled, open `Viz/PyVista_Renders/*` at epochs 0 / mid / last
  - [ ] Locate the model artifact and verify its registry alias

In [8]:
# Train with default config
# Use the example config we defined earlier 
train_conditional_diffusion()

[34m[1mwandb[0m:   97 of 97 files downloaded.  


Loaded 24 real geological samples from CigKarst
Selected 3 pre-validated samples for tracking.
Epoch 0: Simulating conditional diffusion training...
New best model at epoch 0! Validation loss: 0.8425
Epoch 1: Simulating conditional diffusion training...
New best model at epoch 1! Validation loss: 0.7808
Epoch 2: Simulating conditional diffusion training...
New best model at epoch 2! Validation loss: 0.7267
Epoch 3: Simulating conditional diffusion training...
New best model at epoch 3! Validation loss: 0.6818
Epoch 4: Simulating conditional diffusion training...
Epoch 5: Simulating conditional diffusion training...
Epoch 6: Simulating conditional diffusion training...
Epoch 7: Simulating conditional diffusion training...
Epoch 8: Simulating conditional diffusion training...
Epoch 9: Simulating conditional diffusion training...


0,1
train/learning_rate,██████████████▆▆▆▆▆▆▆▆▆▃▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁
train/noise_prediction_loss,██████▇▇▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁
train/reconstruction_mse,█████▇▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁
train/total_loss,████████▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁
val/avg_condition_consistency_mse,█▅▃▁▂▂▂▇▇▇
val/avg_reconstruction_mse,█▆▅▄▂▂▂▁▁▁
val/total_loss,█▅▃▁▁▁▁▄▄▄

0,1
train/learning_rate,0.00086
train/noise_prediction_loss,0.48049
train/reconstruction_mse,0.13947
train/total_loss,0.34408
val/avg_condition_consistency_mse,1.38225
val/avg_reconstruction_mse,8e-05
val/total_loss,0.76028


## 5. Hyperparameter Sweeps: Enterprise-Scale Optimization

Running one experiment manually is not efficient. To find the optimal model, we explore the hyperparameter space using W&B Sweeps.

- What this does
  - Defines a Bayesian sweep optimizing `val/total_loss`
  - Enables early termination via Hyperband (iteration == epoch)
    - `min_iter=3`: allow a few epochs before pruning
    - `max_iter=6`: aligns with the demo’s total epochs
    - `eta=3`: keep strongest ~1/3 at each bracket
  - Keeps sweeps fast by disabling 3D (`enable_3d=False`, `enable_high_fidelity_3d=False`, `enable_ipyvolume=False`)
  - Creates the sweep controller in the W&B cloud and returns a `sweep_id`

- Where to look in W&B
  - Sweeps page: centralized view of runs, pruning, and best configs
  - Run pages: each trial’s metrics, media (2D only in sweeps), and config

- Why W&B here
  - Built-in orchestration and visual comparisons reduce custom tooling
  - Early termination saves compute by stopping weak configs mid-training

- Notes
  - Agents derive project/entity from the sweep and may print “Ignoring project/entity” when starting; this is expected.
  - Keep runs short in demos (few epochs) to get visible results quickly.

- Try it yourself
  - [ ] Adjust the search space (e.g., widen `conditioning_strength`)
  - [ ] Re-create the sweep and note the new `sweep_id`
  - [ ] Open the Sweeps page and watch trials/pruning in real time

In [9]:
# Early termination: realistic Hyperband settings for epoch-based training
# - Iteration == epoch (we log val/total_loss once per epoch)
# - min_iter: let every run complete at least the first 3 epochs before pruning
# - max_iter: equal to total epochs per run (6 here). If you raise epochs, raise this too
# - eta: 3 means keep roughly the top 1/3 at each bracket; 2 is gentler, 4 is more aggressive
# - Effect: poor configs get stopped around mid-training; strong configs run to completion
early_stop = {"type": "hyperband", "min_iter": 3, "max_iter": 6, "eta": 3}

sweep_config = {
    "method": "bayes",
    "metric": {"name": "val/total_loss", "goal": "minimize"},
    "parameters": {
    "learning_rate": {"distribution": "log_uniform_values", "min": 1e-5, "max": 3e-2},
    "batch_size": {"values": [2, 4, 8, 16]},
    "noise_schedule": {"values": ["cosine", "linear", "sigmoid"]},
    "conditioning_strength": {"distribution": "uniform", "min": 0.45, "max": 1.0},

    "epochs": {"value": 10},
    "batches_per_epoch": {"value": 20},
    "seed": {"values": [41, 42, 43]},

    "enable_3d": {"value": False},
    "enable_high_fidelity_3d": {"value": False},
    "enable_ipyvolume": {"value": False}
},
   # This is commented out as it serves no purpose in this workshop but can be useful for real-world training
   # "early_terminate": early_stop,
}
# 2. Initialize the Sweep
# This command creates the sweep controller in the W&B cloud. It acts as a
# central coordinator that agents can query for new jobs.
sweep_id = wandb.sweep(
    sweep=sweep_config,
    entity=ENTITY,
    project=PROJECT
)

print(f"✅ Sweep created successfully! Sweep ID: {sweep_id}")
print(f"🧹 View and manage your sweep here: https://wandb.ai/{ENTITY}/{PROJECT}/sweeps/{sweep_id}")

Task was destroyed but it is pending!
task: <Task pending name='Task-5' coro=<Event.wait() running at /opt/homebrew/Cellar/python@3.10/3.10.18/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/locks.py:214> wait_for=<Future cancelled>>


Create sweep with ID: 6c2jgzy1
Sweep URL: https://wandb.ai/wandb_emea/workshop-ex123456789/sweeps/6c2jgzy1
✅ Sweep created successfully! Sweep ID: 6c2jgzy1
🧹 View and manage your sweep here: https://wandb.ai/wandb_emea/workshop-ex123456789/sweeps/6c2jgzy1


## 6. Launch a W&B Agent

Now that the sweep is initialized, launch an agent. The agent is a stateless worker that will:
1) Request a hyperparameter set from the sweep controller  
2) Invoke your training function (`function=train_conditional_diffusion`) with no positional arguments  
3) Inside the function, `wandb.init()` receives the sweep-provided config; access it via `wandb.config`  
4) Report results and request a new job until the sweep ends

- What this does
  - Starts an agent that executes N trials (`count`) from the sweep
  - Uses the sweep’s metric (`val/total_loss`) for early termination
  - Disables 3D per the sweep config for speed (`enable_3d=False`, `enable_high_fidelity_3d=False`, `enable_ipyvolume=False`)

- How pruning works here
  - The sweep’s metric is `val/total_loss`, logged once per epoch
  - “Iterations” for Hyperband equal the number of times that metric is logged (i.e., epochs here)
  - `min_iter`/`max_iter` therefore map to epoch counts for pruning  
    (See docs: early termination brackets are based on the count of the optimized metric logs, not global step)

- Where to look in W&B
  - Sweeps page: view active/finished trials, pruning, and best configs
  - Each run page: trial metrics, config, and media (2D only during sweeps)

- Notes
  - Console may show “Ignoring project/entity” when running a sweep; this is expected (agents inherit from the sweep)
  - Keep `count` small for demos to finish quickly; raise later for broader search
  - Stop the agent with the notebook’s interrupt if needed

- Try it yourself
  - [ ] Run the agent for a few trials (e.g., `count=5`)
  - [ ] Observe pruning as weak configs stop earlier
  - [ ] Open the best run from the Sweeps page and compare configs

In [10]:
# 3. Run the Sweep Agent
# This single command connects a worker to the sweep controller. The agent
# automatically fetches a configuration, executes the training function,
# and reports the results back, requiring zero manual orchestration.
wandb.agent(sweep_id, function=train_conditional_diffusion, count=5)
wandb.teardown() # if we want to do normal runs after a sweep, in the same session, we must run this.


[34m[1mwandb[0m: Agent Starting Run: vft570iz with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	batches_per_epoch: 20
[34m[1mwandb[0m: 	conditioning_strength: 0.8810594491932708
[34m[1mwandb[0m: 	enable_3d: False
[34m[1mwandb[0m: 	enable_high_fidelity_3d: False
[34m[1mwandb[0m: 	enable_ipyvolume: False
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	learning_rate: 0.00016897362920980895
[34m[1mwandb[0m: 	noise_schedule: sigmoid
[34m[1mwandb[0m: 	seed: 41
[34m[1mwandb[0m: Currently logged in as: [33mallanstevenson[0m ([33mwandb_emea[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   97 of 97 files downloaded.  


Loaded 24 real geological samples from CigKarst
Selected 3 pre-validated samples for tracking.
Epoch 0: Simulating conditional diffusion training...
New best model at epoch 0! Validation loss: 0.8477
Epoch 1: Simulating conditional diffusion training...
New best model at epoch 1! Validation loss: 0.7792
Epoch 2: Simulating conditional diffusion training...
New best model at epoch 2! Validation loss: 0.7228
Epoch 3: Simulating conditional diffusion training...
New best model at epoch 3! Validation loss: 0.6759
Epoch 4: Simulating conditional diffusion training...
Epoch 5: Simulating conditional diffusion training...
Epoch 6: Simulating conditional diffusion training...
Epoch 7: Simulating conditional diffusion training...
Epoch 8: Simulating conditional diffusion training...
Epoch 9: Simulating conditional diffusion training...


0,1
train/learning_rate,██████████▆▆▆▆▆▆▆▆▆▆▆▆▆▃▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁
train/noise_prediction_loss,██████████▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▁▁
train/reconstruction_mse,████████▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▅▄▄▃▃▃▃▂▂▂▁▁▁▁
train/total_loss,██████████▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▄▃▃▃▃▃▃▃▁▁▁▁▁
val/avg_condition_consistency_mse,█▅▃▁▃▃▃▇▇▇
val/avg_reconstruction_mse,█▆▅▄▂▂▂▁▁▁
val/total_loss,█▅▃▁▁▁▁▄▄▄

0,1
train/learning_rate,0.00014
train/noise_prediction_loss,0.71495
train/reconstruction_mse,0.16439
train/total_loss,0.49473
val/avg_condition_consistency_mse,1.38226
val/avg_reconstruction_mse,9e-05
val/total_loss,0.76028


[34m[1mwandb[0m: Agent Starting Run: 0w8o0xjz with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	batches_per_epoch: 20
[34m[1mwandb[0m: 	conditioning_strength: 0.4804638063221182
[34m[1mwandb[0m: 	enable_3d: False
[34m[1mwandb[0m: 	enable_high_fidelity_3d: False
[34m[1mwandb[0m: 	enable_ipyvolume: False
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	learning_rate: 0.000560502678787149
[34m[1mwandb[0m: 	noise_schedule: sigmoid
[34m[1mwandb[0m: 	seed: 43
Task was destroyed but it is pending!
task: <Task pending name='Task-12' coro=<Event.wait() running at /opt/homebrew/Cellar/python@3.10/3.10.18/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/locks.py:214> wait_for=<Future cancelled>>


[34m[1mwandb[0m:   97 of 97 files downloaded.  


Loaded 24 real geological samples from CigKarst
Selected 3 pre-validated samples for tracking.
Epoch 0: Simulating conditional diffusion training...
New best model at epoch 0! Validation loss: 0.8447
Epoch 1: Simulating conditional diffusion training...
New best model at epoch 1! Validation loss: 0.7988
Epoch 2: Simulating conditional diffusion training...
New best model at epoch 2! Validation loss: 0.7516
Epoch 3: Simulating conditional diffusion training...
New best model at epoch 3! Validation loss: 0.7154
Epoch 4: Simulating conditional diffusion training...
New best model at epoch 4! Validation loss: 0.6872
Epoch 5: Simulating conditional diffusion training...
New best model at epoch 5! Validation loss: 0.6866
Epoch 6: Simulating conditional diffusion training...
Epoch 7: Simulating conditional diffusion training...
Epoch 8: Simulating conditional diffusion training...
Epoch 9: Simulating conditional diffusion training...


0,1
train/learning_rate,██████████▆▆▆▆▆▆▆▆▆▆▆▆▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁
train/noise_prediction_loss,██████████▇█▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▄▄▃▃▂▂▂▁▁
train/reconstruction_mse,██████████▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁
train/total_loss,████████▇▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▃▃▃▃▂▂▁▁▁▁
val/avg_condition_consistency_mse,█▆▄▂▁▁▁▆▇▇
val/avg_reconstruction_mse,█▇▅▄▂▂▂▁▁▁
val/total_loss,█▆▄▂▁▁▁▄▄▄

0,1
train/learning_rate,0.00048
train/noise_prediction_loss,0.81044
train/reconstruction_mse,0.2296
train/total_loss,0.57811
val/avg_condition_consistency_mse,1.38042
val/avg_reconstruction_mse,0.00016
val/total_loss,0.7593


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: qopt0wkn with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	batches_per_epoch: 20
[34m[1mwandb[0m: 	conditioning_strength: 0.5014788284786833
[34m[1mwandb[0m: 	enable_3d: False
[34m[1mwandb[0m: 	enable_high_fidelity_3d: False
[34m[1mwandb[0m: 	enable_ipyvolume: False
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	learning_rate: 3.521889560306934e-05
[34m[1mwandb[0m: 	noise_schedule: sigmoid
[34m[1mwandb[0m: 	seed: 42
Task was destroyed but it is pending!
task: <Task pending name='Task-19' coro=<Event.wait() running at /opt/homebrew/Cellar/python@3.10/3.10.18/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/locks.py:214> wait_for=<Future cancelled>>


[34m[1mwandb[0m:   97 of 97 files downloaded.  


Loaded 24 real geological samples from CigKarst
Selected 3 pre-validated samples for tracking.
Epoch 0: Simulating conditional diffusion training...
New best model at epoch 0! Validation loss: 0.8424
Epoch 1: Simulating conditional diffusion training...
New best model at epoch 1! Validation loss: 0.7954
Epoch 2: Simulating conditional diffusion training...
New best model at epoch 2! Validation loss: 0.7520
Epoch 3: Simulating conditional diffusion training...
New best model at epoch 3! Validation loss: 0.7134
Epoch 4: Simulating conditional diffusion training...
New best model at epoch 4! Validation loss: 0.6873
Epoch 5: Simulating conditional diffusion training...
New best model at epoch 5! Validation loss: 0.6866
Epoch 6: Simulating conditional diffusion training...
Epoch 7: Simulating conditional diffusion training...
Epoch 8: Simulating conditional diffusion training...
Epoch 9: Simulating conditional diffusion training...


0,1
train/learning_rate,██████████████▆▆▆▆▆▆▆▆▆▆▆▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁
train/noise_prediction_loss,████████▇▇▇▇▇▇▇▇▇▇▇▇▆▆▅▅▅▅▄▄▄▄▄▄▄▃▃▃▂▂▂▁
train/reconstruction_mse,████████▇▇▇▇▇▇▇▇▇▆▆▆▆▆▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁
train/total_loss,█████████████▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▂▁▁
val/avg_condition_consistency_mse,█▆▄▂▁▁▁▆▇▇
val/avg_reconstruction_mse,█▆▅▄▂▂▂▁▁▁
val/total_loss,█▆▄▂▁▁▁▄▄▄

0,1
train/learning_rate,3e-05
train/noise_prediction_loss,0.82877
train/reconstruction_mse,0.23225
train/total_loss,0.59017
val/avg_condition_consistency_mse,1.38057
val/avg_reconstruction_mse,0.00014
val/total_loss,0.75938


[34m[1mwandb[0m: Agent Starting Run: pya7ad99 with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	batches_per_epoch: 20
[34m[1mwandb[0m: 	conditioning_strength: 0.456685258855234
[34m[1mwandb[0m: 	enable_3d: False
[34m[1mwandb[0m: 	enable_high_fidelity_3d: False
[34m[1mwandb[0m: 	enable_ipyvolume: False
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	learning_rate: 0.0004068509889883103
[34m[1mwandb[0m: 	noise_schedule: sigmoid
[34m[1mwandb[0m: 	seed: 43


Task was destroyed but it is pending!ownloaded...
task: <Task pending name='Task-26' coro=<Event.wait() running at /opt/homebrew/Cellar/python@3.10/3.10.18/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/locks.py:214> wait_for=<Future cancelled>>
[34m[1mwandb[0m:   97 of 97 files downloaded.  


Loaded 24 real geological samples from CigKarst
Selected 3 pre-validated samples for tracking.
Epoch 0: Simulating conditional diffusion training...
New best model at epoch 0! Validation loss: 0.8447
Epoch 1: Simulating conditional diffusion training...
New best model at epoch 1! Validation loss: 0.7999
Epoch 2: Simulating conditional diffusion training...
New best model at epoch 2! Validation loss: 0.7536
Epoch 3: Simulating conditional diffusion training...
New best model at epoch 3! Validation loss: 0.7182
Epoch 4: Simulating conditional diffusion training...
New best model at epoch 4! Validation loss: 0.6873
Epoch 5: Simulating conditional diffusion training...
New best model at epoch 5! Validation loss: 0.6867
Epoch 6: Simulating conditional diffusion training...
Epoch 7: Simulating conditional diffusion training...
Epoch 8: Simulating conditional diffusion training...
Epoch 9: Simulating conditional diffusion training...


0,1
train/learning_rate,███████████▆▆▆▆▆▆▆▆▆▆▆▆▆▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▁
train/noise_prediction_loss,███████▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄▃▃▂▂▁▁▁
train/reconstruction_mse,████████▇▇▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
train/total_loss,██████████▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁
val/avg_condition_consistency_mse,█▆▄▂▁▁▁▆▇▇
val/avg_reconstruction_mse,█▇▅▄▂▂▂▁▁▁
val/total_loss,█▆▄▂▁▁▁▄▄▄

0,1
train/learning_rate,0.00035
train/noise_prediction_loss,0.82217
train/reconstruction_mse,0.23838
train/total_loss,0.58866
val/avg_condition_consistency_mse,1.38017
val/avg_reconstruction_mse,0.00018
val/total_loss,0.75917


[34m[1mwandb[0m: Agent Starting Run: jl5tji2l with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	batches_per_epoch: 20
[34m[1mwandb[0m: 	conditioning_strength: 0.4576413774041321
[34m[1mwandb[0m: 	enable_3d: False
[34m[1mwandb[0m: 	enable_high_fidelity_3d: False
[34m[1mwandb[0m: 	enable_ipyvolume: False
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	learning_rate: 0.00045482625942300887
[34m[1mwandb[0m: 	noise_schedule: sigmoid
[34m[1mwandb[0m: 	seed: 43
Task was destroyed but it is pending!
task: <Task pending name='Task-33' coro=<Event.wait() running at /opt/homebrew/Cellar/python@3.10/3.10.18/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/locks.py:214> wait_for=<Future cancelled>>


[34m[1mwandb[0m:   97 of 97 files downloaded.  


Loaded 24 real geological samples from CigKarst
Selected 3 pre-validated samples for tracking.
Epoch 0: Simulating conditional diffusion training...
New best model at epoch 0! Validation loss: 0.8447
Epoch 1: Simulating conditional diffusion training...
New best model at epoch 1! Validation loss: 0.7999
Epoch 2: Simulating conditional diffusion training...
New best model at epoch 2! Validation loss: 0.7535
Epoch 3: Simulating conditional diffusion training...
New best model at epoch 3! Validation loss: 0.7181
Epoch 4: Simulating conditional diffusion training...
New best model at epoch 4! Validation loss: 0.6873
Epoch 5: Simulating conditional diffusion training...
New best model at epoch 5! Validation loss: 0.6867
Epoch 6: Simulating conditional diffusion training...
Epoch 7: Simulating conditional diffusion training...
Epoch 8: Simulating conditional diffusion training...
Epoch 9: Simulating conditional diffusion training...


0,1
train/learning_rate,██████████▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁
train/noise_prediction_loss,███████████▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▃▃▃▃▁
train/reconstruction_mse,██████████▇▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁
train/total_loss,█████████████▇▇▇▇▇▇▇▇▆▆▆▅▅▄▄▄▄▃▂▂▂▂▁▁▁▁▁
val/avg_condition_consistency_mse,█▆▄▂▁▁▁▆▇▇
val/avg_reconstruction_mse,█▇▅▄▂▂▂▁▁▁
val/total_loss,█▆▄▂▁▁▁▄▄▄

0,1
train/learning_rate,0.00039
train/noise_prediction_loss,0.82031
train/reconstruction_mse,0.23697
train/total_loss,0.58697
val/avg_condition_consistency_mse,1.38018
val/avg_reconstruction_mse,0.00018
val/total_loss,0.75918


## 7. Programmatic Executive Reports: From Model to Boardroom

The final step is to bridge the gap between technical results and business stakeholders. This section shows how to generate a data‑driven executive report directly from your experiments—kept in sync with the project’s best run.

- What this does
  - Identifies the best finished run by `val/total_loss`
  - Builds a report with headline KPIs, a validation loss chart, and the per‑epoch validation table
  - Publishes a shareable link that always points at the latest, programmatically generated summary

- Where to look in W&B
  - Reports: the generated report appears under the project’s Reports tab
  - Linked Run: the report references the best run and its media/table
  - Tables Panel: uses the logged `val_table/validation_table` from the best run

- Why W&B here
  - Stakeholder‑ready, parameterized reports reduce ad‑hoc slide work and stay in sync with the source of truth
  - Reports blend metrics, media, and tables—with live links back to runs, artifacts, and registry items

- Try it yourself
  - [ ] Run at least one training to produce finished runs with `val/total_loss`
  - [ ] Execute the report cell; open the printed URL
  - [ ] Share the link with SMEs; it updates as new best runs are produced

In [11]:
# Programmatic report without Markdown formatting in text blocks
import time
import wandb
import wandb_workspaces.reports.v2 as wr

print("Creating a programmatic report...")
print(f"Using entity: '{ENTITY}', project: '{PROJECT}'")

try:
    # 1) Find the best finished run by val/total_loss
    api = wandb.Api()
    runs = api.runs(
        path=f"{ENTITY}/{PROJECT}",
        order="-created_at",
        filters={"tags": "conditional-diffusion", "state": "finished"},
    )
    if not runs:
        raise ValueError("No finished runs with tag 'conditional-diffusion' found.")

    best_run = sorted(
        runs, key=lambda r: r.summary.get("val/total_loss", float("inf"))
    )[0]
    best_val = best_run.summary.get("val/total_loss", None)
    best_val_str = f"{best_val:.4f}" if best_val is not None else "N/A"
    print(f"Generating report based on best run: {best_run.name} (ID: {best_run.id})")
    if best_val is not None:
        print(f"Best validation loss: {best_val_str}")

    # 2) Build the report container (full-width page)
    report = wr.Report(
        entity=ENTITY,
        project=PROJECT,
        title=f"Geological ML Model Performance - {time.strftime('%Y-%m-%d')}",
        description=f"Automated summary for the conditional diffusion model. Best run: {best_run.name}.",
        width="fluid",
    )

    # Target the best run only
    runset = wr.Runset(entity=ENTITY, project=PROJECT, name=best_run.id)

    # 3) Blocks (plain text; no Markdown formatting)
    report.blocks = [
        wr.H1("Executive Summary: Geological Interpretation Model"),
        wr.P(
            f"The model was trained for {best_run.config.get('epochs', 'N/A')} epochs, "
            f"achieving a final validation loss of {best_val_str}. "
            f"This automated report was generated on {time.strftime('%B %d, %Y')}."
        ),

        wr.H2("Key Performance Metrics"),
        wr.P("Validation loss over time (X axis uses Step)."),

        wr.PanelGrid(
            runsets=[runset],
            panels=[
                wr.LinePlot(
                    title="Validation Loss Over Training",
                    y=["val/total_loss"],
                    layout={"w": 24, "h": 12},
                ),
            ],
        ),

        wr.PanelGrid(
            runsets=[runset],
            panels=[
                wr.ScalarChart(
                title="Best Final Validation Loss",
                metric="val/total_loss",
                layout={"w": 8, "h": 8},
                ),
                wr.ScalarChart(
                    title="Best Avg Reconstruction MSE",
                    metric="val/avg_reconstruction_mse",
                    layout={"w": 8, "h": 8},
                ),
                wr.ScalarChart(
                    title="Best Condition-Consistency MSE",
                    metric="val/avg_condition_consistency_mse",
                    layout={"w": 8, "h": 8},
                ),
            ],
        ),

        wr.H2("Detailed Validation Analysis"),
        wr.P("Interactive table logged each epoch; single-channel slices and well-log comparison."),

        wr.PanelGrid(
            runsets=[runset],
            panels=[
                wr.WeavePanelSummaryTable(
                    table_name="val_table/validation_table",
                    layout={"w": 24, "h": 20},
                )
            ],
        ),

        wr.H2("Model Governance and Next Steps"),
        wr.P("The best model artifact from this run has been versioned and linked to the Model Registry."),
        wr.P(f"Best model run: {best_run.name} ({best_run.id})"),
        wr.P(["Run URL: ",wr.Link(f"https://wandb.ai/{ENTITY}/{PROJECT}/runs/{best_run.id}", url=f"https://wandb.ai/{ENTITY}/{PROJECT}/runs/{best_run.id}")]),
        wr.P("Dataset artifact version: CigKarst:v0"),
        wr.P("Model artifact name: conditional-diffusion-checkpoint"),
        wr.P("Current registry alias: staging"),
        wr.P("Next steps:"),
        wr.P("1) Review model performance with geological subject matter experts."),
        wr.P("2) Promote the model from Staging to Production if acceptance criteria are met."),
        wr.P("3) Plan deployment to a validation environment."),
    ]

    report.save()
    print("\nReport created successfully!")
    print(f"Title: {report.title}")
    print(f"URL: {report.url}")
    print("This script could be extended to leverage a Teams webhook to send a notification to a channel")

except Exception as e:
    print(f"\n Report creation failed: {e}")

Creating a programmatic report...
Using entity: 'wandb_emea', project: 'workshop-ex123456789'


Task was destroyed but it is pending!
task: <Task pending name='Task-40' coro=<Event.wait() running at /opt/homebrew/Cellar/python@3.10/3.10.18/Frameworks/Python.framework/Versions/3.10/lib/python3.10/asyncio/locks.py:214> wait_for=<Future cancelled>>


Generating report based on best run: worthy-sweep-4 (ID: pya7ad99)
Best validation loss: 0.7592


[34m[1mwandb[0m: Saved report to: https://wandb.ai/wandb_emea/workshop-ex123456789/reports/Geological-ML-Model-Performance---2025-08-12--VmlldzoxMzk2ODYwMw==



Report created successfully!
Title: Geological ML Model Performance - 2025-08-12
URL: https://wandb.ai/wandb_emea/workshop-ex123456789/reports/Geological-ML-Model-Performance---2025-08-12--VmlldzoxMzk2ODYwMw==
This script could be extended to leverage a Teams webhook to send a notification to a channel
