# W&B Enterprise Workshop: Advanced Geological AI

Welcome to the W&B Enterprise Workshop. In this session, we will demonstrate how Weights & Biases serves as the indispensable **system of record** for a complex, enterprise-grade machine learning workflow. By establishing a centralized hub for all our activities, we can break down silos between geoscientists, ML engineers, and stakeholders, creating a single source of truth for the entire project lifecycle.

We will simulate the training of a **conditional diffusion model** for geological structure generation. This allows us to focus on the MLOps challenges—collaboration, monitoring, governance, and reporting—that W&B is designed to solve, without waiting hours for a real model to train.

Our first step is to install the required libraries and import our dependencies.

In [1]:
# Install minimal dependencies for W&B workshop
# The wandb-workspaces library is key for both programmatic workspaces and reporting
# allowing us to automatically generate stakeholder-ready W&B Reports later
%pip install wandb numpy tqdm wandb-workspaces plotly pillow scikit-image pyvista ipyvolume ipython-genutils -q

import numpy as np
import wandb
import wandb_workspaces.reports.v2 as wr
from tqdm import tqdm
import time
import os
from pathlib import Path
from typing import List, Tuple, Dict, Any
import json
import plotly.graph_objects as go
import plotly.io as pio
from skimage.metrics import structural_similarity as ssim

# New imports for 3D visualization
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pyvista as pv
import ipyvolume as ipv
from skimage.transform import resize
from skimage.metrics import structural_similarity as ssim

# Set PyVista to use an off-screen plotter for notebook environments
pv.set_jupyter_backend(None)



Note: you may need to restart the kernel to use updated packages.


## 1. W&B Project Configuration

Here, we'll log in to W&B and define the `ENTITY` (your team or organization) and the `PROJECT` for this workshop. 

A W&B Project is a collaborative workspace where your entire team—from geoscientists to ML engineers—can track experiments, compare results, and share insights in real-time.

Centralizing our work here is the first step toward building a reliable system of record. We also define a sample configuration dictionary that holds our model's hyperparameters.

In [2]:

ENTITY = "wandb_emea"  # Replace with your W&B team entity
PROJECT = "workshop-scratchpad"

# Ensure you have logged in and defined ENTITY and PROJECT in a previous cell
assert "ENTITY" in locals(), "Please define the ENTITY variable"
assert "PROJECT" in locals(), "Please define the PROJECT variable"

# W&B Team Authentication
wandb.login()


print(f"W&B Authentication: Entity = {ENTITY}")
print(f"Project: {PROJECT}")

# Configuration for (simulated) conditional diffusion training
# wandb.config makes hyperparameters a first-class citizen.
# They are saved with every run, ensuring 100% reproducibility and
# enabling powerful, automated hyperparameter sweeps

example_config = {
    # Model Architecture
    "model_architecture": "ConditionalUNet3D",
    "task": "geological_structure_generation",
    "input_modality": "seismic_amplitude",
    "output_modality": "karst_structures",
    
    # Training Parameters
    "epochs": 10,
    "batches_per_epoch": 20,
    "learning_rate": 1e-3,
    "batch_size": 4,
    
    # Diffusion Parameters
    "timesteps": 1000,
    "noise_schedule": "cosine",
    "conditioning_strength": 0.8,
    
    # Geological Domain
    "voxel_resolution": "25m",
    "depth_range": "0-800m",
    "geological_context": "karst_detection"
}

print("Training configuration loaded")

[34m[1mwandb[0m: Currently logged in as: [33mallanstevenson[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


W&B Authentication: Entity = wandb_emea
Project: workshop-scratchpad
Training configuration loaded


## 2. Simulation and Visualization Helpers

This cell contains the helper functions for our workshop. **This is not part of the core W&B integration.**

- `simulate_...`: A function that mimics the behavior of a real conditional diffusion model, generating progressively better geological predictions with each epoch.
- `plot_...` & `normalize_...`: Utilities for creating interactive charts and preparing images for visualization.

We're using a simulation because these generative models can run for days or weeks. By focusing on the MLOps workflow, we demonstrate how to solve the operational challenges—like monitoring, debugging, and reporting—where teams lose the most time and money, especially when long-running jobs fail silently.

By isolating this simulation logic, we can focus the rest of the notebook purely on the MLOps workflow powered by W&B.

In [3]:
# By returning a wandb.Html object, we can log custom, interactive visualizations
# like this Plotly chart directly into the W&B dashboard. This allows domain experts
# to analyze results without switching contexts or downloading files.
def plot_well_log_comparison(gt_log: np.ndarray, pred_log: np.ndarray, depth: np.ndarray) -> wandb.Html:
    """Creates an interactive Plotly chart comparing ground truth and predicted well logs."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=gt_log, y=depth, mode='lines', name='Ground Truth Log', line=dict(color='black')))
    fig.add_trace(go.Scatter(x=pred_log, y=depth, mode='lines', name='Predicted Log', line=dict(color='crimson', dash='dash')))
    fig.update_layout(
        title="Well Log Comparison",
        xaxis_title="Signal Amplitude",
        yaxis_title="Depth (m)",
        yaxis_autorange='reversed' # Depth increases downwards
    )
    html = pio.to_html(fig)
    return wandb.Html(html)

def generate_synthetic_loss(step: int) -> float:
    """Generate realistic MSE loss that improves over time."""
    np.random.seed(int(step))
    base_loss = 1.5 * np.exp(-step / 150.0) + 0.05
    noise = np.random.normal(0, 0.02)
    return max(0.01, base_loss + noise)

def simulate_conditional_diffusion_progress(seismic_condition: np.ndarray, karst_target: np.ndarray, epoch: int, total_epochs: int = 10) -> np.ndarray:
    """
    Simulate realistic conditional diffusion model learning progression.
    Early epochs: noisy predictions with some seismic influence
    Later epochs: cleaner karst structures conditioned on seismic input
    """
    # Training progress (0.0 at start, 1.0 at end)
    progress = epoch / (total_epochs - 1)

    # Set deterministic seed based on epoch for reproducible progression
    np.random.seed(epoch * 42)

    # Stage 1 (epochs 0-3): Learning basic seismic-karst correlations
    if epoch <= 3:
        # Start with mostly noise, gradually incorporate seismic patterns
        noise_level = 0.8 - (epoch / 3) * 0.4  # 0.8 → 0.4
        structural_learning = epoch / 3 * 0.3   # 0.0 → 0.3

        # Generate structural noise that respects seismic boundaries
        structure_mask = (seismic_condition > np.percentile(seismic_condition, 60)).astype(float)
        noise = np.random.normal(0, noise_level, seismic_condition.shape)

        # Prediction combines noise with emerging structural understanding
        prediction = (noise * 0.7 +
                     seismic_condition * structural_learning +
                     karst_target * structure_mask * 0.1)

    # Stage 2 (epochs 4-6): Refining geological structures
    elif epoch <= 6:
        stage_progress = (epoch - 4) / 2  # 0.0 → 1.0 for epochs 4-6

        # Better geological understanding - focus on karst-forming regions
        karst_regions = (seismic_condition > np.percentile(seismic_condition, 40)) & (seismic_condition < np.percentile(seismic_condition, 85))

        # Noise reduces significantly
        noise_level = 0.4 - stage_progress * 0.25  # 0.4 → 0.15
        noise = np.random.normal(0, noise_level, seismic_condition.shape)

        # Geological structure emergence
        geological_understanding = 0.3 + stage_progress * 0.4  # 0.3 → 0.7

        prediction = (noise * 0.3 +
                     karst_target * karst_regions * geological_understanding +
                     seismic_condition * 0.2 * (1 - karst_regions))

    # Stage 3 (epochs 7-9): Fine-tuning and detail refinement
    else:
        stage_progress = (epoch - 7) / 2  # 0.0 → 1.0 for epochs 7-9

        # High geological accuracy with fine detail learning
        noise_level = 0.15 - stage_progress * 0.10  # 0.15 → 0.05
        noise = np.random.normal(0, noise_level, seismic_condition.shape)

        # Near-target accuracy with realistic imperfections
        accuracy = 0.7 + stage_progress * 0.25  # 0.7 → 0.95

        # Add some realistic geological interpretation uncertainty
        uncertainty_mask = np.random.random(seismic_condition.shape) < 0.1

        prediction = (karst_target * accuracy +
                     noise * 0.1 +
                     seismic_condition * uncertainty_mask * 0.05)

    # Ensure realistic value ranges
    prediction = np.clip(prediction, 0, 1)
    # Add a minuscule amount of noise to break mathematical perfection
    # This ensures (prediction - ground_truth) is never exactly all zeros.
    prediction += np.random.uniform(-1e-9, 1e-9, prediction.shape)

    return prediction

def normalize_for_visualization(data: np.ndarray) -> np.ndarray:
    """Normalize geological data for W&B visualization with proper contrast."""
    data_min, data_max = np.min(data), np.max(data)
    if data_max == data_min:
        return np.full_like(data, 0.5)
    normalized = (data - data_min) / (data_max - data_min)
    # Apply gamma correction for better geological feature visibility
    return np.power(normalized, 0.7)

def simulate_forward_model(x_pred: np.ndarray) -> np.ndarray:
    """
    Simulates a forward model y=f(x) to generate a predicted condition y_pred.
    In a real scenario, this would be a physics-based or learned model.
    Here, we simulate it by applying a slight blur and adding minor noise.
    """
    # Apply a simple blurring effect (convolution with a small kernel)
    kernel = np.ones((3, 3, 3)) / 27.0
    # Use scipy for convolution if available, otherwise a simpler method
    try:
        from scipy.ndimage import convolve
        blurred = convolve(x_pred, kernel, mode='reflect')
    except ImportError:
        # Fallback if scipy is not installed
        blurred = x_pred
        
    # Add a small amount of random noise
    noise = np.random.normal(0, 0.02, x_pred.shape)
    y_pred = blurred + noise
    
    return np.clip(y_pred, 0, 1)

### 3D Validation: Downsampled Overview

**The Goal:** To get a quick, high-level "sanity check" of the model's 3D output across the entire volume.

**The Technique:** We take each of the high-resolution volumes (e.g., 128x128x128) and downsample them to a much smaller size (e.g., 64x64x64). These smaller volumes are then rendered side-by-side using Plotly.

**Why it's useful:**
* **Performance:** Smaller volumes render very quickly in the browser, providing immediate feedback without performance lag.
* **Full Context:** You can see the entire spatial domain at once, which is great for identifying large-scale structural problems or biases in the model's output.
* **Browser-Friendly:** This is the most reliable way to visualize multiple 3D volumes on a wide range of hardware, as it keeps memory and GPU usage low.

**The Trade-off:**
* **Loss of Detail:** Fine-grained features, sharp edges, and subtle details in the geology will be blurred or lost during the downsampling process. This view is not suitable for detailed analysis.

In [11]:
def create_downsampled_comparison(volumes: Dict[str, np.ndarray]) -> go.Figure:
    """Downsamples volumes and returns a side-by-side Plotly figure."""
    target_shape = (64, 64, 64)
    downsampled_volumes = {}
    for name, data in volumes.items():
        data_min, data_max = np.min(data), np.max(data)
        resized_data = resize(data, target_shape, anti_aliasing=True)
        downsampled_volumes[name] = resized_data * (data_max - data_min) + data_min
    fig = make_subplots(rows=1, cols=5, specs=[[{'type': 'volume'}] * 5], subplot_titles=list(downsampled_volumes.keys()))
    for i, (name, data) in enumerate(downsampled_volumes.items()):
        fig.add_trace(go.Volume(x=np.arange(data.shape[0]), y=np.arange(data.shape[1]), z=np.arange(data.shape[2]), value=data.flatten(), isomin=np.min(data), isomax=np.max(data), opacity=0.1, surface_count=15, colorscale='RdBu' if 'Residual' in name else 'viridis'), row=1, col=i + 1)
    fig.update_layout(title_text="Downsampled 3D Comparison", height=400, margin=dict(t=50, b=10, l=10, r=10))
    return fig

### 3D Validation: Cropped High-Detail View

**The Goal:** To inspect a specific region of the 3D volume at full resolution, preserving all the fine details.

**The Technique:** Instead of resizing the entire volume, we extract a smaller sub-volume (e.g., a 64x64x64 cube) from a consistent location (like the center) of the original, high-resolution data. These full-detail crops are then rendered side-by-side.

**Why it's useful:**
* **Maximum Detail:** You see the data exactly as it is, with no loss of resolution. This is crucial for validating the texture, sharpness, and small-scale accuracy of the model's predictions.
* **Targeted Analysis:** It allows you to focus on a specific known feature or a problematic area identified in the downsampled view.

**The Trade-off:**
* **Loss of Context:** You are only seeing a small fraction of the total volume and lose the broader structural context. An issue might seem small in the crop but could be part of a much larger problem not visible in the limited view.

In [10]:
def create_cropped_comparison(volumes: Dict[str, np.ndarray]) -> go.Figure:
    """Crops the center of volumes and returns a side-by-side Plotly figure."""
    crop_size = 64
    cropped_volumes = {}
    for name, data in volumes.items():
        if all(dim >= crop_size for dim in data.shape):
            center = [dim // 2 for dim in data.shape]; start = [c - crop_size // 2 for c in center]; end = [c + crop_size // 2 for c in center]
            cropped_volumes[name] = data[start[0]:end[0], start[1]:end[1], start[2]:end[2]]
        else:
            cropped_volumes[name] = data
    fig = make_subplots(rows=1, cols=5, specs=[[{'type': 'volume'}] * 5], subplot_titles=list(cropped_volumes.keys()))
    for i, (name, data) in enumerate(cropped_volumes.items()):
        fig.add_trace(go.Volume(x=np.arange(data.shape[0]), y=np.arange(data.shape[1]), z=np.arange(data.shape[2]), value=data.flatten(), isomin=np.min(data), isomax=np.max(data), opacity=0.1, surface_count=15, colorscale='RdBu' if 'Residual' in name else 'viridis'), row=1, col=i + 1)
    fig.update_layout(title_text="Cropped (Full-Res) 3D Comparison", height=400, margin=dict(t=50, b=10, l=10, r=10))
    return fig

### 3D Validation: High-Fidelity Single Volume Render

**The Goal:** To generate the highest-quality interactive visualization possible for a single, critical 3D volume.

**The Technique:** We use a specialized scientific visualization library like PyVista or ipyvolume. These libraries are built for performance and offer advanced rendering features. The process involves creating a render, exporting it to a self-contained HTML file, and logging that file to W&B.

**Why it's useful:**
* **Best Visual Quality:** These libraries use sophisticated rendering techniques (e.g., volume ray casting) to produce much clearer and more detailed visualizations than general-purpose plotting tools.
* **Advanced Controls:** They often provide more advanced tools for manipulating the view, such as adjusting color maps and lighting, which can be crucial for geological interpretation.

**The Trade-off:**
* **Single-Volume Focus:** This method is best for inspecting one volume at a time. Comparing multiple volumes requires logging separate viewers, which can be less convenient than a side-by-side view.
* **Larger Artifacts:** The resulting HTML files can be larger than a simple Plotly JSON object.

In [15]:
def create_pyvista_render(volume: np.ndarray, title: str) -> str:
    """Renders a single volume using PyVista and returns the HTML as a string."""
    try:
        grid = pv.ImageData(dimensions=volume.shape); grid["values"] = volume.flatten(order="F")
        plotter = pv.Plotter(off_screen=True); plotter.add_volume(grid, cmap='RdBu' if 'Residual' in title else 'viridis', opacity='sigmoid', shade=True)
        plotter.add_axes(); plotter.camera_position = 'iso'
        html_filename = f"{title}.html"; plotter.export_html(html_filename, progressive=True)
        with open(html_filename, 'r') as f: html_content = f.read()
        os.remove(html_filename)
        return html_content
    except Exception as e:
        return f"<p>PyVista rendering failed: {e}</p>"

def create_ipyvolume_render(volume: np.ndarray, title: str) -> str:
    """Renders a single volume using ipyvolume and returns the HTML as a string."""
    try:
        ipv.clear(); ipv.quickvolshow(volume, level_width=0.1, opacity=0.03, data_min=np.min(volume), data_max=np.max(volume))
        html_filename = f"{title}_ipv.html"; ipv.save(html_filename)
        with open(html_filename, 'r') as f: html_content = f.read()
        os.remove(html_filename)
        return html_content
    except Exception as e:
        return f"<p>ipyvolume rendering failed: {e}</p>"


## 3. The Core Training Function with W&B Integration

This function encapsulates the entire training and evaluation workflow. It serves as a blueprint for a reproducible, auditable, and transparent ML lifecycle. Each section is clearly marked to show how a specific W&B feature is used to track, version, and evaluate the model, turning a standard training script into an enterprise-ready system of record.

In [16]:
def train_conditional_diffusion(config=None):
    """
    Train a conditional diffusion model for geological structure generation.
    This function is instrumented with W&B to track the entire lifecycle.
    """
    # ===================================================================
    # 1. W&B SETUP: Initialize Run and Centralize Configuration
    #    - wandb.init() starts a new run to track this experiment.
    #    - wandb.config pulls hyperparameter settings, making them available to sweeps.
    # ===================================================================
    # Use the example config we defined earlier 
    config = example_config

    # Initialize a new W&B run
    run = wandb.init(
        entity=ENTITY,
        project=PROJECT,
        config=config, 
        # job_type and tags are powerful organizational tools that make it easy
        # to filter, group, and compare runs across a large project.
        job_type="training",
        tags=["conditional-diffusion"]
    )

    # Use the config from W&B (this allows sweeps to override values)
    config = wandb.config

    # ===================================================================
    # 2. W&B ARTIFACTS: Versioning and Loading the Dataset
    #    - run.use_artifact() declares a dependency on a specific dataset version.
    #    - This creates a lineage graph, giving us a full audit trail.
    # ===================================================================
    # Load CigKarst dataset from W&B Registry
    
    
    # run.use_artifact() creates a direct link to the dataset version used for this run.
    # This automatically generates a complete, visual data-to-model lineage graph,
    # which is essential for auditability and debugging.
    artifact = run.use_artifact('wandb-registry-dataset/CigKarst:v0', type='dataset')
    artifact_dir = artifact.download()

    # Load geological samples from the dataset
    geological_samples = []
    metadata_path = Path(artifact_dir) / "dataset_metadata.npz"

    if metadata_path.exists():
        metadata = np.load(metadata_path, allow_pickle=True)
        samples_data = metadata['samples'].tolist()
        for sample_info in samples_data:
            seismic_file = Path(artifact_dir) / sample_info['seismic_file']
            karst_file = Path(artifact_dir) / sample_info['karst_file']
            if seismic_file.exists() and karst_file.exists():
                seismic_data = np.load(seismic_file)
                karst_data = np.load(karst_file)
                sample = {
                    'seismic': seismic_data['patch'],
                    'karst': karst_data['patch'],
                    'sample_id': sample_info['sample_id'],
                    'coordinates': sample_info['coordinates'],
                    'source': sample_info['source_volume']
                }
                geological_samples.append(sample)
        print(f"Loaded {len(geological_samples)} real geological samples from CigKarst")
    else:
        print("Metadata file not found in artifact")
        run.finish()
        return None

    # Select a fixed set of samples for consistent validation across runs
    good_sample_ids = ["volume_0_patch_1", "volume_0_patch_3", "volume_1_patch_0"]
    fixed_samples = [s for s in geological_samples if s['sample_id'] in good_sample_ids]
    if len(fixed_samples) != len(good_sample_ids):
        raise ValueError(f"Error: Could not find all specified good samples.")
    print(f"Selected {len(fixed_samples)} pre-validated samples for tracking.")

    # Track best model
    best_validation_loss = float('inf')

    # ===================================================================
    # 3. W&B TRAINING LOOP: Logging Live Metrics
    #    - run.log() is called inside the loop to stream metrics in real-time.
    #    - This allows us to monitor model performance and system utilization live from the dashboard.
    # ===================================================================
    for epoch in range(config.get("epochs", 10)):
        print(f"Epoch {epoch}: Simulating conditional diffusion training...")
        batches_per_epoch = config.get("batches_per_epoch", 20)
        
        for batch in range(batches_per_epoch):
            step = epoch * batches_per_epoch + batch
            # Simulate training progress and calculate synthetic losses
            noise_pred_loss = generate_synthetic_loss(step)
            condition_mse = generate_synthetic_loss(step * 1.5) * 0.8
            total_loss = 0.5 * noise_pred_loss + 0.3 * condition_mse

            # Prepare log dictionary for training metrics
            log_dict = {
                "train/noise_prediction_loss": noise_pred_loss,
                "train/condition_mse": condition_mse,
                "train/total_loss": total_loss,
                "train/learning_rate": config.get("learning_rate", 1e-3) * (0.95 ** (epoch // 3)),
                "epoch": epoch
            }
            
            # SINGLE wandb.log call per training step. W&B automatically captures system
            # metrics (CPU/GPU utilization, memory) alongside your custom metrics, providing
            # a holistic view to diagnose performance bottlenecks in real-time.
            run.log(log_dict)

        # ===================================================================
        # 4. W&B VALIDATION: Logging Rich Media and Tables
        #    - At the end of each epoch, we log qualitative results.
        #    - wandb.Image() for visual comparison.
        #    - wandb.Html() for custom, interactive Plotly charts.
        #    - wandb.Table() to create structured, sortable tables of predictions.
        # ===================================================================
        val_total_loss = generate_synthetic_loss(epoch * 50) * 0.85
        
        # wandb.Table creates a rich, interactive table in the UI. This allows for
        # sorting and filtering results, and comparing images, plots, and metrics
        # side-by-side across different runs—all within a single view.
        validation_table = wandb.Table(columns=[
            "epoch", "sample_id", "seismic_input", "ground_truth",
            "prediction", "residual_map", "well_log_comparison", "condition_mse", "ssim_score", "log_correlation"])
            
        total_condition_mse = 0

        # --- NEW: Create visualizations for a small BATCH of validation samples ---
        # We will create a unique key for each visualization in the log dictionary.
        visualizations_log = {}
        
        for sample in fixed_samples:
            prediction = simulate_conditional_diffusion_progress(
                sample['seismic'], sample['karst'], epoch, config.get("epochs", 10))
            sample_id = sample['sample_id']
            
            # Generate the necessary volumes for this specific sample
            prediction_3d = simulate_conditional_diffusion_progress(sample['seismic'], sample['karst'], epoch, config.get("epochs", 10))
            y_pred_3d = simulate_forward_model(prediction_3d)
            volumes_for_viz = {
                "1_Input_Condition_Y": sample['seismic'],
                "2_Ground_Truth_X": sample['karst'],
                "3_AI_Prediction_X_pred": prediction_3d,
                "4_AI_Condition_Y_pred": y_pred_3d,
                "5_Residual_Y_pred_minus_Y": y_pred_3d - sample['seismic']
            }
 
            # --- NEW: Call visualization helper functions ---
            # Call the create functions and add the objects to our log dictionary
            # using unique keys that include the sample_id.
            visualizations_log[f"3D_Views/{sample_id}/Downsampled"] = create_downsampled_comparison(volumes_for_viz)
            visualizations_log[f"3D_Views/{sample_id}/Cropped"] = create_cropped_comparison(volumes_for_viz)
            visualizations_log[f"3D_Views/{sample_id}/PyVista_Residual"] = wandb.Html(create_pyvista_render(volumes_for_viz["5_Residual_Y_pred_minus_Y"], f"Residual_{sample_id}"))
            visualizations_log[f"3D_Views/{sample_id}/ipyvolume_Prediction"] = wandb.Html(create_ipyvolume_render(volumes_for_viz["3_AI_Prediction_X_pred"], f"Prediction_{sample_id}"))

            
            # end

            # Prepare data for logging (2D slices, logs, etc.)
            slice_idx = sample['seismic'].shape[2] // 2
            seismic_slice = sample['seismic'][:, :, slice_idx]
            gt_slice = sample['karst'][:, :, slice_idx]
            pred_slice = prediction[:, :, slice_idx]
            residual_slice = pred_slice - gt_slice
            max_abs_val = np.max(np.abs(residual_slice))
            residual_norm = (residual_slice + max_abs_val) / (2 * max_abs_val) if max_abs_val > 0 else np.zeros_like(residual_slice)
            
            # Metrics
            condition_mse = np.mean((pred_slice - gt_slice) ** 2)
            total_condition_mse += condition_mse
            ssim_score = ssim(gt_slice, pred_slice, data_range=1.0)
            
            # Well Logs
            well_log_depth = np.arange(prediction.shape[0]) * 25
            well_x, well_y = prediction.shape[1] // 2, prediction.shape[2] // 2
            gt_well_log = sample['karst'][:, well_x, well_y]
            pred_well_log = prediction[:, well_x, well_y]
            log_correlation = np.corrcoef(gt_well_log, pred_well_log)[0, 1] if np.std(gt_well_log) > 0 and np.std(pred_well_log) > 0 else 0.0
            
            # Add data to the validation table
            validation_table.add_data(
                epoch,
                sample['sample_id'],
                wandb.Image(normalize_for_visualization(seismic_slice)),
                wandb.Image(normalize_for_visualization(gt_slice)),
                wandb.Image(normalize_for_visualization(pred_slice)),
                wandb.Image(residual_norm, caption="Residual Map"),
                plot_well_log_comparison(gt_well_log, pred_well_log, well_log_depth),
                condition_mse,
                ssim_score,
                log_correlation
            )

        # Log validation metrics ONCE per epoch
        # Log all epoch-level data in a single call
        epoch_log_data = {"val/total_loss": val_total_loss, "val/avg_condition_mse": total_condition_mse / len(fixed_samples), "advanced_validation_table": validation_table, "epoch": epoch}
        epoch_log_data.update(visualizations_log) # Add the dictionary of 3D views
        run.log(epoch_log_data)

        # ===================================================================
        # 5. W&B MODEL REGISTRY: Versioning Models as Artifacts
        #    - We check if the model has improved and, if so, save it.
        #    - wandb.Artifact() creates a versioned package of model files and metadata.
        #    - run.link_artifact() registers the model in the W&B Model Registry,
        #      assigning it an alias like "best" or "staging" for easy retrieval.
        # ===================================================================
        if val_total_loss < best_validation_loss:
            best_validation_loss = val_total_loss
            print(f"New best model at epoch {epoch}! Validation loss: {val_total_loss:.4f}")
            
            checkpoint_artifact = wandb.Artifact(
                name="conditional-diffusion-checkpoint",
                type="model",
                description=f"Best conditional diffusion model - epoch {epoch}",
                metadata={
                    "epoch": epoch,
                    "validation_loss": round(val_total_loss, 4),
                    "dataset_artifact": artifact.name
                }
            )
            
            checkpoint_path = f"best_model_epoch_{epoch}.pth"
            Path(checkpoint_path).write_text(f"Best model checkpoint at epoch {epoch}")
            checkpoint_artifact.add_file(checkpoint_path)
            
            # Log the artifact and link it to the registry
            logged_artifact = run.log_artifact(checkpoint_artifact, aliases=["best"])
           
            # run.link_artifact() registers the model in the W&B Model Registry. Aliases
            # like "staging" or "production" create pointers for CI/CD systems, automating
            # the path from training to deployment.
            run.link_artifact(
                artifact=logged_artifact,
                target_path="wandb-registry-model/conditional-diffusion",
                aliases=["staging"]
            )
            os.remove(checkpoint_path)

    # ===================================================================
    # 6. W&B CLEANUP: Finalizing the Run
    #    - run.finish() marks the run as complete and uploads any remaining data.
    # ===================================================================
    run.finish()

## 4. Execute a Single Training Run

Now we call our main function to execute a single, baseline training run.

**Action:** When you run this cell, click the W&B link that appears in the output. This will take you to the live dashboard where you can see all the metrics, images, and tables being logged in real-time. This is the central hub for our experiment.

What to look for in the UI:

Live Metrics: Watch the loss curves update in real-time. No need to wait for the job to finish.

System Monitoring: Check the system tab to see live CPU/GPU utilization charts, automatically captured by W&B.

Interactive Tables: At the end of each epoch, a new row will appear in the advanced_validation_table. Click on the images to expand them and hover over the charts to interact.

In [17]:
# Train with default config
train_conditional_diffusion()

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/condition_mse,██▇▆▆▆▅▅▅▄▃▃▃▃▂▂▂▂▂▁
train/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/noise_prediction_loss,██▆▇▆▅▅▆▄▄▄▄▃▂▃▂▂▂▁▁
train/total_loss,██▆▇▆▆▅▆▄▄▄▄▃▂▃▂▂▂▁▁

0,1
epoch,0.0
train/condition_mse,1.03949
train/learning_rate,0.001
train/noise_prediction_loss,1.37596
train/total_loss,0.99983


[34m[1mwandb[0m:   49 of 49 files downloaded.  


Loaded 24 real geological samples from CigKarst
Selected 3 pre-validated samples for tracking.
Epoch 0: Simulating conditional diffusion training...
New best model at epoch 0! Validation loss: 1.3475
Epoch 1: Simulating conditional diffusion training...
New best model at epoch 1! Validation loss: 0.9296
Epoch 2: Simulating conditional diffusion training...
New best model at epoch 2! Validation loss: 0.6674
Epoch 3: Simulating conditional diffusion training...
New best model at epoch 3! Validation loss: 0.5076
Epoch 4: Simulating conditional diffusion training...
New best model at epoch 4! Validation loss: 0.3539
Epoch 5: Simulating conditional diffusion training...
New best model at epoch 5! Validation loss: 0.2714
Epoch 6: Simulating conditional diffusion training...
New best model at epoch 6! Validation loss: 0.1898
Epoch 7: Simulating conditional diffusion training...
Epoch 8: Simulating conditional diffusion training...
New best model at epoch 8! Validation loss: 0.1119
Epoch 9: Si

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▅▅▆▆▆▆▆▆▆▆▆▇▇███
train/condition_mse,████▇▇▆▆▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train/learning_rate,█████████▆▆▆▆▆▆▆▆▆▆▆▆▆▃▃▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁
train/noise_prediction_loss,███▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁
train/total_loss,██▆▆▆▆▆▆▆▅▄▄▄▄▄▄▃▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁
val/avg_condition_mse,█▇██▅▅▄▁▁▁
val/total_loss,█▆▄▃▂▂▁▁▁▁

0,1
epoch,9.0
train/condition_mse,0.21489
train/learning_rate,0.00086
train/noise_prediction_loss,0.4702
train/total_loss,0.29957
val/avg_condition_mse,0.00075
val/total_loss,0.116


## 5. Hyperparameter Sweeps: Enterprise-Scale Optimization

Running one experiment manually is not efficient. To find the optimal model, we need to explore the hyperparameter space. W&B Sweeps provide a powerful, scalable, and fully integrated way to automate this process.

Instead of writing custom loops or relying on external optimization libraries, you can define a search strategy in a simple configuration file. W&B then coordinates the search, distributing jobs to any number of agents and providing powerful visualizations to track the results in real-time.

First, we define a `sweep_config`. This configuration specifies the search strategy (Bayesian), the metric to optimize (validation loss), and the range of hyperparameters to test. We also include an early termination strategy to save compute resources by stopping underperforming runs.

In [None]:
# 1. Define Sweep Configuration
sweep_config = {
    "method": "bayes",  # Bayesian optimization for efficient search
    "metric": {
        "name": "val/total_loss",
        "goal": "minimize"
    },
    "parameters": {
        # Key training parameters
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-4,
            "max": 5e-3
        },
        "batch_size": {
            "values": [4, 8, 16]
        },

        # Diffusion-specific parameters
        "noise_schedule": {
            "values": ["cosine", "linear", "sigmoid"]
        },
        "conditioning_strength": {
            "distribution": "uniform",
            "min": 0.6,
            "max": 0.95
        },

        # Fixed parameters for the workshop
        "epochs": {
            "value": 10
        },
        "batches_per_epoch": {
            "value": 20
        }
    },
    # Add early stopping to be more efficient with resources. This tells the W&B
    # sweep controller to automatically stop poor-performing runs early, saving
    # significant compute time and cost.
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 3, # Stop runs that don't show improvement after 3 epochs
    }
}

# 2. Initialize the Sweep
# This command creates the sweep controller in the W&B cloud. It acts as a
# central coordinator that agents can query for new jobs.
sweep_id = wandb.sweep(
    sweep=sweep_config,
    entity=ENTITY,
    project=PROJECT
)

print(f"✅ Sweep created successfully! Sweep ID: {sweep_id}")
print(f"🧹 View and manage your sweep here: https://wandb.ai/{ENTITY}/{PROJECT}/sweeps/{sweep_id}")

## 6. Launch a W&B Agent

Now that the sweep is initialized, we launch an agent. The agent is a simple, stateless worker that will:
1.  Ask the W&B sweep server for a set of hyperparameters.
2.  Run our `train_conditional_diffusion` function with those hyperparameters.
3.  Repeat until the sweep is finished.

This architecture is incredibly scalable. You can launch agents on your laptop, on a fleet of cloud VMs, or in a Kubernetes cluster, and they will all coordinate through the central sweep controller to work on the same optimization problem

We will start one agent to run 5 experiments for this demo.

In [None]:
# 3. Run the Sweep Agent
# This single command connects a worker to the sweep controller. The agent
# automatically fetches a configuration, executes the training function,
# and reports the results back, requiring zero manual orchestration.
wandb.agent(sweep_id, function=train_conditional_diffusion, count=5)

## 7. Programmatic Executive Reports: From Model to Boardroom

The final step is to bridge the gap between technical results and business stakeholders. This is often the "last mile" problem in ML, where valuable insights get lost in translation. Manually creating reports is slow, error-prone, and creates static documents that are quickly outdated.

The W&B Report API allows us to programmatically generate dynamic, data-driven reports directly from our experiments. This transforms our experimental results from raw data into a persistent, shareable decision-making asset. This provides a transparent, repeatable, and always up-to-date summary for governance and high-stakes investment decisions

In [None]:
print("Attempting to create a programmatic report...")
print(f"Using entity: '{ENTITY}', project: '{PROJECT}'")

try:
    # --- 1. Find the Best Run ---
    # We'll create the report based on the best-performing run from our latest experiments.
    api = wandb.Api()
    # The W&B API allows for powerful, programmatic querying of all experimental data.
    # Here, we filter for our specific runs and sort them by the target metric.
    api = wandb.Api()
    runs = api.runs(
        path=f"{ENTITY}/{PROJECT}",
        order="-created_at",
        filters={"tags": "conditional-diffusion", "state": "finished"}
    )
    
    if not runs:
        raise ValueError("No finished runs with the 'conditional-diffusion' tag found.")
        
    best_run = sorted(runs, key=lambda run: run.summary.get("val/total_loss", float('inf')))[0]
    print(f"Generating report based on best run: {best_run.name} (ID: {best_run.id})")
    print(f"Best validation loss: {best_run.summary.get('val/total_loss'):.4f}")

    # --- 2. Create the Report Object ---
    report = wr.Report(
        entity=ENTITY,
        project=PROJECT,
        title=f"Geological ML Model Performance - {time.strftime('%Y-%m-%d')}",
        description=f"Automated summary for the conditional diffusion model. Best run: {best_run.name}.",
    )

    # --- 3. Define the Report Structure ---
    report.blocks = [
        wr.H1("Executive Summary: Geological Interpretation Model"),
        wr.P(f"""
        This report summarizes the performance of the conditional diffusion model trained for geological structure generation. 
        The model was trained for **{best_run.config.get('epochs', 'N/A')} epochs**, achieving a final validation loss of 
        **{best_run.summary.get('val/total_loss', 0):.4f}**. 
        This automated report was generated on {time.strftime('%B %d, %Y')}.
        """),
        
        wr.H2("Key Performance Metrics"),
        wr.P("The following charts show the model's learning progression and final performance on the validation set."),
        
        # --- 4. Add Panels for Metrics ---
        # PanelGrids pull visualizations directly from one or more runs. The report
        # remains a live document; if you re-run the experiment, the charts in the
        # report can be updated automatically.
        wr.PanelGrid(
            runsets=[wr.Runset(entity=ENTITY, project=PROJECT, name=best_run.id)],
            panels=[
                wr.LinePlot(
                    title="Validation Loss Over Epochs",
                    x="epoch",
                    y=["val/total_loss"],
                    layout={'w': 24, 'h': 12} # Make this chart wider
                ),
                wr.ScalarChart(
                    title="Best Final Validation Loss",
                    metric="val/total_loss",
                ),
            ]
        ),
        
        wr.H2("Detailed Validation Analysis"),
        wr.P("The table below provides a detailed, interactive breakdown of the model's performance on our validation samples from the final epoch."),
        
        # --- 5. Add Panel for the Validation Table ---
        # We can embed the entire interactive media table we created during validation
        # directly into the final report for detailed, qualitative analysis.
        wr.PanelGrid(
            runsets=[wr.Runset(entity=ENTITY, project=PROJECT, name=best_run.id)],
            panels=[
                wr.WeavePanelSummaryTable(
                    table_name="advanced_validation_table", 
                    layout={'w': 24, 'h': 16}
                )
            ]
        ),
        
        wr.H2("Model Governance & Next Steps"),
        wr.P(f"""
        The best model artifact from this run has been versioned and linked to the Model Registry, ensuring a complete audit trail.
        - **Best Model Run Link**: You can view the full experiment details here: [`{best_run.name}`](https://wandb.ai/{ENTITY}/{PROJECT}/runs/{best_run.id})
        - **Model Artifact Name**: `conditional-diffusion-checkpoint`
        
        **Next Steps**:
        1. Review model performance with geological subject matter experts.
        2. Promote the model from 'Staging' to a 'Production' alias in the registry for further testing.
        3. Begin planning for deployment to a validation environment.
        """)
    ]

    # --- 6. Save and Publish the Report ---
    # report.save() publishes the report to your W&B workspace, generating a
    # stable, shareable URL for stakeholders.
    report.save()

    print(f"\n✅ Executive report created successfully!")
    print(f"📊 Title: {report.title}")
    print(f"🔗 URL: {report.url}")
    print(f"📧 Ready for stakeholder distribution.")

except Exception as e:
    print(f"\n⚠️ Report creation failed: {e}")
    print("This might be due to:")
    print("- No finished runs found in the project with the correct tag.")
    print("- Incorrect API permissions or W&B team/enterprise settings.")