# Evaluation and Visualization of Seismic Interpolation Model

This notebook demonstrates how to evaluate and visualize the results of the trained seismic interpolation model. We'll go through the following steps:

1. Load the trained model
2. Load test data and perform inference
3. Compute comprehensive evaluation metrics
4. Create visualizations for qualitative assessment
5. Analyze model performance under different masking scenarios
6. Generate publication-quality figures

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import mlflow
import mlflow.pytorch
from scipy import signal
import seaborn as sns

# Add the project root to path for imports
sys.path.append('..')

# Import project modules
from src.models.transformer import StorSeismicBERTModel
from src.preprocessing.dataset import TransformerSeismicDataset
from src.evaluation.metrics import (
    mean_squared_error, mean_absolute_error, signal_to_noise_ratio,
    correlation_coefficient, coherence_measure, frequency_domain_error,
    amplitude_ratio, evaluate_model
)
from src.utils.logging_utils import setup_logging
from src.utils.plot_utils import (
    plot_trace_comparison, plot_gather_comparison, plot_seismic_trace,
    plot_frequency_comparison, plot_scatter_true_vs_pred, plot_metrics_comparison
)

# Set up logging
logger = setup_logging(level='INFO')

## Load Trained Model

Load the transformer model that was trained in the previous notebook.

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths
data_dir = "../data/synthetic/processed/datasets"
models_dir = "../experiments/models"
results_dir = "../experiments/results"
figures_dir = "../papers/figures"

# Ensure directories exist
Path(results_dir).mkdir(parents=True, exist_ok=True)
Path(figures_dir).mkdir(parents=True, exist_ok=True)

# Load model
experiment_name = "seismic_interpolation_transformer"
model_path = Path(models_dir) / f"{experiment_name}_final_model.pt"

# Check if model exists
if not model_path.exists():
    print(f"Model not found at {model_path}. Please run the training notebook first.")
else:
    # Load model checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    model_config = checkpoint['model_config']
    dataset_params = checkpoint['dataset_params']
    
    # Print model config
    print("Model Configuration:")
    for key, value in model_config.items():
        print(f"  {key}: {value}")
        
    # Initialize model with the same configuration
    model = StorSeismicBERTModel(
        max_channels=model_config['max_channels'],
        time_steps=model_config['time_steps'],
        d_model=model_config['d_model'],
        nhead=model_config['nhead'],
        num_layers=model_config['num_layers'],
        dim_feedforward=model_config['dim_feedforward'],
        dropout=model_config['dropout']
    )
    
    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"Loaded model from {model_path}")

## Load Test Data

Load the test data for evaluation.

In [None]:
# Load windowed data
geophone_windows = np.load(Path(data_dir) / "geophone_windows.npy")
das_windows = np.load(Path(data_dir) / "das_windows.npy")

# Load test indices
test_indices = np.load(Path(data_dir) / "test_indices.npy")

# Get test data
test_geo = geophone_windows[test_indices]
test_das = das_windows[test_indices]

print(f"Loaded {len(test_indices)} test samples")
print(f"Test geophone data shape: {test_geo.shape}")
print(f"Test DAS data shape: {test_das.shape}")

In [None]:
# Create test datasets with different mask ratios
mask_patterns = ['random', 'regular', 'block']
mask_ratios = [0.1, 0.3, 0.5, 0.7]  # From 10% to 70% channels masked

# Create test datasets
test_datasets = {}
for pattern in mask_patterns:
    pattern_datasets = {}
    for ratio in mask_ratios:
        dataset = TransformerSeismicDataset(
            test_geo, test_das, 
            mask_ratio=ratio, 
            mask_pattern=pattern, 
            positional_encoding=True
        )
        pattern_datasets[ratio] = dataset
    test_datasets[pattern] = pattern_datasets

# Create test dataloader for the standard case (random, 30%)
batch_size = 32
standard_test_dataset = test_datasets['random'][0.3]
test_loader = DataLoader(standard_test_dataset, batch_size=batch_size, shuffle=False)

## Perform Inference on Test Data

Run inference on the test data to evaluate model performance.

In [None]:
# Run inference on one batch
def run_inference(model, batch, device):
    """Run inference on a batch of data."""
    model.eval()
    
    with torch.no_grad():
        # Unpack batch
        input_data, attention_mask, positions, target = batch
        input_data = input_data.to(device)
        attention_mask = attention_mask.to(device)
        positions = positions.to(device) if positions is not None else None
        target = target.to(device)
        
        # Forward pass
        outputs = model(input_data, attention_mask=attention_mask, position_ids=positions)
        
        # Extract geophone predictions
        n_das_channels = outputs.shape[1] - target.shape[1]
        predicted_geophone = outputs[:, n_das_channels:, :]
        
        # Get mask for geophone channels (masked channels have attention_mask=0)
        geo_mask = ~attention_mask[:, n_das_channels:].bool()
        
        return {
            'input_data': input_data.cpu().numpy(),
            'attention_mask': attention_mask.cpu().numpy(),
            'target': target.cpu().numpy(),
            'predictions': predicted_geophone.cpu().numpy(),
            'geo_mask': geo_mask.cpu().numpy(),
            'n_das_channels': n_das_channels
        }

# Get a test batch
test_batch = next(iter(test_loader))

# Run inference
results = run_inference(model, test_batch, device)

print(f"Ran inference on batch of {results['predictions'].shape[0]} samples")

## Compute Evaluation Metrics

Compute comprehensive metrics to evaluate the model's performance.

In [None]:
# Compute metrics for the first sample
sample_idx = 0

# Get masked channels for this sample
geo_mask = results['geo_mask'][sample_idx]
target = results['target'][sample_idx]
predictions = results['predictions'][sample_idx]

# Compute metrics for masked channels only
masked_indices = np.where(geo_mask)
masked_targets = target[masked_indices]
masked_predictions = predictions[masked_indices]

# Flatten the arrays (each masked channel becomes a row)
masked_targets_flat = masked_targets.reshape(-1, masked_targets.shape[-1])
masked_predictions_flat = masked_predictions.reshape(-1, masked_predictions.shape[-1])

# Compute metrics
mse = mean_squared_error(masked_targets_flat, masked_predictions_flat)
mae = mean_absolute_error(masked_targets_flat, masked_predictions_flat)
snr = signal_to_noise_ratio(masked_targets_flat, masked_predictions_flat)
corr = correlation_coefficient(masked_targets_flat.flatten(), masked_predictions_flat.flatten())
amp_ratio = amplitude_ratio(masked_targets_flat, masked_predictions_flat)

# Print metrics
print(f"Metrics for Sample {sample_idx}:")
print(f"  Mean Squared Error: {mse:.6f}")
print(f"  Mean Absolute Error: {mae:.6f}")
print(f"  Signal-to-Noise Ratio: {snr:.2f} dB")
print(f"  Correlation Coefficient: {corr:.4f}")
print(f"  Amplitude Ratio: {amp_ratio:.4f}")

In [None]:
# Compute metrics on the full test set
print("Computing metrics on full test set...")
full_metrics = evaluate_model(model, test_loader, device)

print("Full Test Set Metrics:")
for metric, value in full_metrics.items():
    print(f"  {metric}: {value:.6f}")

## Visualize Model Predictions

Create visualizations to qualitatively assess the model's predictions.

In [None]:
# Visualize predictions for the first sample
sample_idx = 0

# Get sample data
input_data = results['input_data'][sample_idx]
target = results['target'][sample_idx]
predictions = results['predictions'][sample_idx]
geo_mask = results['geo_mask'][sample_idx]
n_das_channels = results['n_das_channels']

# Extract DAS data
das_data = input_data[:n_das_channels]

# Create masked geophone data (with zeros for masked channels)
masked_geophone = input_data[n_das_channels:].copy()

# Plot the data
plt.figure(figsize=(18, 12))

# Plot DAS data
plt.subplot(4, 1, 1)
plt.imshow(das_data, aspect='auto', cmap='seismic')
plt.title("DAS Data (Input)")
plt.ylabel("Channel")
plt.colorbar()

# Plot masked geophone data
plt.subplot(4, 1, 2)
plt.imshow(masked_geophone, aspect='auto', cmap='seismic')
plt.title("Masked Geophone Data (Input)")
plt.ylabel("Channel")
plt.colorbar()

# Plot predicted geophone data
plt.subplot(4, 1, 3)
plt.imshow(predictions, aspect='auto', cmap='seismic')
plt.title("Predicted Geophone Data (Output)")
plt.ylabel("Channel")
plt.colorbar()

# Plot true geophone data
plt.subplot(4, 1, 4)
plt.imshow(target, aspect='auto', cmap='seismic')
plt.title("True Geophone Data (Ground Truth)")
plt.xlabel("Time Sample")
plt.ylabel("Channel")
plt.colorbar()

plt.tight_layout()
plt.savefig(Path(figures_dir) / "full_prediction_visualization.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Visualize individual traces for comparison
# Choose a masked channel
masked_channel_idx = np.where(geo_mask)[0][0]

plt.figure(figsize=(12, 6))
plt.plot(target[masked_channel_idx], 'b-', label='True')
plt.plot(predictions[masked_channel_idx], 'r-', label='Predicted')
plt.title(f"True vs Predicted - Channel {masked_channel_idx} (Masked)")
plt.xlabel("Time Sample")
plt.ylabel("Amplitude")
plt.legend()
plt.grid(True)

# Add metrics text
mse_val = mean_squared_error(target[masked_channel_idx], predictions[masked_channel_idx])
snr_val = signal_to_noise_ratio(target[masked_channel_idx], predictions[masked_channel_idx])
corr_val = correlation_coefficient(target[masked_channel_idx], predictions[masked_channel_idx])

metrics_text = f'MSE: {mse_val:.6f}\nSNR: {snr_val:.2f} dB\nCorr: {corr_val:.4f}'
plt.text(0.02, 0.98, metrics_text, transform=plt.gca().transAxes,
         verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.savefig(Path(figures_dir) / "trace_comparison.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Plot frequency comparison for a masked channel
f, cxy = plot_frequency_comparison(
    target[masked_channel_idx], 
    predictions[masked_channel_idx],
    fs=100,  # Assuming 100 Hz sampling rate
    title=f"Frequency Domain Comparison - Channel {masked_channel_idx}",
    save_path=Path(figures_dir) / "frequency_comparison.png"
)

## Analyze Performance with Different Masking Patterns and Ratios

Evaluate how the model performs with different masking patterns and ratios.

In [None]:
# Function to evaluate on a dataset
def evaluate_dataset(model, dataset, device, num_samples=50):
    """Evaluate model on a dataset."""
    model.eval()
    metrics = {
        'mse': [],
        'snr': [],
        'corr': []
    }
    
    # Create a small dataloader
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    with torch.no_grad():
        sample_count = 0
        for batch in loader:
            results = run_inference(model, batch, device)
            
            for i in range(len(results['predictions'])):
                # Get masked channels for this sample
                geo_mask = results['geo_mask'][i]
                target = results['target'][i]
                predictions = results['predictions'][i]
                
                # Compute metrics for masked channels only
                masked_indices = np.where(geo_mask)
                if len(masked_indices[0]) > 0:  # Make sure there are masked channels
                    masked_targets = target[masked_indices]
                    masked_predictions = predictions[masked_indices]
                    
                    # Compute metrics
                    mse = mean_squared_error(masked_targets, masked_predictions)
                    snr = signal_to_noise_ratio(masked_targets, masked_predictions)
                    corr = correlation_coefficient(masked_targets.flatten(), masked_predictions.flatten())
                    
                    metrics['mse'].append(mse)
                    metrics['snr'].append(snr)
                    metrics['corr'].append(corr)
                    
                sample_count += 1
                if sample_count >= num_samples:
                    break
            
            if sample_count >= num_samples:
                break
    
    # Compute average metrics
    avg_metrics = {}
    for key, values in metrics.items():
        avg_metrics[key] = np.mean(values)
    
    return avg_metrics

In [None]:
# Evaluate on different masking patterns and ratios
pattern_metrics = {}

for pattern in mask_patterns:
    ratio_metrics = {}
    for ratio in mask_ratios:
        print(f"Evaluating {pattern} pattern with {ratio*100:.0f}% masking...")
        dataset = test_datasets[pattern][ratio]
        metrics = evaluate_dataset(model, dataset, device, num_samples=20)  # Using a small number for speed
        ratio_metrics[ratio] = metrics
    pattern_metrics[pattern] = ratio_metrics

In [None]:
# Plot metrics vs mask ratio for each pattern
metrics_to_plot = ['mse', 'snr', 'corr']
titles = {
    'mse': 'Mean Squared Error',
    'snr': 'Signal-to-Noise Ratio (dB)',
    'corr': 'Correlation Coefficient'
}

for metric_name in metrics_to_plot:
    plt.figure(figsize=(10, 6))
    
    for pattern in mask_patterns:
        ratios = list(pattern_metrics[pattern].keys())
        values = [pattern_metrics[pattern][ratio][metric_name] for ratio in ratios]
        plt.plot([r*100 for r in ratios], values, 'o-', label=pattern.capitalize())
    
    plt.title(f"{titles[metric_name]} vs Mask Ratio")
    plt.xlabel("Mask Ratio (%)")
    plt.ylabel(titles[metric_name])
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(Path(figures_dir) / f"{metric_name}_vs_mask_ratio.png", dpi=300, bbox_inches='tight')
    plt.show()

## Visualize Masking Patterns

Create visualizations to show how different masking patterns affect the predictions.

In [None]:
# Create a function to visualize masking patterns
def visualize_masking_pattern(pattern, ratio):
    """Visualize a masking pattern."""
    # Create a dataset with the specific pattern and ratio
    dataset = TransformerSeismicDataset(
        test_geo[:1], test_das[:1],  # Just use the first sample
        mask_ratio=ratio,
        mask_pattern=pattern,
        positional_encoding=True
    )
    
    # Get the sample
    input_data, attention_mask, positions, target = dataset[0]
    
    # Get the mask
    n_das_channels = input_data.shape[0] - target.shape[0]
    geo_mask = ~attention_mask[n_das_channels:].bool().numpy()
    
    # Create a visualization of the mask
    mask_viz = np.zeros((1, geo_mask.shape[0]))
    mask_viz[0, geo_mask] = 1
    
    return mask_viz

# Visualize different mask patterns with 30% ratio
ratio = 0.3
mask_visualizations = {}

for pattern in mask_patterns:
    mask_viz = visualize_masking_pattern(pattern, ratio)
    mask_visualizations[pattern] = mask_viz

# Plot masks
plt.figure(figsize=(12, 6))
for i, pattern in enumerate(mask_patterns):
    plt.subplot(len(mask_patterns), 1, i+1)
    plt.imshow(mask_visualizations[pattern], aspect='auto', cmap='binary')
    plt.title(f"{pattern.capitalize()} Masking Pattern ({ratio*100:.0f}% masked)")
    plt.ylabel("Mask")
    if i == len(mask_patterns) - 1:
        plt.xlabel("Channel Index")

plt.tight_layout()
plt.savefig(Path(figures_dir) / "masking_patterns.png", dpi=300, bbox_inches='tight')
plt.show()

## Create Publication-Quality Figures

Create high-quality figures for publication, showing the model's performance on representative examples.

In [None]:
# Create a full figure showing the entire pipeline
def create_pipeline_figure(sample_idx=0):
    """Create a figure showing the entire interpolation pipeline."""
    # Get the sample data
    input_data = results['input_data'][sample_idx]
    target = results['target'][sample_idx]
    predictions = results['predictions'][sample_idx]
    geo_mask = results['geo_mask'][sample_idx]
    n_das_channels = results['n_das_channels']
    
    # Extract DAS data
    das_data = input_data[:n_das_channels]
    
    # Create masked geophone data (with zeros for masked channels)
    masked_geophone = input_data[n_das_channels:].copy()
    
    # Create a figure showing the entire pipeline
    fig = plt.figure(figsize=(18, 14))
    gs = fig.add_gridspec(5, 2, height_ratios=[1, 1, 1, 1, 0.5])
    
    # Plot DAS data
    ax1 = fig.add_subplot(gs[0, :])
    im1 = ax1.imshow(das_data, aspect='auto', cmap='seismic')
    ax1.set_title("DAS Data (Input)", fontsize=14)
    ax1.set_ylabel("DAS Channel", fontsize=12)
    fig.colorbar(im1, ax=ax1)
    
    # Plot masked geophone data
    ax2 = fig.add_subplot(gs[1, :])
    im2 = ax2.imshow(masked_geophone, aspect='auto', cmap='seismic')
    ax2.set_title("Masked Geophone Data (Input)", fontsize=14)
    ax2.set_ylabel("Geophone Channel", fontsize=12)
    fig.colorbar(im2, ax=ax2)
    
    # Mark masked channels
    for i, masked in enumerate(geo_mask):
        if masked:
            ax2.axhline(i, color='red', linestyle='--', alpha=0.5)
    
    # Plot predicted geophone data
    ax3 = fig.add_subplot(gs[2, :])
    im3 = ax3.imshow(predictions, aspect='auto', cmap='seismic')
    ax3.set_title("Predicted Geophone Data (Model Output)", fontsize=14)
    ax3.set_ylabel("Geophone Channel", fontsize=12)
    fig.colorbar(im3, ax=ax3)
    
    # Plot true geophone data
    ax4 = fig.add_subplot(gs[3, :])
    im4 = ax4.imshow(target, aspect='auto', cmap='seismic')
    ax4.set_title("True Geophone Data (Ground Truth)", fontsize=14)
    ax4.set_xlabel("Time Sample", fontsize=12)
    ax4.set_ylabel("Geophone Channel", fontsize=12)
    fig.colorbar(im4, ax=ax4)
    
    # Plot comparison of a masked channel
    masked_channel_idx = np.where(geo_mask)[0][0]
    
    ax5 = fig.add_subplot(gs[4, 0])
    ax5.plot(target[masked_channel_idx], 'b-', label='True')
    ax5.plot(predictions[masked_channel_idx], 'r-', label='Predicted')
    ax5.set_title(f"Channel {masked_channel_idx} Comparison", fontsize=14)
    ax5.set_xlabel("Time Sample", fontsize=12)
    ax5.set_ylabel("Amplitude", fontsize=12)
    ax5.legend(fontsize=10)
    ax5.grid(True)
    
    # Add metrics for this channel
    mse_val = mean_squared_error(target[masked_channel_idx], predictions[masked_channel_idx])
    snr_val = signal_to_noise_ratio(target[masked_channel_idx], predictions[masked_channel_idx])
    corr_val = correlation_coefficient(target[masked_channel_idx], predictions[masked_channel_idx])
    
    # Plot frequency comparison
    ax6 = fig.add_subplot(gs[4, 1])
    fs = 100  # Assuming 100 Hz sampling rate
    f, psd_true = signal.welch(target[masked_channel_idx], fs, nperseg=256)
    f, psd_pred = signal.welch(predictions[masked_channel_idx], fs, nperseg=256)
    ax6.semilogy(f, psd_true, 'b-', label='True')
    ax6.semilogy(f, psd_pred, 'r-', label='Predicted')
    ax6.set_title(f"Frequency Content - Channel {masked_channel_idx}", fontsize=14)
    ax6.set_xlabel("Frequency (Hz)", fontsize=12)
    ax6.set_ylabel("PSD", fontsize=12)
    ax6.legend(fontsize=10)
    ax6.grid(True)
    
    # Add overall metrics text
    plt.figtext(0.5, 0.01, 
               f"MSE: {mse_val:.6f} | SNR: {snr_val:.2f} dB | Correlation: {corr_val:.4f}", 
               ha="center", fontsize=12, 
               bbox={"facecolor":"white", "alpha":0.5, "pad":5})
    
    plt.suptitle("Seismic Interpolation with Transformer Model", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    
    return fig

# Create and save the figure
pipeline_fig = create_pipeline_figure(sample_idx=0)
pipeline_fig.savefig(Path(figures_dir) / "complete_pipeline.png", dpi=300, bbox_inches='tight')
pipeline_fig.savefig(Path(figures_dir) / "complete_pipeline.pdf", format='pdf', bbox_inches='tight')

In [None]:
# Create a comparison figure for different masking patterns
def create_masking_comparison_figure():
    """Create a figure comparing model performance with different masking patterns."""
    # Set up the figure
    fig, axes = plt.subplots(3, 3, figsize=(16, 12))
    
    # For each pattern and a fixed ratio (30%)
    ratio = 0.3
    for i, pattern in enumerate(mask_patterns):
        # Create a dataset with this pattern
        dataset = TransformerSeismicDataset(
            test_geo[:1], test_das[:1],  # Just use the first sample
            mask_ratio=ratio,
            mask_pattern=pattern,
            positional_encoding=True
        )
        
        # Get batch
        batch = [t.unsqueeze(0) for t in dataset[0]]  # Add batch dimension
        
        # Run inference
        results = run_inference(model, batch, device)
        
        # Plot masked geophone
        masked_geophone = results['input_data'][0, results['n_das_channels']:]
        ax = axes[i, 0]
        im = ax.imshow(masked_geophone, aspect='auto', cmap='seismic')
        ax.set_title(f"{pattern.capitalize()} - Masked Input")
        if i == 2:
            ax.set_xlabel("Time Sample")
        ax.set_ylabel("Channel")
        fig.colorbar(im, ax=ax)
        
        # Plot predictions
        predictions = results['predictions'][0]
        ax = axes[i, 1]
        im = ax.imshow(predictions, aspect='auto', cmap='seismic')
        ax.set_title(f"{pattern.capitalize()} - Predictions")
        if i == 2:
            ax.set_xlabel("Time Sample")
        fig.colorbar(im, ax=ax)
        
        # Plot true data
        target = results['target'][0]
        ax = axes[i, 2]
        im = ax.imshow(target, aspect='auto', cmap='seismic')
        ax.set_title(f"{pattern.capitalize()} - Ground Truth")
        if i == 2:
            ax.set_xlabel("Time Sample")
        fig.colorbar(im, ax=ax)
        
        # Compute metrics
        geo_mask = results['geo_mask'][0]
        masked_indices = np.where(geo_mask)
        masked_targets = target[masked_indices]
        masked_predictions = predictions[masked_indices]
        
        mse = mean_squared_error(masked_targets, masked_predictions)
        snr = signal_to_noise_ratio(masked_targets, masked_predictions)
        corr = correlation_coefficient(masked_targets.flatten(), masked_predictions.flatten())
        
        # Add metrics as text
        metrics_text = f"MSE: {mse:.6f}\nSNR: {snr:.2f} dB\nCorr: {corr:.4f}"
        axes[i, 1].text(0.5, 0.05, metrics_text, transform=axes[i, 1].transAxes,
                      ha='center', va='bottom',
                      bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.suptitle(f"Comparison of Masking Patterns ({ratio*100:.0f}% masked)", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    return fig

# Create and save the figure
masking_fig = create_masking_comparison_figure()
masking_fig.savefig(Path(figures_dir) / "masking_patterns_comparison.png", dpi=300, bbox_inches='tight')
masking_fig.savefig(Path(figures_dir) / "masking_patterns_comparison.pdf", format='pdf', bbox_inches='tight')

In [None]:
# Create a summary figure for the paper
def create_metrics_summary_figure():
    """Create a summary figure of metrics for the paper."""
    # Create a dataframe with pattern, ratio, and metrics
    data = []
    for pattern in mask_patterns:
        for ratio in mask_ratios:
            metrics = pattern_metrics[pattern][ratio]
            data.append({
                'Pattern': pattern.capitalize(),
                'Mask Ratio': ratio * 100,  # Convert to percentage
                'MSE': metrics['mse'],
                'SNR (dB)': metrics['snr'],
                'Correlation': metrics['corr']
            })
    
    df = pd.DataFrame(data)
    
    # Create a figure with three subplots (one for each metric)
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot MSE
    sns.barplot(x='Pattern', y='MSE', hue='Mask Ratio', data=df, ax=axes[0])
    axes[0].set_title('Mean Squared Error')
    axes[0].set_xlabel('Masking Pattern')
    axes[0].set_ylabel('MSE')
    
    # Plot SNR
    sns.barplot(x='Pattern', y='SNR (dB)', hue='Mask Ratio', data=df, ax=axes[1])
    axes[1].set_title('Signal-to-Noise Ratio')
    axes[1].set_xlabel('Masking Pattern')
    axes[1].set_ylabel('SNR (dB)')
    
    # Plot Correlation
    sns.barplot(x='Pattern', y='Correlation', hue='Mask Ratio', data=df, ax=axes[2])
    axes[2].set_title('Correlation Coefficient')
    axes[2].set_xlabel('Masking Pattern')
    axes[2].set_ylabel('Correlation')
    
    # Adjust legend
    for ax in axes:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, [f"{r}%" for r in sorted([int(label) for label in labels])], title='Mask Ratio')
    
    plt.suptitle('Performance Metrics for Different Masking Patterns and Ratios', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    return fig

# Create and save the figure
metrics_fig = create_metrics_summary_figure()
metrics_fig.savefig(Path(figures_dir) / "metrics_summary.png", dpi=300, bbox_inches='tight')
metrics_fig.savefig(Path(figures_dir) / "metrics_summary.pdf", format='pdf', bbox_inches='tight')

## Save Evaluation Results

Save the evaluation results to a CSV file for future reference.

In [None]:
# Create a dataframe with all evaluation results
evaluation_data = []
for pattern in mask_patterns:
    for ratio in mask_ratios:
        metrics = pattern_metrics[pattern][ratio]
        evaluation_data.append({
            'Pattern': pattern,
            'Mask Ratio': ratio,
            'MSE': metrics['mse'],
            'SNR': metrics['snr'],
            'Correlation': metrics['corr']
        })

evaluation_df = pd.DataFrame(evaluation_data)

# Save to CSV
evaluation_df.to_csv(Path(results_dir) / "evaluation_results.csv", index=False)
print(f"Saved evaluation results to {Path(results_dir) / 'evaluation_results.csv'}")

## Conclusion

In this notebook, we have:

1. Loaded the trained transformer model for seismic interpolation
2. Run inference on test data and evaluated the model's performance
3. Analyzed the impact of different masking patterns and ratios
4. Created publication-quality figures to visualize the results
5. Demonstrated the effectiveness of the transformer model for interpolating missing geophone data using DAS constraints

The evaluation results show that the model performs well across different masking patterns and ratios, with the performance naturally declining as the mask ratio increases. These findings highlight the potential of this approach for practical applications in seismic data processing.