# SPECFEM3D Model Design and Visualization

This notebook provides an interactive environment for designing and visualizing seismic models, sources, and receiver configurations for SPECFEM3D simulations.

## Workflow:
1. Create and configure velocity model parameters
2. Define source location and characteristics
3. Set up receiver arrays (geophones and DAS fiber)
4. Visualize the complete model setup
5. Export settings for simulation

In [None]:
import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
from IPython.display import display, HTML

# Add parent directory to path for imports
sys.path.append('..')

# Import project modules
from param_manager import ParamManager
from src.utils.logging_utils import setup_logging

# Setup logging
logger = setup_logging(level='INFO')

## 1. Project Paths Setup

In [None]:
# Define paths
notebook_dir = os.path.dirname(os.path.abspath("__file__"))
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))
specfem_dir = os.path.expanduser("~/specfem3d")  # Path to SPECFEM3D installation
param_sets_dir = os.path.join(project_root, "parameter_sets")
output_dir = os.path.join(project_root, "data/synthetic/raw/simulation_design")

# Ensure output directory exists
Path(output_dir).mkdir(parents=True, exist_ok=True)

# List available parameter sets
param_files = [f for f in os.listdir(param_sets_dir) if f.endswith('.json')]
print("Available parameter templates:")
for i, param_file in enumerate(param_files):
    print(f"  {i+1}. {param_file}")

## 2. Load Parameter Template

In [None]:
# Load a parameter set
template_name = "standard_simulation.json"  # Change this to use a different template
template_path = os.path.join(param_sets_dir, template_name)

# Load parameters
with open(template_path, 'r') as f:
    params = json.load(f)

# Initialize parameter manager
param_manager = ParamManager()

# Print key parameters
print("Model dimensions:")
print(f"  X (km): {params['xmax']/1000:.2f}")
print(f"  Y (km): {params['ymax']/1000:.2f}")
print(f"  Z (km): {params['zmax']/1000:.2f}")
print("\nTime settings:")
print(f"  Simulation time (s): {params['simulation_time']}")
print(f"  Time step (s): {params['time_step']}")

## 3. Model Configuration

### 3.1 Define Model Structure

In [None]:
# Define model structure with interactive widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# Create slider widgets for model dimensions
x_slider = widgets.FloatSlider(value=params['xmax']/1000, min=1, max=10, step=0.5, description='X (km):')
y_slider = widgets.FloatSlider(value=params['ymax']/1000, min=1, max=10, step=0.5, description='Y (km):')
z_slider = widgets.FloatSlider(value=params['zmax']/1000, min=0.5, max=5, step=0.5, description='Z (km):')

# Number of layers (default is 3)
layer_slider = widgets.IntSlider(value=3, min=1, max=5, step=1, description='Layers:')

# Display the sliders
display(x_slider, y_slider, z_slider, layer_slider)

# Function to update model parameters
def update_model_dimensions():
    params['xmax'] = x_slider.value * 1000  # Convert to meters
    params['ymax'] = y_slider.value * 1000
    params['zmax'] = z_slider.value * 1000
    
    # Update number of layers (this will be used later)
    params['num_layers'] = layer_slider.value
    
    print(f"Model dimensions updated to: {x_slider.value:.1f} km × {y_slider.value:.1f} km × {z_slider.value:.1f} km")
    print(f"Number of layers: {layer_slider.value}")

# Button to confirm dimensions
update_button = widgets.Button(description="Update Dimensions")
update_button.on_click(lambda b: update_model_dimensions())
display(update_button)

### 3.2 Define Layer Properties

In [None]:
# Create widgets for layer properties
def create_layer_widgets(num_layers=3):
    # Initialize with default values if not in params
    if 'layers' not in params:
        params['layers'] = []
        for i in range(num_layers):
            params['layers'].append({
                'depth': (i+1) * params['zmax'] / (num_layers+1),
                'vp': 2000 + i * 500,
                'vs': 1000 + i * 300,
                'rho': 2000 + i * 100
            })
    
    # Ensure we have the right number of layers
    while len(params['layers']) < num_layers:
        params['layers'].append({
            'depth': params['zmax'] * 0.8,
            'vp': 3500,
            'vs': 2000,
            'rho': 2500
        })
    
    # Create widgets for each layer
    layer_widgets = []
    for i in range(num_layers):
        layer_i = params['layers'][i] if i < len(params['layers']) else {}
        
        # Interface depth (percentage of total depth)
        depth_pct = layer_i.get('depth', params['zmax'] * 0.5) / params['zmax'] * 100
        depth_slider = widgets.FloatSlider(
            value=depth_pct,
            min=5,
            max=95,
            step=5,
            description=f'Layer {i+1} depth (%):'
        )
        
        # P-wave velocity
        vp_slider = widgets.FloatSlider(
            value=layer_i.get('vp', 2000 + i * 500),
            min=1500,
            max=6000,
            step=100,
            description=f'Vp (m/s):'
        )
        
        # S-wave velocity
        vs_slider = widgets.FloatSlider(
            value=layer_i.get('vs', 1000 + i * 300),
            min=0,
            max=3500,
            step=100,
            description=f'Vs (m/s):'
        )
        
        # Density
        rho_slider = widgets.FloatSlider(
            value=layer_i.get('rho', 2000 + i * 100),
            min=1000,
            max=3000,
            step=50,
            description=f'ρ (kg/m³):'
        )
        
        # Group widgets for this layer
        layer_box = widgets.VBox([
            widgets.HTML(f"<b>Layer {i+1}</b>"),
            depth_slider,
            vp_slider,
            vs_slider,
            rho_slider,
            widgets.HTML("<hr>")
        ])
        
        layer_widgets.append({
            'depth': depth_slider,
            'vp': vp_slider,
            'vs': vs_slider,
            'rho': rho_slider,
            'box': layer_box
        })
    
    return layer_widgets

# Create widgets for initial number of layers
layer_widgets = create_layer_widgets(params.get('num_layers', 3))

# Display layer property widgets
for layer in layer_widgets:
    display(layer['box'])

# Function to update layer properties
def update_layer_properties():
    # Recreate widgets if number of layers changed
    global layer_widgets
    if len(layer_widgets) != params['num_layers']:
        layer_widgets = create_layer_widgets(params['num_layers'])
        # Clear output and re-display
        from IPython.display import clear_output
        clear_output(wait=True)
        display_all_widgets()
    
    # Update layer properties in params
    params['layers'] = []
    for i, layer in enumerate(layer_widgets):
        depth_pct = layer['depth'].value
        params['layers'].append({
            'depth': depth_pct / 100 * params['zmax'],  # Convert percentage to actual depth
            'vp': layer['vp'].value,
            'vs': layer['vs'].value,
            'rho': layer['rho'].value
        })
    
    print("Layer properties updated")
    for i, layer in enumerate(params['layers']):
        print(f"Layer {i+1}: depth={layer['depth']:.1f}m, Vp={layer['vp']}m/s, Vs={layer['vs']}m/s, ρ={layer['rho']}kg/m³")

# Function to display all widgets
def display_all_widgets():
    # Display dimension sliders
    display(x_slider, y_slider, z_slider, layer_slider)
    display(update_button)
    
    # Display layer widgets
    for layer in layer_widgets:
        display(layer['box'])
    
    # Display update layers button
    display(update_layers_button)

# Button to update layers
update_layers_button = widgets.Button(description="Update Layer Properties")
update_layers_button.on_click(lambda b: update_layer_properties())
display(update_layers_button)

## 4. Source Configuration

In [None]:
# Create widgets for source parameters
source_x_slider = widgets.FloatSlider(
    value=params.get('source_x', params['xmax']/2), 
    min=0, 
    max=params['xmax'], 
    step=params['xmax']/20,
    description='X position (m):'
)

source_y_slider = widgets.FloatSlider(
    value=params.get('source_y', params['ymax']/2), 
    min=0, 
    max=params['ymax'], 
    step=params['ymax']/20,
    description='Y position (m):'
)

source_z_slider = widgets.FloatSlider(
    value=params.get('source_z', 10), 
    min=0, 
    max=params['zmax']/5,  # Typically source is not too deep 
    step=10,
    description='Z position (m):'
)

source_freq_slider = widgets.FloatSlider(
    value=params.get('source_freq', 10), 
    min=1, 
    max=50, 
    step=1,
    description='Frequency (Hz):'
)

source_type_dropdown = widgets.Dropdown(
    options=['explosion', 'force_solution'],
    value=params.get('source_type', 'explosion'),
    description='Source type:',
)

# Display source widgets
display(widgets.HTML("<h3>Source Configuration</h3>"))
display(source_x_slider, source_y_slider, source_z_slider, source_freq_slider, source_type_dropdown)

# Function to update source parameters
def update_source_params():
    params['source_x'] = source_x_slider.value
    params['source_y'] = source_y_slider.value
    params['source_z'] = source_z_slider.value
    params['source_freq'] = source_freq_slider.value
    params['source_type'] = source_type_dropdown.value
    
    print("Source parameters updated:")
    print(f"  Position: ({params['source_x']}, {params['source_y']}, {params['source_z']}) m")
    print(f"  Frequency: {params['source_freq']} Hz")
    print(f"  Type: {params['source_type']}")

# Button to update source parameters
update_source_button = widgets.Button(description="Update Source")
update_source_button.on_click(lambda b: update_source_params())
display(update_source_button)

## 5. Receiver Configuration

### 5.1 Geophone Array

In [None]:
# Create widgets for geophone array
layout_dropdown = widgets.Dropdown(
    options=['line', 'grid'],
    value=params.get('receiver_layout', 'line'),
    description='Layout:',
)

num_receivers_slider = widgets.IntSlider(
    value=params.get('num_receivers', 20),
    min=5,
    max=100,
    step=5,
    description='Count:'
)

receiver_spacing_slider = widgets.FloatSlider(
    value=params.get('receiver_spacing', 50),
    min=10,
    max=200,
    step=10,
    description='Spacing (m):'
)

receiver_depth_slider = widgets.FloatSlider(
    value=params.get('receiver_depth', 0),
    min=0,
    max=100,
    step=5,
    description='Depth (m):'
)

# Layout start position
start_x_slider = widgets.FloatSlider(
    value=params.get('receiver_start_x', params['xmax']/4),
    min=0,
    max=params['xmax'],
    step=params['xmax']/20,
    description='Start X (m):'
)

start_y_slider = widgets.FloatSlider(
    value=params.get('receiver_start_y', params['ymax']/2),
    min=0,
    max=params['ymax'],
    step=params['ymax']/20,
    description='Start Y (m):'
)

# For grid layout, number of rows and columns
grid_rows_slider = widgets.IntSlider(
    value=params.get('grid_rows', 5),
    min=2,
    max=20,
    step=1,
    description='Grid Rows:'
)

grid_cols_slider = widgets.IntSlider(
    value=params.get('grid_cols', 5),
    min=2,
    max=20,
    step=1,
    description='Grid Columns:'
)

# Display geophone widgets
display(widgets.HTML("<h3>Geophone Array Configuration</h3>"))
display(layout_dropdown, num_receivers_slider, receiver_spacing_slider, receiver_depth_slider)
display(start_x_slider, start_y_slider)
display(grid_rows_slider, grid_cols_slider)

# Function to update geophone parameters
def update_geophone_params():
    params['receiver_layout'] = layout_dropdown.value
    params['num_receivers'] = num_receivers_slider.value
    params['receiver_spacing'] = receiver_spacing_slider.value
    params['receiver_depth'] = receiver_depth_slider.value
    params['receiver_start_x'] = start_x_slider.value
    params['receiver_start_y'] = start_y_slider.value
    params['grid_rows'] = grid_rows_slider.value
    params['grid_cols'] = grid_cols_slider.value
    
    # If grid layout, total receivers = rows * cols
    if params['receiver_layout'] == 'grid':
        params['num_receivers'] = params['grid_rows'] * params['grid_cols']
    
    print("Geophone parameters updated:")
    print(f"  Layout: {params['receiver_layout']}")
    print(f"  Number of receivers: {params['num_receivers']}")
    print(f"  Spacing: {params['receiver_spacing']} m")
    print(f"  Start position: ({params['receiver_start_x']}, {params['receiver_start_y']}) m")
    print(f"  Depth: {params['receiver_depth']} m")
    if params['receiver_layout'] == 'grid':
        print(f"  Grid: {params['grid_rows']} rows × {params['grid_cols']} columns")

# Button to update geophone parameters
update_geophone_button = widgets.Button(description="Update Geophones")
update_geophone_button.on_click(lambda b: update_geophone_params())
display(update_geophone_button)

### 5.2 DAS Fiber Configuration

In [None]:
# Create widgets for DAS fiber
das_enabled_checkbox = widgets.Checkbox(
    value=params.get('das_enabled', True),
    description='Enable DAS fiber',
)

das_layout_dropdown = widgets.Dropdown(
    options=['straight', 'zigzag'],
    value=params.get('das_layout', 'straight'),
    description='Layout:',
)

das_channels_slider = widgets.IntSlider(
    value=params.get('das_channels', 40),
    min=10,
    max=200,
    step=10,
    description='Channels:'
)

das_spacing_slider = widgets.FloatSlider(
    value=params.get('das_spacing', 10),
    min=1,
    max=50,
    step=1,
    description='Spacing (m):'
)

das_gauge_length_slider = widgets.FloatSlider(
    value=params.get('das_gauge_length', 10),
    min=1,
    max=50,
    step=1,
    description='Gauge Length (m):'
)

das_depth_slider = widgets.FloatSlider(
    value=params.get('das_depth', 0),
    min=0,
    max=100,
    step=5,
    description='Depth (m):'
)

# DAS fiber start position
das_start_x_slider = widgets.FloatSlider(
    value=params.get('das_start_x', params['xmax']/4),
    min=0,
    max=params['xmax'],
    step=params['xmax']/20,
    description='Start X (m):'
)

das_start_y_slider = widgets.FloatSlider(
    value=params.get('das_start_y', params['ymax']/2 + 50),
    min=0,
    max=params['ymax'],
    step=params['ymax']/20,
    description='Start Y (m):'
)

# Display DAS fiber widgets
display(widgets.HTML("<h3>DAS Fiber Configuration</h3>"))
display(das_enabled_checkbox)
display(das_layout_dropdown, das_channels_slider, das_spacing_slider, das_gauge_length_slider, das_depth_slider)
display(das_start_x_slider, das_start_y_slider)

# Function to update DAS parameters
def update_das_params():
    params['das_enabled'] = das_enabled_checkbox.value
    params['das_layout'] = das_layout_dropdown.value
    params['das_channels'] = das_channels_slider.value
    params['das_spacing'] = das_spacing_slider.value
    params['das_gauge_length'] = das_gauge_length_slider.value
    params['das_depth'] = das_depth_slider.value
    params['das_start_x'] = das_start_x_slider.value
    params['das_start_y'] = das_start_y_slider.value
    
    print("DAS parameters updated:")
    print(f"  Enabled: {params['das_enabled']}")
    if params['das_enabled']:
        print(f"  Layout: {params['das_layout']}")
        print(f"  Channels: {params['das_channels']}")
        print(f"  Channel spacing: {params['das_spacing']} m")
        print(f"  Gauge length: {params['das_gauge_length']} m")
        print(f"  Start position: ({params['das_start_x']}, {params['das_start_y']}) m")
        print(f"  Depth: {params['das_depth']} m")

# Button to update DAS parameters
update_das_button = widgets.Button(description="Update DAS Fiber")
update_das_button.on_click(lambda b: update_das_params())
display(update_das_button)

## 6. Model Visualization

In [None]:
# Function to generate receiver coordinates based on parameters
def generate_receiver_coords():
    # Update all parameters first
    update_geophone_params()
    update_das_params()
    
    # Generate geophone coordinates
    if params['receiver_layout'] == 'line':
        # Linear array along x-axis
        x_coords = np.linspace(
            params['receiver_start_x'],
            params['receiver_start_x'] + (params['num_receivers'] - 1) * params['receiver_spacing'],
            params['num_receivers']
        )
        y_coords = np.ones(params['num_receivers']) * params['receiver_start_y']
        z_coords = np.ones(params['num_receivers']) * params['receiver_depth']
        
        geophone_coords = np.column_stack((x_coords, y_coords, z_coords))
        
    elif params['receiver_layout'] == 'grid':
        # 2D grid array
        rows = params['grid_rows']
        cols = params['grid_cols']
        
        # Create grid coordinates
        x_coords = np.linspace(
            params['receiver_start_x'],
            params['receiver_start_x'] + (cols - 1) * params['receiver_spacing'],
            cols
        )
        y_coords = np.linspace(
            params['receiver_start_y'],
            params['receiver_start_y'] + (rows - 1) * params['receiver_spacing'],
            rows
        )
        
        # Create meshgrid
        xx, yy = np.meshgrid(x_coords, y_coords)
        
        # Flatten grid to 1D arrays
        x_flat = xx.flatten()
        y_flat = yy.flatten()
        z_flat = np.ones_like(x_flat) * params['receiver_depth']
        
        geophone_coords = np.column_stack((x_flat, y_flat, z_flat))
    
    # Generate DAS fiber coordinates if enabled
    if params['das_enabled']:
        if params['das_layout'] == 'straight':
            # Straight fiber along x-axis
            x_coords = np.linspace(
                params['das_start_x'],
                params['das_start_x'] + (params['das_channels'] - 1) * params['das_spacing'],
                params['das_channels']
            )
            y_coords = np.ones(params['das_channels']) * params['das_start_y']
            z_coords = np.ones(params['das_channels']) * params['das_depth']
            
            das_coords = np.column_stack((x_coords, y_coords, z_coords))
            
        elif params['das_layout'] == 'zigzag':
            # Zigzag pattern (alternating directions)
            segments = 4  # Number of segments in zigzag
            points_per_segment = params['das_channels'] // segments
            
            # Initialize arrays
            x_coords = []
            y_coords = []
            
            # Create zigzag pattern
            x_start = params['das_start_x']
            y_start = params['das_start_y']
            
            for i in range(segments):
                if i % 2 == 0:  # Even segments go along x
                    segment_x = np.linspace(
                        x_start,
                        x_start + points_per_segment * params['das_spacing'],
                        points_per_segment
                    )
                    segment_y = np.ones(points_per_segment) * y_start
                    
                    x_start = segment_x[-1]
                else:  # Odd segments go along y
                    segment_x = np.ones(points_per_segment) * x_start
                    segment_y = np.linspace(
                        y_start,
                        y_start + points_per_segment * params['das_spacing'],
                        points_per_segment
                    )
                    
                    y_start = segment_y[-1]
                
                x_coords.extend(segment_x)
                y_coords.extend(segment_y)
            
            # Convert to numpy arrays
            x_coords = np.array(x_coords[:params['das_channels']])  # Ensure exact channel count
            y_coords = np.array(y_coords[:params['das_channels']])
            z_coords = np.ones(params['das_channels']) * params['das_depth']
            
            das_coords = np.column_stack((x_coords, y_coords, z_coords))
    else:
        das_coords = np.array([])  # Empty array if DAS is disabled
    
    return geophone_coords, das_coords

# Function to visualize model
def visualize_model():
    # Update all parameters first
    update_model_dimensions()
    update_layer_properties()
    update_source_params()
    
    # Generate receiver coordinates
    geophone_coords, das_coords = generate_receiver_coords()
    
    # Create figure with multiple subplots
    fig = plt.figure(figsize=(15, 10))
    
    # 1. Top-down view (X-Y plane)
    ax1 = fig.add_subplot(221)
    ax1.set_title('Top View (X-Y Plane)')
    ax1.set_xlabel('X (m)')
    ax1.set_ylabel('Y (m)')
    
    # Plot model boundaries
    ax1.plot([0, params['xmax'], params['xmax'], 0, 0], 
             [0, 0, params['ymax'], params['ymax'], 0], 'k-', alpha=0.5)
    
    # Plot source
    ax1.plot(params['source_x'], params['source_y'], 'r*', markersize=15, label='Source')
    
    # Plot geophones
    if len(geophone_coords) > 0:
        ax1.plot(geophone_coords[:, 0], geophone_coords[:, 1], 'bo', markersize=5, label='Geophones')
    
    # Plot DAS fiber
    if params['das_enabled'] and len(das_coords) > 0:
        ax1.plot(das_coords[:, 0], das_coords[:, 1], 'g-', linewidth=2, label='DAS Fiber')
        ax1.plot(das_coords[:, 0], das_coords[:, 1], 'go', markersize=3)
    
    ax1.legend()
    ax1.grid(True)
    
    # 2. Side view (X-Z plane)
    ax2 = fig.add_subplot(222)
    ax2.set_title('Side View (X-Z Plane)')
    ax2.set_xlabel('X (m)')
    ax2.set_ylabel('Z (m)')
    
    # Plot model boundaries
    ax2.plot([0, params['xmax'], params['xmax'], 0, 0], 
             [0, 0, params['zmax'], params['zmax'], 0], 'k-', alpha=0.5)
    
    # Plot layers with colors
    colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(params['layers'])))
    
    # Sort layers by depth
    sorted_layers = sorted(params['layers'], key=lambda x: x['depth'])
    
    # Calculate Vp/Vs ratio for each layer for visualization
    vp_vs_ratios = [layer['vp']/max(layer['vs'], 1) for layer in sorted_layers]
    
    # Plot the layers
    z_prev = 0
    for i, layer in enumerate(sorted_layers):
        # Fill between layers
        ax2.fill_between([0, params['xmax']], z_prev, layer['depth'], 
                         color=colors[i], alpha=0.5, label=f'Layer {i+1}')
        # Add text with Vp, Vs
        ax2.text(params['xmax']/2, (z_prev + layer['depth'])/2, 
                 f"Vp={layer['vp']}\nVs={layer['vs']}\nρ={layer['rho']}", 
                 ha='center', va='center')
        z_prev = layer['depth']
    
    # Fill last layer to bottom
    if len(sorted_layers) > 0:
        ax2.fill_between([0, params['xmax']], z_prev, params['zmax'], 
                         color=colors[-1], alpha=0.5)
        ax2.text(params['xmax']/2, (z_prev + params['zmax'])/2, 
                 f"Vp={sorted_layers[-1]['vp']}\nVs={sorted_layers[-1]['vs']}\nρ={sorted_layers[-1]['rho']}", 
                 ha='center', va='center')
    
    # Plot source
    ax2.plot(params['source_x'], params['source_z'], 'r*', markersize=15)
    
    # Plot geophones
    if len(geophone_coords) > 0:
        ax2.plot(geophone_coords[:, 0], geophone_coords[:, 2], 'bo', markersize=5)
    
    # Plot DAS fiber
    if params['das_enabled'] and len(das_coords) > 0:
        # Project onto X-Z plane
        sorted_idx = np.argsort(das_coords[:, 0])
        ax2.plot(das_coords[sorted_idx, 0], das_coords[sorted_idx, 2], 'g-', linewidth=2)
        ax2.plot(das_coords[:, 0], das_coords[:, 2], 'go', markersize=3)
    
    ax2.invert_yaxis()  # Depth increases downward
    ax2.grid(True)
    
    # 3. Vp model cross-section
    ax3 = fig.add_subplot(223)
    ax3.set_title('Vp Model Cross-Section')
    ax3.set_xlabel('X (m)')
    ax3.set_ylabel('Z (m)')
    
    # Create a simple Vp model visualization
    nx = 100
    nz = 50
    x = np.linspace(0, params['xmax'], nx)
    z = np.linspace(0, params['zmax'], nz)
    xx, zz = np.meshgrid(x, z)
    
    # Initialize velocity model with default velocity
    vp_model = np.ones((nz, nx)) * (sorted_layers[0]['vp'] if len(sorted_layers) > 0 else 2000)
    
    # Fill in layers
    z_prev = 0
    for i, layer in enumerate(sorted_layers):
        # Convert depths to indices
        z_idx_prev = int(z_prev / params['zmax'] * nz)
        z_idx = int(layer['depth'] / params['zmax'] * nz)
        
        # Fill in this layer with Vp
        vp_model[z_idx_prev:z_idx, :] = layer['vp']
        
        z_prev = layer['depth']
    
    # Fill last layer to bottom
    if len(sorted_layers) > 0:
        z_idx_prev = int(z_prev / params['zmax'] * nz)
        vp_model[z_idx_prev:, :] = sorted_layers[-1]['vp']
    
    # Plot the model
    im = ax3.imshow(vp_model, extent=[0, params['xmax'], params['zmax'], 0], 
                    aspect='auto', cmap='viridis', vmin=1500, vmax=6000)
    plt.colorbar(im, ax=ax3, label='Vp (m/s)')
    
    # Plot source and receivers
    ax3.plot(params['source_x'], params['source_z'], 'r*', markersize=10)
    if len(geophone_coords) > 0:
        ax3.plot(geophone_coords[:, 0], geophone_coords[:, 2], 'wo', markersize=3)
    
    # Plot DAS fiber
    if params['das_enabled'] and len(das_coords) > 0:
        # Project onto X-Z plane
        sorted_idx = np.argsort(das_coords[:, 0])
        ax3.plot(das_coords[sorted_idx, 0], das_coords[sorted_idx, 2], 'g-', linewidth=2)
    
    # 4. Source-Receiver distance plot
    ax4 = fig.add_subplot(224)
    ax4.set_title('Source-Receiver Distances')
    ax4.set_xlabel('Receiver Number')
    ax4.set_ylabel('Distance from Source (m)')
    
    # Calculate distances from source to geophones
    if len(geophone_coords) > 0:
        geo_distances = np.sqrt(
            (geophone_coords[:, 0] - params['source_x'])**2 + 
            (geophone_coords[:, 1] - params['source_y'])**2 + 
            (geophone_coords[:, 2] - params['source_z'])**2
        )
        ax4.plot(range(1, len(geo_distances)+1), geo_distances, 'bo-', label='Geophones')
    
    # Calculate distances from source to DAS channels
    if params['das_enabled'] and len(das_coords) > 0:
        das_distances = np.sqrt(
            (das_coords[:, 0] - params['source_x'])**2 + 
            (das_coords[:, 1] - params['source_y'])**2 + 
            (das_coords[:, 2] - params['source_z'])**2
        )
        ax4.plot(range(1, len(das_distances)+1), das_distances, 'go-', label='DAS Channels')
    
    ax4.grid(True)
    ax4.legend()
    
    plt.tight_layout()
    
    return geophone_coords, das_coords

# Button to visualize model
visualize_button = widgets.Button(description="Visualize Model", button_style='success')
visualize_button.on_click(lambda b: visualize_model())
display(visualize_button)

## 7. Export Configuration Files

In [None]:
# Function to export parameter files
def export_parameter_files():
    # Ensure all parameters are updated
    update_model_dimensions()
    update_layer_properties()
    update_source_params()
    update_geophone_params()
    update_das_params()
    
    # Get receiver coordinates
    geophone_coords, das_coords = generate_receiver_coords()
    
    # Save parameters to JSON file
    param_file = os.path.join(output_dir, "model_parameters.json")
    with open(param_file, 'w') as f:
        json.dump(params, f, indent=2)
    
    # Create the stations file
    stations_df = pd.DataFrame({
        'name': [f'ST{i:03d}' for i in range(1, len(geophone_coords) + 1)],
        'network': ['GE'] * len(geophone_coords),
        'lat': geophone_coords[:, 0],  # X coordinates
        'lon': geophone_coords[:, 1],  # Y coordinates
        'elevation': geophone_coords[:, 2],  # Z coordinates
        'burial': [0.0] * len(geophone_coords)
    })
    
    # Save stations to file
    stations_file = os.path.join(output_dir, 'STATIONS')
    stations_df.to_csv(stations_file, sep=' ', index=False, header=False)
    
    # Save DAS channels if enabled
    if params['das_enabled'] and len(das_coords) > 0:
        das_df = pd.DataFrame({
            'channel': [f'DAS{i:03d}' for i in range(1, len(das_coords) + 1)],
            'x': das_coords[:, 0],
            'y': das_coords[:, 1],
            'z': das_coords[:, 2]
        })
        
        das_file = os.path.join(output_dir, 'das_channels.csv')
        das_df.to_csv(das_file, index=False)
    
    # Use param_manager to create SPECFEM configuration files
    # Create SOURCE file
    source_dict = {
        'latitude': params['source_y'],  # In SPECFEM, lat/lon are actually y/x in UTM
        'longitude': params['source_x'],
        'depth': params['source_z'],
        'Mrr': 1.0 if params['source_type'] == 'explosion' else 0.0,
        'Mtt': 1.0 if params['source_type'] == 'explosion' else 0.0,
        'Mpp': 1.0 if params['source_type'] == 'explosion' else 0.0,
        'Mrt': 0.0,
        'Mrp': 0.0,
        'Mtp': 0.0,
        'time_shift': 0.0,
        'half_duration': 1.0 / params['source_freq']
    }
    
    # Write SOURCE file
    source_file = os.path.join(output_dir, 'SOURCE')
    param_manager.write_source_file(source_file, source_dict)
    
    # Create interface files for layers
    if len(params['layers']) > 0:
        for i, layer in enumerate(params['layers']):
            # Create simple interface file with constant depth
            interface_data = np.column_stack([
                np.linspace(0, params['xmax'], 10),  # X coordinates
                np.ones(10) * params['ymax'] / 2,    # Y coordinates (middle of model)
                np.ones(10) * layer['depth']         # Z coordinates (constant depth)
            ])
            
            # Save interface file
            interface_file = os.path.join(output_dir, f'interface{i+1}.dat')
            np.savetxt(interface_file, interface_data, fmt='%.2f')
        
        # Create combined interfaces file
        interfaces_file = os.path.join(output_dir, 'interfaces.dat')
        with open(interfaces_file, 'w') as f:
            f.write(f"{len(params['layers'])}\n")
            for i in range(len(params['layers'])):
                f.write(f"interface{i+1}.dat\n")
    
    # Create Par_file and Mesh_Par_file (basic versions)
    par_file_dict = {
        'SIMULATION_TYPE': 1,  # Forward simulation
        'NPROC': 4,           # Number of processors
        'NSTEP': int(params['simulation_time'] / params['time_step']),
        'DT': params['time_step'],
        'MODEL': 'default',   # Default model
        'SAVE_FORWARD': False
    }
    
    par_file = os.path.join(output_dir, 'Par_file')
    param_manager.write_par_file(par_file, par_file_dict)
    
    mesh_par_dict = {
        'NPROC_XI': 2,
        'NPROC_ETA': 2,
        'NGNOD': 8,
        'NEX_XI': 40,
        'NEX_ETA': 40,
        'NER': 15
    }
    
    mesh_par_file = os.path.join(output_dir, 'Mesh_Par_file')
    param_manager.write_mesh_par_file(mesh_par_file, mesh_par_dict)
    
    print(f"\nConfiguration files exported to: {output_dir}")
    for file in os.listdir(output_dir):
        print(f"  - {file}")
    
    return output_dir

# Button to export files
export_button = widgets.Button(description="Export Configuration Files", button_style='primary')
export_button.on_click(lambda b: export_parameter_files())
display(export_button)

## 8. Running Simulations

For the next steps, you'll need to run a full SPECFEM3D simulation using these configuration files. Use the `run_specfem.py` script with the exported parameters.

In [None]:
# Generate command for running simulation
def generate_simulation_command():
    config_dir = os.path.join(output_dir, "model_parameters.json")
    if os.path.exists(config_dir):
        cmd = f"cd {project_root} && python run_specfem.py --param_file={config_dir}"
        print("To run the simulation, execute the following command in a terminal:")
        print(f"\n{cmd}\n")
        print("Note: This simulation may take significant time depending on your hardware.")
    else:
        print("Please export configuration files first.")

# Button to get simulation command
sim_cmd_button = widgets.Button(description="Get Simulation Command")
sim_cmd_button.on_click(lambda b: generate_simulation_command())
display(sim_cmd_button)