# Revised Seismic Interpretation Analysis

## Introduction: Automating Seismic Facies Classification

Seismic interpretation, a cornerstone of subsurface exploration and characterization in geophysics, involves analyzing seismic reflection data to understand geological structures, stratigraphy, and fluid content (e.g., Sheng et al., 2025). A key task within this process is **seismic facies classification**, which aims to categorize distinct zones within the seismic volume based on their reflection characteristics (amplitude, frequency, continuity) (e.g., Gao et al., 2025; Chikhaoui & Alfarraj, 2024). These categories, or **facies**, often correspond to different depositional environments or rock types, providing crucial insights for resource exploration and reservoir modeling (e.g., Gao et al., 2025).

Traditionally, seismic facies classification relies heavily on manual interpretation by experienced geophysicists, a process that is time-consuming, subjective, and challenging to apply consistently across large 3D seismic volumes (e.g., Sheng et al., 2025; Babikir et al., 2024). Automating this task using machine learning offers the potential for faster, more objective, and reproducible interpretations (e.g., Babikir et al., 2024; Mustafa & AlRegib, 2023).

This study investigates the application of deep learning models for automating seismic facies classification using the publicly available F3 block dataset from the Dutch North Sea. Specifically, we aim to answer the following research questions:

1.  How effectively can a baseline 3D Convolutional Neural Network (CNN), trained on seismic amplitude patches, classify seismic facies defined by interpreted horizons?
2.  Can sequence-based models, specifically a Bidirectional Long Short-Term Memory (BiLSTM) network analyzing individual seismic traces, provide a competitive or complementary approach to patch-based CNNs for this classification task?

We will compare the performance of these two distinct deep learning architectures – one focusing on 3D spatial patterns (CNN) and the other on 1D sequential patterns along seismic traces (BiLSTM) – using standard classification metrics.

### Related Work

Machine learning has been increasingly applied to seismic interpretation tasks (e.g., Sheng et al., 2025; Liu et al., 2025). Early work often utilized traditional methods like support vector machines or decision trees based on handcrafted seismic attributes (a concept discussed in the context of modern attribute selection by Babikir et al., 2024). More recently, deep learning, particularly CNNs, has shown significant promise for tasks like fault detection (e.g., Liu et al., 2025), salt body delineation [Citation for salt body delineation not directly available in the provided list; a general DL review might be needed], and facies classification (e.g., Zhao et al., 2023; Ore & Gao, 2025). Most CNN approaches for facies classification operate on 2D slices or 3D patches extracted from the seismic volume (e.g., Stitt et al., 2022; Durall et al., 2021).

Recurrent Neural Networks (RNNs), including LSTMs, have been explored for analyzing seismic data as sequences, often focusing on well log correlation or time-series analysis (e.g., Shi et al., 2020, for waveform embedding which is sequential). Their application specifically for facies classification based on trace sequences within a 3D volume is less common compared to CNNs.

This work differentiates itself by directly comparing a standard 3D CNN patch-based approach with a BiLSTM trace-based approach on the same 3D seismic dataset (F3 block) for facies classification defined by multiple horizons, providing insights into the relative strengths of spatial versus sequential modeling for this specific problem.

*(Note: The Reinforcement Learning investigation mentioned in previous versions has been removed based on feedback to focus the scope of this study.)*

## Methodology

### Dataset: F3 Block, Dutch North Sea

This study utilizes the F3 block seismic dataset, a publicly available 3D seismic survey from the Dutch sector of the North Sea. The data is provided in the standard **SEG-Y format** (`Seismic_data.sgy`), which stores seismic trace data along with header information defining trace locations (inline, crossline, X/Y coordinates) and recording parameters.

Accompanying the seismic volume are interpreted **horizons**, provided as text files (`.xyt`). Each file contains X, Y spatial coordinates (Easting, Northing in meters) and the corresponding two-way travel time (TWT in milliseconds) defining a specific geological boundary surface within the seismic volume. These horizons delineate the boundaries between different seismic facies, serving as the ground truth for our classification task.

**Table 1: F3 Dataset Summary**

| Parameter         | Value                                   |
| :---------------- | :-------------------------------------- |
| Data Format       | SEG-Y                                   |
| Dimensions (ILxXLxSamples) | 651 x 951 x 462                         |
| Inline Range      | [Specific Range, e.g., 100-750]         |
| Crossline Range   | [Specific Range, e.g., 300-1250]        |
| Sample Interval   | 4 ms                                    |
| Time Range        | 0 - 1844 ms                             |
| Horizon Files     | 5 (`F3-Horizon-FS4.xyt`, etc.)          |
| Derived Classes   | 6 (Regions between/outside 5 horizons) |

Derived classes are based off of the 5 horizons(fs4, msf4, fs6, fs7, and fs8 respectfully) and a class to be considered not a horizon.


*(Note: Specific Inline/Crossline ranges need verification from data loading)*

### Data Preprocessing

The primary goal of preprocessing is to convert the raw SEG-Y data and horizon picks into labeled data suitable for model training: 3D patches for the CNN and 1D traces for the BiLSTM.

1.  **SEG-Y Loading & Coordinate Scaling:**
    *   The `segyio` library is used to load the seismic volume, mapping trace headers (Inline, Crossline, CDP_X, CDP_Y) and amplitude samples.
    *   Coordinate scaling factors from the SEG-Y binary header are applied to ensure CDP_X/CDP_Y coordinates match the units (meters) of the horizon files.
2.  **Seismic Amplitude Processing:**
    *   A bandpass filter (e.g., 5-60 Hz) is applied to each trace to remove noise outside the typical seismic frequency band. [This is a standard geophysical processing step; specific citation from the list is difficult. Many papers like Babikir et al. (2024) would implicitly use such processing before attribute analysis or ML.]
    *   The Hilbert transform is used to compute the instantaneous amplitude (envelope) of the filtered trace, enhancing reflection strength information. [Also a standard geophysical processing step; papers like Han et al. (2024) focusing on multi-attribute learning might utilize envelope, but a direct citation for the transform itself from this list is challenging.]
    *   The resulting envelope volume is normalized (e.g., clipping extreme percentiles and standard scaling) to stabilize model training.
3.  **Horizon Mapping:**
    *   Horizon `.xyt` files are loaded using `pandas`.
    *   A KD-Tree built from the seismic trace coordinates is used to efficiently map each horizon pick (X, Y) to the nearest seismic trace (inline, crossline).
    *   The horizon pick time (TWT) is converted to a sample index using linear interpolation based on the seismic trace\\s sample times.
    *   The mapped sample indices for all horizons are stored in a 3D stack (`horizon_stack`) indexed by (horizon_index, inline_index, crossline_index).
4.  **Patch/Trace Extraction and Labeling:**
    *   **For CNN:** 3D patches (e.g., 32x32x32 samples) are extracted from the processed amplitude volume with a defined stride.
    *   **For BiLSTM:** 1D traces (vertical sequences of amplitude values) are extracted at various inline/crossline locations.
    *   **Labeling:** For both patches and traces, the label is determined by the geological interval the patch center (or trace location) falls into, based on the `horizon_stack`. The sample index of the patch center (or each sample along the trace for sequence labeling, though patch-center labeling is simpler here) is compared to the sorted horizon depths at that inline/crossline location using `np.searchsorted`. This assigns an integer label (0 to N, where N is the number of horizons) representing the facies.
5.  **Train/Validation Split:**
    *   The extracted, labeled data (patches for CNN, traces for BiLSTM) is split into training (80%) and validation (20%) sets.
    *   **Stratification:** The split is stratified based on the facies labels (`y`) to ensure that the proportion of each class is approximately maintained in both the training and validation sets. This is crucial for handling potentially imbalanced datasets.
    *   **Assumption:** We assume the data points (patches/traces) are sufficiently independent for a random split. No explicit spatial partitioning (e.g., splitting by inline/crossline ranges) is performed, which could be considered for future work to better assess spatial generalization.



In [1]:
# Code Block 1: Imports and Initial Setup (Optimized)
import os
import numpy as np
import torch
import segyio
import pandas as pd
from scipy.signal import butter, filtfilt, hilbert
from scipy.interpolate import interp1d
from scipy.spatial import cKDTree
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
import torchvision.models as models
import copy
import struct
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false" 

#seed because life
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

#config
data_dir = "F3_Demo_2020" 
segy_path = "F3_Demo_2020/Rawdata/Seismic_data.sgy"
horizon_dir = os.path.join(data_dir, "Rawdata", "Surface_data")
horizon_files = [
    os.path.join(horizon_dir, "F3-Horizon-FS4.xyt.bz2"),
    os.path.join(horizon_dir, "F3-Horizon-MFS4.xyt"),
    os.path.join(horizon_dir, "F3-Horizon-FS6.xyt"),
    os.path.join(horizon_dir, "F3-Horizon-FS7.xyt"),
    os.path.join(horizon_dir, "F3-Horizon-FS8.xyt"),
    os.path.join(horizon_dir, "F3-Horizon-Shallow.xyt"),
    os.path.join(horizon_dir, "F3-Horizon-Top-Foresets.xyt")
]

PATCH_SIZE = 32
STRIDE = 16 
MAX_PATCHES = 50000
BATCH_SIZE = 32
NUM_EPOCHS_CNN = 25
NUM_EPOCHS_BILSTM = 40
device = torch.device("cpu")
print(f"Using device: {device}")

#helpers
def bandpass_filter(trace, lowcut=5, highcut=60, fs=250, order=4):
    nyq = 0.5 * fs
    # Check if trace is constant
    if np.all(trace == trace[0]):
        return trace # Return constant trace as is
    # Check for NaN/inf
    if np.isnan(trace).any() or np.isinf(trace).any():
        return np.zeros_like(trace) # Return zeros if invalid data
    try:
        b, a = butter(order, [lowcut/nyq, highcut/nyq], btype="band")
        return filtfilt(b, a, trace)
    except ValueError as e:
        # Handle potential issues with short traces or edge cases
        print(f"Warning: Filtering failed for a trace - {e}. Returning zeros.")
        return np.zeros_like(trace)
print("Setup complete.")

Using device: cpu
Setup complete.


## Methodology: Detailed Approach

This section should provide a comprehensive overview of the methods employed. While the core model architectures are defined in the code cells, this markdown section should elaborate on their design choices, the data preprocessing pipeline, and the training and evaluation strategy.

### Data Preprocessing and Feature Engineering


*   **Data Loading:** Describe the source of the F3 block data and the horizon labels.
*   **SEG-Y Data Parsing:** Mention the extraction of seismic amplitudes and relevant metadata (e.g., inline, crossline, time/depth samples).
*   **Horizon Mapping:** Explain how horizon data (.txt files) are used to label the seismic data, including any scaling or coordinate transformation steps.
*   **Patch Extraction:** Detail the process of extracting 3D patches (e.g., 64x64x64 voxels) around labeled points for the CNN and corresponding 1D traces (e.g., center trace of 64 samples) for the BiLSTM. Justify the patch/trace dimensions.
*   **Normalization:** Describe the normalization technique applied to the seismic amplitudes (e.g., per-patch or global mean/std normalization) and its importance.
*   **Train/Validation/Test Split:** Explain the strategy for splitting the data, ensuring that the splits are representative and prevent data leakage (e.g., splitting by seismic lines if spatial correlation is a concern, or random split if appropriate).

### Model Architectures

**1. 3D Convolutional Neural Network (CNN) for Seismic Facies Classification:**

*   **Rationale:** Explain why 3D CNNs are suitable for volumetric seismic data, highlighting their ability to learn spatial hierarchies of features directly from the 3D patches.
*   **Architecture Details:** Describe the layers of your `SeismicCNN3D` model:
    *   Convolutional Layers: Specify the number of layers, filter sizes (e.g., 3x3x3), number of filters, activation functions (e.g., ReLU), and any use of padding.
    *   Pooling Layers: Detail the type of pooling (e.g., MaxPooling3D), pool size, and stride.
    *   Batch Normalization/Dropout: Mention if and where these regularization techniques are used and their purpose (e.g., Batch Normalization after convolutions, Dropout before fully connected layers).
    *   Flattening Layer: How the 3D feature maps are converted to a 1D vector.
    *   Fully Connected (Dense) Layers: Specify the number of dense layers, number of units, and activation functions leading to the output layer.
    *   Output Layer: Activation function (e.g., Softmax for multi-class classification) and number of units (corresponding to the number of seismic facies).

**2. Bidirectional Long Short-Term Memory (BiLSTM) Network for Seismic Trace Analysis:**

*   **Rationale:** Explain the choice of BiLSTMs for analyzing 1D seismic traces, emphasizing their ability to capture sequential dependencies and contextual information from both past and future samples in the trace.
*   **Architecture Details:** Describe the layers of your `SeismicBiLSTM` model:
    *   LSTM Layers: Specify the number of BiLSTM layers, hidden units per LSTM, and whether `batch_first=True` is used.
    *   Dropout: Mention if dropout is applied within or after the LSTM layers for regularization.
    *   Fully Connected Layers: Detail any dense layers following the LSTM output, their units, and activation functions.
    *   Output Layer: Activation function (e.g., Softmax) and number of units.

**3. Hybrid CNN-BiLSTM Model:**

*   **Rationale:** Explain the motivation for combining the CNN and BiLSTM models. The hybrid approach aims to leverage the CNN's strength in capturing 3D spatial context from patches and the BiLSTM's proficiency in understanding 1D sequential patterns within individual traces. The hypothesis is that combining these complementary features will lead to a more robust and accurate classification.
*   **Feature Fusion Strategy:** Describe how features from the CNN and BiLSTM are combined:
    *   Feature Extraction: Explain that the pre-final layers of the trained (or partially trained) CNN and BiLSTM models are used as feature extractors.
    *   Concatenation: The feature vectors from both models are concatenated.
    *   Classification Head: The concatenated feature vector is then fed into one or more new fully connected layers that form the classification head of the hybrid model, culminating in a Softmax output layer.
*   **Input to Hybrid Model:** Clarify that the `HybridSeismicDataset` provides both the 3D patch (for the CNN part) and the corresponding 1D trace (for the BiLSTM part) for each sample.

### Training Strategy

*   **Loss Function:** Specify the loss function used (e.g., Cross-Entropy Loss for multi-class classification).
*   **Optimizer:** Name the optimizer (e.g., Adam, SGD) and its key hyperparameters (e.g., learning rate, weight decay).
*   **Learning Rate Scheduling:** Mention if any learning rate scheduler (e.g., StepLR, ReduceLROnPlateau) was used and its parameters.
*   **Batch Size:** State the batch size used for training.
*   **Number of Epochs:** Specify the total number of training epochs.
*   **Early Stopping:** If used, describe the criteria for early stopping (e.g., monitoring validation loss or accuracy with a certain patience).
*   **Model Saving:** Explain that the model with the best validation performance (e.g., highest accuracy or lowest loss) is saved during training.
*   **Hardware/Software:** Briefly mention the computational environment (e.g., GPU type if used, PyTorch version).

### Evaluation Metrics

Clearly list and define the metrics used to evaluate model performance:
*   Accuracy
*   Precision (specify if micro, macro, or weighted average is reported, and why)
*   Recall (specify if micro, macro, or weighted average is reported, and why)
*   F1-Score (specify if micro, macro, or weighted average is reported, and why)
*   Confusion Matrix: Explain its utility in understanding class-wise performance and misclassifications.


In [2]:
EPSILON = 1e-8
DEFAULT_LOWCUT = 5
DEFAULT_HIGHCUT = 60
DEFAULT_FILTER_ORDER = 4

#coordinate Scaling Factor
#Based on debug output comparison:
#horizon x/y 1/10th of SEG-Y Raw CDP X/Y so factor up
HORIZON_COORD_SCALE_FACTOR = 10.0

def bandpass_filter(trace, lowcut=DEFAULT_LOWCUT, highcut=DEFAULT_HIGHCUT, fs=250, order=DEFAULT_FILTER_ORDER):
    """Applies a bandpass filter to a single trace."""
    nyq = 0.5 * fs
    if np.all(trace == trace[0]):
        return trace
    if np.isnan(trace).any() or np.isinf(trace).any():
        print("Warning: NaN or Inf found in trace, returning zeros.")
        return np.zeros_like(trace)
    if lowcut <= 0 or highcut >= nyq:
        print(f"Warning: Invalid frequency cuts ({lowcut}, {highcut}) for Nyquist {nyq}. Returning zeros.")
        return np.zeros_like(trace)
    try:
        b, a = butter(order, [lowcut / nyq, highcut / nyq], btype="band")
        return filtfilt(b, a, trace)
    except ValueError as e:
        print(f"Warning: Filtering failed for a trace - {e}. Returning zeros.")
        return np.zeros_like(trace)

def load_and_preprocess_data(segy_path, horizon_files, patch_size, stride, max_patches):
    #Loads SEG-Y data and horizon picks (X,Y,time), processes amplitudes,
    #extracts 3D patches, and returns X (N,1,D,H,W), y (N,), num_classes.
    
    if not os.path.exists(segy_path):
        raise FileNotFoundError(f"SEG-Y file not found: {segy_path}")
    print("Checking horizon file paths…")
    valid_horizons = []
    for hf in horizon_files:
        if os.path.exists(hf):
            print(f"  {hf} → FOUND")
            valid_horizons.append(hf)
        else:
            print(f"  {hf} → MISSING")
    if not valid_horizons:
        raise ValueError(f"No valid horizon files found; checked: {horizon_files}")

    #read SEG-Y volume
    print("Loading SEG-Y data...")
    try:
        with segyio.open(segy_path, "r", ignore_geometry=True) as f:
            f.mmap()
            inlines = f.attributes(segyio.TraceField.INLINE_3D)[:]
            xlines = f.attributes(segyio.TraceField.CROSSLINE_3D)[:]
            raw_cdpX = f.attributes(segyio.TraceField.CDP_X)[:].astype(float)
            raw_cdpY = f.attributes(segyio.TraceField.CDP_Y)[:].astype(float)
            samples = np.array(f.samples)
            sample_rate = segyio.tools.dt(f) / 1000.0
            fs = 1000.0 / sample_rate

            print(f"  Sample rate: {sample_rate} ms, Freq: {fs:.1f} Hz")
            print(f"  Raw CDP X range: {raw_cdpX.min():.3f}–{raw_cdpX.max():.3f}, "
                  f"Raw CDP Y range: {raw_cdpY.min():.3f}–{raw_cdpY.max():.3f}")

            uni_il, il_counts = np.unique(inlines, return_counts=True)
            uni_xl, xl_counts = np.unique(xlines, return_counts=True)
            if not (il_counts.size > 0 and xl_counts.size > 0 and
                    np.all(il_counts == il_counts[0]) and np.all(xl_counts == xl_counts[0])):
                print("  Warning: Irregular inline/xline distribution detected.")
            else:
                 print(f"  Grid appears regular: {il_counts[0]} traces per inline, {xl_counts[0]} traces per xline.")

            n_ilines, n_xlines, n_samples = len(uni_il), len(uni_xl), len(samples)
            print(f"  Volume dims (IL, XL, Samples): {n_ilines}, {n_xlines}, {n_samples}")

            with open(segy_path, "rb") as raw_file:
                raw_file.seek(3216)
                scalar_bytes = raw_file.read(2)
                if len(scalar_bytes) == 2:
                    scalar = struct.unpack(">h", scalar_bytes)[0]
                    print(f"  Coordinate scalar found: {scalar}")
                else:
                    scalar = 1
                    print("  Warning: Could not read coordinate scalar, assuming 1.")

            #ppply scalar for info purposes, but KDTree uses raw
            if scalar != 1:
                scaled_cdpX = raw_cdpX.copy()
                scaled_cdpY = raw_cdpY.copy()
                if scalar > 0:
                    scaled_cdpX *= scalar
                    scaled_cdpY *= scalar
                elif scalar < 0:
                    scaled_cdpX /= abs(scalar)
                    scaled_cdpY /= abs(scalar)
                print(f"  Scaled CDP X range: {scaled_cdpX.min():.3f}–{scaled_cdpX.max():.3f}, "
                      f"Scaled CDP Y range: {scaled_cdpY.min():.3f}–{scaled_cdpY.max():.3f}")

            print("Processing amplitudes…")
            volume = np.zeros((n_ilines, n_xlines, n_samples), dtype=np.float32)
            trace_map = {}
            for i in range(f.tracecount):
                trace_map[(inlines[i], xlines[i])] = i

            for il_idx, il_val in enumerate(tqdm(uni_il, desc="    Inlines")):
                for xl_idx, xl_val in enumerate(uni_xl):
                    trace_index = trace_map.get((il_val, xl_val))
                    if trace_index is not None:
                        trace = f.trace.raw[trace_index].astype(np.float32)
                        filtered_trace = bandpass_filter(trace, fs=fs)
                        env = np.abs(hilbert(filtered_trace))
                        volume[il_idx, xl_idx, :] = env

            mask = volume != 0
            mean_val, std_val = 0.0, 1.0
            if mask.any():
                p1, p99 = np.percentile(volume[mask], [1, 99])
                volume = np.clip(volume, p1, p99)
                mean_val = volume[mask].mean()
                std_val = volume[mask].std()
                volume[mask] = (volume[mask] - mean_val) / (std_val + EPSILON)
            print(f"  Normalized stats (mean/std): {mean_val:.3f}/{std_val:.3f}")

    except Exception as e:
        print(f"Error loading or processing SEG-Y file: {e}")
        raise

    print("Building KD-Tree on raw SEG-Y coordinates...")
    coords = np.column_stack((raw_cdpX, raw_cdpY))
    try:
        tree = cKDTree(coords)
    except Exception as e:
        print(f"Error building KDTree: {e}")
        raise

    time_to_idx = interp1d(samples, np.arange(n_samples),
                           kind="nearest", bounds_error=False, fill_value=-1)

    print("Loading and mapping horizons…")
    horizon_stack = np.full((len(valid_horizons), n_ilines, n_xlines),
                            np.nan, dtype=float)
    il_map = {val: idx for idx, val in enumerate(uni_il)}
    xl_map = {val: idx for idx, val in enumerate(uni_xl)}

    for h_idx, hf in enumerate(valid_horizons):
        try:
            df = pd.read_csv(hf, sep=r'\s+', header=None,
                             names=["X", "Y", "time_ms"], engine='python', compression='infer')
            if df.isnull().values.any():
                 print(f"  Warning: NaNs detected in {os.path.basename(hf)}, dropping rows.")
                 df = df.dropna()
            if df.empty:
                print(f"  Warning: No valid data found in {os.path.basename(hf)}")
                continue

            #FIX: Scale horizon coordinates to match SEG-Y raw coordinates
            print(f"  Applying scale factor ({HORIZON_COORD_SCALE_FACTOR}) to coordinates from {os.path.basename(hf)}")
            df["X"] *= HORIZON_COORD_SCALE_FACTOR
            df["Y"] *= HORIZON_COORD_SCALE_FACTOR
            

        except Exception as e:
            print(f"  Error reading or scaling horizon file {hf}: {e}. Skipping.")
            continue

        tops = np.full((n_ilines, n_xlines), np.nan, dtype=float)
        mapped_points = 0
        for Xval, Yval, t_ms in df.itertuples(index=False):
            try:
                dist, tidx = tree.query([Xval, Yval])
                if tidx < 0 or tidx >= len(raw_cdpX):
                    continue
                il_hdr, xl_hdr = inlines[tidx], xlines[tidx]
                i_idx = il_map.get(il_hdr)
                x_idx = xl_map.get(xl_hdr)
                if i_idx is None or x_idx is None:
                    continue
                s_idx = int(time_to_idx(t_ms))
                if 0 <= s_idx < n_samples and np.isnan(tops[i_idx, x_idx]):
                    tops[i_idx, x_idx] = s_idx
                    mapped_points += 1
            except Exception as e:
                print(f"  Warning: Error processing point ({Xval}, {Yval}, {t_ms}) from {os.path.basename(hf)}: {e}")
                continue

        horizon_stack[h_idx] = tops
        coverage = (~np.isnan(tops)).sum() / tops.size * 100 if tops.size > 0 else 0
        print(f"  {os.path.basename(hf)}: {coverage:.2f}% coverage ({mapped_points} points mapped)")

    if np.all(np.isnan(horizon_stack)):
         raise ValueError("Horizon mapping resulted in zero coverage for all files. Check coordinate systems and file contents.")

    print("Extracting patches and labels…")
    patches, labels = [], []
    half_patch = patch_size // 2
    for il in range(0, n_ilines - patch_size + 1, stride):
        for xl in range(0, n_xlines - patch_size + 1, stride):
            ctr_il, ctr_xl = il + half_patch, xl + half_patch
            depths_at_center = horizon_stack[:, ctr_il, ctr_xl]
            valid_depths = np.sort(depths_at_center[~np.isnan(depths_at_center)])
            if valid_depths.size == 0:
                continue
            for sm in range(0, n_samples - patch_size + 1, stride):
                patch = volume[il : il + patch_size,
                               xl : xl + patch_size,
                               sm : sm + patch_size]
                if np.isfinite(patch).all():
                    center_depth_sample = sm + half_patch
                    label = np.searchsorted(valid_depths, center_depth_sample, side='right')
                    patches.append(patch[np.newaxis, ...])
                    labels.append(label)
                    if len(patches) >= max_patches:
                        break
            if len(patches) >= max_patches:
                break
        if len(patches) >= max_patches:
            break

    if not patches:
        raise ValueError("No valid patches extracted. Check patch_size, stride, horizon coverage, and volume normalization.")
    print(f"  Extracted {len(patches)} patches.")

    X = np.stack(patches).astype(np.float32)
    y = np.array(labels, dtype=np.int64)
    unique_labels = np.unique(y)
    num_classes = len(unique_labels)
    label_map = {lab: i for i, lab in enumerate(unique_labels)}
    y_mapped = np.vectorize(label_map.get)(y)

    print(f"Final shapes → X: {X.shape}, y: {y_mapped.shape}, classes: {num_classes} (mapped from {unique_labels})")

    return X, y_mapped, num_classes

try:
    X, y, num_classes = load_and_preprocess_data(
        segy_path, horizon_files, PATCH_SIZE, STRIDE, MAX_PATCHES
    )
    print("\nPreprocessing complete (with coordinate fix).")
    print("X shape:", X.shape)
    print("y shape:", y.shape)
    print("Number of classes:", num_classes)
except Exception as e:
    print(f"\nAn error occurred during preprocessing: {e}")

Checking horizon file paths…
  F3_Demo_2020/Rawdata/Surface_data/F3-Horizon-FS4.xyt.bz2 → FOUND
  F3_Demo_2020/Rawdata/Surface_data/F3-Horizon-MFS4.xyt → FOUND
  F3_Demo_2020/Rawdata/Surface_data/F3-Horizon-FS6.xyt → FOUND
  F3_Demo_2020/Rawdata/Surface_data/F3-Horizon-FS7.xyt → FOUND
  F3_Demo_2020/Rawdata/Surface_data/F3-Horizon-FS8.xyt → FOUND
  F3_Demo_2020/Rawdata/Surface_data/F3-Horizon-Shallow.xyt → FOUND
  F3_Demo_2020/Rawdata/Surface_data/F3-Horizon-Top-Foresets.xyt → FOUND
Loading SEG-Y data...
  Sample rate: 4.0 ms, Freq: 250.0 Hz
  Raw CDP X range: 6054167.000–6295763.000, Raw CDP Y range: 60735564.000–60904632.000
  Volume dims (IL, XL, Samples): 651, 951, 462
  Coordinate scalar found: 4000
  Scaled CDP X range: 24216668000.000–25183052000.000, Scaled CDP Y range: 242942256000.000–243618528000.000
Processing amplitudes…


    Inlines:   0%|          | 0/651 [00:00<?, ?it/s]

  Normalized stats (mean/std): 2326.829/1714.808
Building KD-Tree on raw SEG-Y coordinates...
Loading and mapping horizons…
  Applying scale factor (10.0) to coordinates from F3-Horizon-FS4.xyt.bz2
  F3-Horizon-FS4.xyt.bz2: 95.71% coverage (592541 points mapped)
  Applying scale factor (10.0) to coordinates from F3-Horizon-MFS4.xyt
  F3-Horizon-MFS4.xyt: 95.65% coverage (592177 points mapped)
  Applying scale factor (10.0) to coordinates from F3-Horizon-FS6.xyt
  F3-Horizon-FS6.xyt: 49.04% coverage (303622 points mapped)
  Applying scale factor (10.0) to coordinates from F3-Horizon-FS7.xyt
  F3-Horizon-FS7.xyt: 95.31% coverage (590096 points mapped)
  Applying scale factor (10.0) to coordinates from F3-Horizon-FS8.xyt
  F3-Horizon-FS8.xyt: 95.50% coverage (591259 points mapped)
  Applying scale factor (10.0) to coordinates from F3-Horizon-Shallow.xyt
  F3-Horizon-Shallow.xyt: 95.33% coverage (590203 points mapped)
  Applying scale factor (10.0) to coordinates from F3-Horizon-Top-Forese

In [3]:
#Train/Validation Split and DataLoaders
X_train_cnn, X_val_cnn, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=SEED, stratify=y)

#create TensorDatasets and DataLoaders for CNN
train_ds_cnn = TensorDataset(torch.tensor(X_train_cnn, dtype=torch.float32),
                             torch.tensor(y_train, dtype=torch.long))
val_ds_cnn = TensorDataset(torch.tensor(X_val_cnn, dtype=torch.float32),
                           torch.tensor(y_val, dtype=torch.long))
train_loader_cnn = DataLoader(train_ds_cnn, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader_cnn = DataLoader(val_ds_cnn, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f"CNN DataLoaders: Train batches={len(train_loader_cnn)}, Val batches={len(val_loader_cnn)}")

#Data Preparation for BiLSTM using traces
def extract_traces_from_patches(patches, num_traces_per_patch=5):
    N, C, D, H, W = patches.shape
    all_traces = np.zeros((N, D), dtype=np.float32)
    for i in range(N):
        patch_traces = np.zeros((num_traces_per_patch, D), dtype=np.float32)
        for j in range(num_traces_per_patch):
            h, w = np.random.randint(0, H), np.random.randint(0, W)
            patch_traces[j] = patches[i, 0, :, h, w]
        all_traces[i] = np.mean(patch_traces, axis=0)
    return all_traces

print("Extracting traces for BiLSTM...")
X_train_traces = extract_traces_from_patches(X_train_cnn)
X_val_traces = extract_traces_from_patches(X_val_cnn)

#for use later
def extract_center_trace(patch):
    C, D, H, W = patch.shape
    random_h = np.random.randint(0, H)
    random_w = np.random.randint(0, W)
    trace = patch[0, :, random_h, random_w]
    return trace.astype(np.float32)

#create TensorDatasets and DataLoaders for BiLSTM
train_ds_bilstm = TensorDataset(torch.tensor(X_train_traces, dtype=torch.float32),
                                torch.tensor(y_train, dtype=torch.long))
val_ds_bilstm = TensorDataset(torch.tensor(X_val_traces, dtype=torch.float32),
                              torch.tensor(y_val, dtype=torch.long))
train_loader_bilstm = DataLoader(train_ds_bilstm, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader_bilstm = DataLoader(val_ds_bilstm, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f"BiLSTM DataLoaders: Train batches={len(train_loader_bilstm)}, Val batches={len(val_loader_bilstm)}")
print(f"Sample trace shape: {X_train_traces[0].shape}")

CNN DataLoaders: Train batches=1250, Val batches=313
Extracting traces for BiLSTM...
BiLSTM DataLoaders: Train batches=1250, Val batches=313
Sample trace shape: (32,)


In [4]:
#Hybrid Dataset
from torch.utils.data import Dataset
import numpy as np
import torch

class HybridSeismicDataset(Dataset):
    def __init__(self, patches, labels):
    #patches (numpy.ndarray): The 3D patches (N, C, D, H, W).
    #labels (numpy.ndarray): The corresponding labels (N,).
        #convert to tensors to avoid issues in DataLoader workers
        self.patches = torch.tensor(patches, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        print(f"HybridSeismicDataset: Initialized with {len(self.patches)} samples.")
        if len(self.patches) != len(self.labels):
             raise ValueError("Patches and labels must have the same number of samples.")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        patch = self.patches[idx]
        label = self.labels[idx]
        
        #extract the corresponding trace from the patch
        #convert patch tensor back to numpy temporarily for extraction function
        trace_np = extract_center_trace(patch.numpy())
        trace = torch.tensor(trace_np, dtype=torch.float32)
        
        return patch, trace, label



train_ds_hybrid = HybridSeismicDataset(X_train_cnn, y_train)
val_ds_hybrid = HybridSeismicDataset(X_val_cnn, y_val)
BATCH_SIZE_HYBRID = 32 
train_loader_hybrid = DataLoader(train_ds_hybrid, batch_size=BATCH_SIZE_HYBRID, shuffle=True, num_workers=0, pin_memory=True)
val_loader_hybrid = DataLoader(val_ds_hybrid, batch_size=BATCH_SIZE_HYBRID, shuffle=False, num_workers=0, pin_memory=True)

print(f"Hybrid DataLoaders: Train batches={len(train_loader_hybrid)}, Val batches={len(val_loader_hybrid)}")

#verify
sample_patch, sample_trace, sample_label = next(iter(train_loader_hybrid))
print(f"Sample Batch Shapes -> Patch: {sample_patch.shape}, Trace: {sample_trace.shape}, Label: {sample_label.shape}")

HybridSeismicDataset: Initialized with 40000 samples.
HybridSeismicDataset: Initialized with 10000 samples.
Hybrid DataLoaders: Train batches=1250, Val batches=313
Sample Batch Shapes -> Patch: torch.Size([32, 1, 32, 32, 32]), Trace: torch.Size([32, 32]), Label: torch.Size([32])


### Model Architectures

Two primary architectures are compared:

1.  **3D Convolutional Neural Network (CNN):**
    *   **Input:** 3D seismic amplitude patches (e.g., 1 x 32 x 32 x 32).
    *   **Architecture:** A simple CNN with two blocks, each containing a 3D Convolution layer, Batch Normalization, ReLU activation, and Max Pooling. This is followed by a Flatten layer and a fully connected classifier head with Dropout for regularization.
    *   **Rationale:** CNNs excel at learning spatial hierarchies of features directly from grid-like data, making them suitable for analyzing 3D seismic patches.

2.  **Bidirectional Long Short-Term Memory (BiLSTM):**
    *   **Input:** 1D seismic amplitude traces (sequences of amplitude values along the depth/time axis).
    *   **Architecture:** A BiLSTM network processes the input sequence. LSTMs are a type of Recurrent Neural Network (RNN) specifically designed to handle long-range dependencies in sequential data. The bidirectional nature allows the model to learn from both past (earlier samples) and future (later samples) context within the trace. The output from the BiLSTM layers is then passed to a fully connected classifier head.
    *   **Rationale:** Seismic traces represent sequences where the amplitude at one sample is related to its neighbors. LSTMs are adept at capturing such temporal/sequential dependencies, which might be relevant for distinguishing facies based on vertical reflection patterns. BiLSTMs enhance this by considering the full trace context.

In [5]:
#Model Definitions

#Simple 3D CNN
class SeismicCNN3D(nn.Module):
    def __init__(self, in_channels=1, num_classes=num_classes, patch_depth=PATCH_SIZE):
        super().__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv3d(in_channels, 16, kernel_size=3, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2), # D/2, H/2, W/2
            # Block 2
            nn.Conv3d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2) # D/4, H/4, W/4
        )
        #calculate flattened size dynamically based on input patch depth
        feat_depth = patch_depth // 4
        feat_h = PATCH_SIZE // 4 
        feat_w = PATCH_SIZE // 4
        flat_size = 32 * feat_depth * feat_h * feat_w

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_size, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        # Input x shape: (N, C, D, H, W) - Ensure D is samples/depth
        x = self.features(x)
        x = self.classifier(x)
        return x

#BiLSTM Model
class SeismicBiLSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=2, num_classes=num_classes, dropout=0.3):
        super().__init__()
        #input size is 1 if feeding raw amplitudes, could be embedding dim if using binned values
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True, #expects (batch, seq_len, features)
                            bidirectional=True,
                            dropout=dropout if num_layers > 1 else 0)
        #classifier input features: hidden_size * 2 (bidirectional)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        #input x shape: (batch, seq_len) need (batch, seq_len, features=1)
        if x.ndim == 2:
            x = x.unsqueeze(-1)

        #LSTM returns output, (h_n, c_n)
        #output shape: (batch, seq_len, hidden_size * 2)
        #h_n shape: (num_layers * 2, batch, hidden_size)
        lstm_out, (hidden, cell) = self.lstm(x)

        #use the output of the last time step for classification
        #could use hidden state or pooling over sequence
        last_step_output = lstm_out[:, -1, :] #shape: (batch, hidden_size * 2)

        logits = self.classifier(last_step_output)
        return logits

cnn_model = SeismicCNN3D(num_classes=num_classes, patch_depth=X.shape[2]).to(device)
bilstm_model = SeismicBiLSTM(input_size=1, num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=1e-4)
optimizer_bilstm = optim.Adam(bilstm_model.parameters(), lr=1e-3)

class HybridCNNBiLSTM(nn.Module):
    def __init__(self, cnn_model, bilstm_model, num_classes):
        super().__init__()
        print("Initializing HybridCNNBiLSTM model...")
        
        #CNN Branch
        self.cnn_features = cnn_model.features
        try:
            cnn_flat_size = cnn_model.classifier[1].in_features
            print(f"  Detected CNN feature size (flattened): {cnn_flat_size}")
        except (AttributeError, IndexError):
             # Fallback calculation if structure differs (use with caution)
             print("  Warning: Could not automatically detect CNN feature size. Calculating fallback...")

        #BiLSTM Branch
        #ue the LSTM part of the provided BiLSTM model
        self.bilstm = bilstm_model.lstm
        bilstm_feature_size = bilstm_model.lstm.hidden_size * 2
        print(f"  Detected BiLSTM feature size: {bilstm_feature_size}")

        #combined Classifier Head
        combined_feature_size = cnn_flat_size + bilstm_feature_size
        print(f"  Combined feature size: {combined_feature_size}")
        self.classifier = nn.Sequential(
            nn.Linear(combined_feature_size, 128), # Intermediate layer
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        print(f"  Added combined classifier head.")

        #otional,Freeze the original model parts
        #for param in self.cnn_features.parameters():
        #     param.requires_grad = False
        #for param in self.bilstm.parameters():
        #     param.requires_grad = False
        #print("  Note: CNN and BiLSTM feature extractors frozen.") 

    def forward(self, patch, trace):
        #patch shape: (N, C, D, H, W)
        #trace shape: (N, SeqLen) or (N, SeqLen, 1)
        
        # Process
        cnn_out = self.cnn_features(patch)
        cnn_out_flat = torch.flatten(cnn_out, 1) # Flatten starting from dim 1
        if trace.ndim == 2:
            trace = trace.unsqueeze(-1) # (N, SeqLen) -> (N, SeqLen, 1)
        lstm_out, (hidden, cell) = self.bilstm(trace)
        # Use the output of the last time step
        bilstm_out = lstm_out[:, -1, :] # Shape: (N, hidden_size * 2)
        combined_features = torch.cat((cnn_out_flat, bilstm_out), dim=1)
        #classify
        logits = self.classifier(combined_features)
        
        return logits
    
hybrid_model = HybridCNNBiLSTM(cnn_model, bilstm_model, num_classes).to(device)

print("--- CNN Model ---")
print(cnn_model)
print("--- BiLSTM Model ---")
print(bilstm_model)
print("--- Hybrid Model ---")
print(hybrid_model)

Initializing HybridCNNBiLSTM model...
  Detected CNN feature size (flattened): 16384
  Detected BiLSTM feature size: 128
  Combined feature size: 16512
  Added combined classifier head.
--- CNN Model ---
SeismicCNN3D(
  (features): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (5): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=16384, out_features=64, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5

### Training Procedure

Both models are trained using the Adam optimizer and CrossEntropyLoss function. Standard training loops are employed, iterating through the training data in batches for a fixed number of epochs. After each epoch, the models are evaluated on the validation set to monitor performance and prevent overfitting. Training and validation loss and accuracy are recorded for later analysis.

In [None]:
#Training and Eval Loop

#debug step
device = torch.device("cpu")

#same issue with device about not supporting 3d on mps
def train_evaluate_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, model_name="Model"):
    train_loss_history, val_loss_history = [], []
    train_acc_history, val_acc_history = [], []

    print(f"\n--- Training {model_name} for {num_epochs} epochs ---")
    for epoch in range(num_epochs):
        #train
        model.train()
        running_loss = 0.0
        correct_train, total_train = 0, 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)

        for inputs, labels in train_pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            train_pbar.set_postfix({"Loss": running_loss / total_train, "Acc": correct_train / total_train})

        epoch_train_loss = running_loss / total_train
        epoch_train_acc = correct_train / total_train
        train_loss_history.append(epoch_train_loss)
        train_acc_history.append(epoch_train_acc)

        #val
        model.eval()
        running_val_loss = 0.0
        correct_val, total_val = 0, 0
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
        with torch.no_grad():
            for inputs, labels in val_pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
                val_pbar.set_postfix({"Loss": running_val_loss / total_val, "Acc": correct_val / total_val})

        epoch_val_loss = running_val_loss / total_val
        epoch_val_acc = correct_val / total_val
        val_loss_history.append(epoch_val_loss)
        val_acc_history.append(epoch_val_acc)

        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {epoch_train_loss:.4f}, Acc: {epoch_train_acc:.4f} | Val Loss: {epoch_val_loss:.4f}, Acc: {epoch_val_acc:.4f}")

    print(f"Finished training {model_name}.")
    history = {
        "train_loss": train_loss_history, "val_loss": val_loss_history,
        "train_acc": train_acc_history, "val_acc": val_acc_history
    }
    return model, history

#train
cnn_model, cnn_history = train_evaluate_model(cnn_model, train_loader_cnn, val_loader_cnn,
                                            criterion, optimizer_cnn, device, NUM_EPOCHS_CNN, "CNN")
bilstm_model, bilstm_history = train_evaluate_model(bilstm_model, train_loader_bilstm, val_loader_bilstm,criterion, optimizer_bilstm, device, NUM_EPOCHS_BILSTM, "BiLSTM")


--- Training CNN for 25 epochs ---


Epoch 1/25 [Train]:   0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 1/25 [Val]:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1/25 - Train Loss: 1.1232, Acc: 0.5917 | Val Loss: 0.8138, Acc: 0.7080


Epoch 2/25 [Train]:   0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 2/25 [Val]:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 2/25 - Train Loss: 0.8905, Acc: 0.6601 | Val Loss: 0.7082, Acc: 0.7167


Epoch 3/25 [Train]:   0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 3/25 [Val]:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 3/25 - Train Loss: 0.8021, Acc: 0.6900 | Val Loss: 0.6365, Acc: 0.7592


Epoch 4/25 [Train]:   0%|          | 0/1250 [00:00<?, ?it/s]

In [None]:
def train_hybrid_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    print(f"Starting hybrid training for {num_epochs} epochs on {device}...")
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 10)

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
                dataloader = train_loader
                print("Training Phase:")
            else:
                model.eval()
                dataloader = val_loader
                print("Validation Phase:")

            running_loss = 0.0
            running_corrects = 0
            total_samples = 0

            progress_bar = tqdm(dataloader, desc=f"{phase.capitalize()} Epoch {epoch+1}", leave=False)
            # Modified loop to unpack three items: patch, trace, label
            for patches, traces, labels in progress_bar:
                patches = patches.to(device)
                traces = traces.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    # Pass both inputs to the hybrid model
                    outputs = model(patches, traces)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                batch_loss = loss.item() * patches.size(0)
                batch_corrects = torch.sum(preds == labels.data)
                running_loss += batch_loss
                running_corrects += batch_corrects
                total_samples += patches.size(0)

                progress_bar.set_postfix({
                    'loss': f"{batch_loss/patches.size(0):.4f}",
                    'acc': f"{batch_corrects.double()/patches.size(0):.4f}"
                })

            epoch_loss = running_loss / total_samples
            epoch_acc = running_corrects.double() / total_samples

            print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            if phase == "train":
                history["train_loss"].append(epoch_loss)
                history["train_acc"].append(epoch_acc.item())
            else:
                history["val_loss"].append(epoch_loss)
                history["val_acc"].append(epoch_acc.item())
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    print(f"  New best validation accuracy: {best_acc:.4f}")

    print(f"\nTraining complete. Best validation Acc: {best_acc:4f}")
    model.load_state_dict(best_model_wts)
    return model, history

def evaluate_hybrid_model(model, dataloader, criterion, device):
    print(f"Evaluating hybrid model on {device}...")
    model.eval()
    all_preds = []
    all_labels = []
    running_loss = 0.0
    total_samples = 0

    progress_bar = tqdm(dataloader, desc="Evaluation", leave=False)
    with torch.no_grad():
        for patches, traces, labels in progress_bar:
            patches = patches.to(device)
            traces = traces.to(device)
            labels = labels.to(device)

            outputs = model(patches, traces)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)

            running_loss += loss.item() * patches.size(0)
            total_samples += patches.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{torch.sum(preds == labels.data).double()/patches.size(0):.4f}"
            })

    avg_loss = running_loss / total_samples
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)
    print(f"Evaluation Loss: {avg_loss:.4f}")
    print(f"Evaluation Accuracy: {accuracy:.4f}")
    print(f"Evaluation Precision (Weighted): {precision:.4f}")
    print(f"Evaluation Recall (Weighted): {recall:.4f}")
    print(f"Evaluation F1-Score (Weighted): {f1:.4f}")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, zero_division=0))
    conf_matrix = confusion_matrix(all_labels, all_preds)
    return avg_loss, accuracy, precision, recall, f1, conf_matrix

In [None]:
#Evaluation Function and Execution
def evaluate_final_model(model, dataloader, device, model_name="Model"):
    print(f"\n--- Evaluating {model_name} ---")
    model.eval()
    all_labels, all_preds = [], []
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc=f"Evaluating {model_name}", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="weighted", zero_division=0)
    report = classification_report(all_labels, all_preds, zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)

    print(f"{model_name} Validation Metrics:")
    print(f"  Accuracy:  {accuracy:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall:    {recall:.4f}")
    print(f"  F1-Score:  {f1:.4f}")
    print("Classification Report:")
    print(report)

    results = {
        "accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1,
        "report": report, "confusion_matrix": cm
    }
    return results
#eval
cnn_results = evaluate_final_model(cnn_model, val_loader_cnn, device, "CNN")
bilstm_results = evaluate_final_model(bilstm_model, val_loader_bilstm, device, "BiLSTM")


In [None]:
#Create Hybrid Datasets and DataLoaders
train_loader_hybrid = DataLoader(train_ds_hybrid, batch_size=BATCH_SIZE_HYBRID, shuffle=True, num_workers=0, pin_memory=True)
val_loader_hybrid = DataLoader(val_ds_hybrid, batch_size=BATCH_SIZE_HYBRID, shuffle=False, num_workers=0, pin_memory=True)
print(f"Hybrid DataLoaders: Train batches={len(train_loader_hybrid)}, Val batches={len(val_loader_hybrid)}")

print("Instantiating Hybrid Model...")
#pass  cnn_model and bilstm_model instances
hybrid_model = HybridCNNBiLSTM(cnn_model.to(device), bilstm_model.to(device), num_classes).to(device)
print(hybrid_model)

#define loss and optimizer
LEARNING_RATE_HYBRID = 1e-4 # Adjust as needed
NUM_EPOCHS_HYBRID = 5     # Adjust as needed

criterion_hybrid = nn.CrossEntropyLoss()

# Optimize all parameters in the hybrid model (including CNN/BiLSTM parts)
optimizer_hybrid = optim.Adam(hybrid_model.parameters(), lr=LEARNING_RATE_HYBRID)
print("Optimizer set to train all parameters of the hybrid model.")

#train
hybrid_model, history_hybrid = train_hybrid_model(
    hybrid_model,
    train_loader_hybrid,
    val_loader_hybrid,
    criterion_hybrid,
    optimizer_hybrid,
    num_epochs=NUM_EPOCHS_HYBRID,
    device=device
)

#eval
print("\n--- Evaluating Best Hybrid Model ---")
val_loss_hybrid, val_acc_hybrid, prec_hybrid, recall_hybrid, f1_hybrid, conf_matrix_hybrid = evaluate_hybrid_model(hybrid_model, val_loader_hybrid, criterion_hybrid, device )
hybrid_results = { "accuracy": val_acc_hybrid, "precision": prec_hybrid, "recall": recall_hybrid, "f1": f1_hybrid, "confusion_matrix": conf_matrix_hybrid} 
print("Hybrid model evaluation results stored.") 
all_results = {"CNN": cnn_results, "BiLSTM": bilstm_results, "Hybrid": hybrid_results}

## Results and Analysis

We evaluate the trained CNN and BiLSTM models on the held-out validation set using standard classification metrics: accuracy, precision, recall, and F1-score (weighted average), along with confusion matrices.

In [None]:
#Plots
def plot_training_history(history, model_name):
    if not history or not isinstance(history, dict) or not all(k in history for k in ["train_loss", "val_loss", "train_acc", "val_acc"]):
        print(f"Warning: Invalid or incomplete history data for {model_name}. Skipping history plot.")
        return
        
    epochs = range(1, len(history["train_loss"]) + 1)
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, history["train_loss"], "bo-", label="Train Loss")
    plt.plot(epochs, history["val_loss"], "ro-", label="Val Loss")
    plt.title(f"{model_name} Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(epochs, history["train_acc"], "bo-", label="Train Accuracy")
    plt.plot(epochs, history["val_acc"], "ro-", label="Val Accuracy")
    plt.title(f"{model_name} Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

print("Plotting Training Histories...")
if "cnn_history" in locals():
    plot_training_history(cnn_history, "CNN")
else:
    print("Skipping CNN history plot (cnn_history not found).")
    
if "bilstm_history" in locals():
    plot_training_history(bilstm_history, "BiLSTM")
else:
    print("Skipping BiLSTM history plot (bilstm_history not found).")

if "history_hybrid" in locals():
    plot_training_history(history_hybrid, "Hybrid")
else:
    print("Skipping Hybrid history plot (history_hybrid not found).")

print("\nPlotting Performance Comparison...")
if "all_results" not in locals() or not isinstance(all_results, dict):
    print("Error: `all_results` dictionary not found or invalid. Cannot create comparison plot.")
else:
    metrics_to_plot = ["accuracy", "precision", "recall", "f1"] # Use lowercase keys as likely stored from evaluate function
    model_names = list(all_results.keys()) # Should include "CNN", "BiLSTM", "Hybrid"
    num_models = len(model_names)
    num_metrics = len(metrics_to_plot)
    
    # Check if all models and metrics exist
    valid_plot = True
    for model in model_names:
        if model not in all_results:
             print(f"Warning: Model 	'{model}	' not found in all_results. Skipping comparison plot.")
             valid_plot = False
             break
        for metric in metrics_to_plot:
             if metric not in all_results[model]:
                 print(f"Warning: Metric 	'{metric}	' not found for model 	'{model}	' in all_results. Skipping comparison plot.")
                 valid_plot = False
                 break
        if not valid_plot: break

    if valid_plot:
        x = np.arange(num_metrics)
        width = 0.25 

        fig, ax = plt.subplots(figsize=(12, 7))
        
        offset_cnn = -width
        offset_bilstm = 0
        offset_hybrid = width
        
        cnn_metrics = [all_results.get("CNN", {}).get(metric, 0) for metric in metrics_to_plot]
        bilstm_metrics = [all_results.get("BiLSTM", {}).get(metric, 0) for metric in metrics_to_plot]
        hybrid_metrics = [all_results.get("Hybrid", {}).get(metric, 0) for metric in metrics_to_plot]

        rects1 = ax.bar(x + offset_cnn, cnn_metrics, width, label="CNN")
        rects2 = ax.bar(x + offset_bilstm, bilstm_metrics, width, label="BiLSTM")
        rects3 = ax.bar(x + offset_hybrid, hybrid_metrics, width, label="Hybrid") 

        ax.set_ylabel("Score")
        ax.set_title("Model Performance Comparison (Validation Set)")
        ax.set_xticks(x) # Keep ticks centered between the groups
        ax.set_xticklabels([m.replace("_", " ").title() for m in metrics_to_plot])
        ax.legend()
        ax.set_ylim(0, 1.1)

        # Add labels to all bars
        ax.bar_label(rects1, padding=3, fmt="%.3f")
        ax.bar_label(rects2, padding=3, fmt="%.3f")
        ax.bar_label(rects3, padding=3, fmt="%.3f")

        fig.tight_layout()
        plt.show()


#plot confusion matricies
print("\nPlotting Confusion Matrices...")
if "all_results" not in locals() or not isinstance(all_results, dict):
    print("Error: `all_results` dictionary not found or invalid. Cannot plot confusion matrices.")
else:
    model_names = list(all_results.keys())
    num_models = len(model_names)
    valid_plot = True
    cms_to_plot = {}
    for model in model_names:
        if model not in all_results or "confusion_matrix" not in all_results[model]:
            print(f"Warning: Confusion matrix not found for model 	'{model}	' in all_results. Skipping its matrix plot.")
            break
        else:
            cms_to_plot[model] = all_results[model]["confusion_matrix"]
            
    num_valid_cms = len(cms_to_plot)

    if num_valid_cms > 0:
        fig, axes = plt.subplots(1, num_valid_cms, figsize=(7 * num_valid_cms, 6)) 
        if num_valid_cms == 1:
            axes = [axes]
            
        plot_idx = 0
        for model_name, cm in cms_to_plot.items():
            if cm is not None and isinstance(cm, np.ndarray):
                sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=axes[plot_idx],
                            xticklabels=range(num_classes), yticklabels=range(num_classes)) 
                axes[plot_idx].set_title(f"{model_name} Confusion Matrix")
                axes[plot_idx].set_xlabel("Predicted Label")
                axes[plot_idx].set_ylabel("True Label")
                plot_idx += 1
            else:
                 print(f"Skipping invalid confusion matrix data for {model_name}.")

        plt.tight_layout()
        plt.show()
    else:
        print("No valid confusion matrices found to plot.")

## Interpreting Model Performance

This section is crucial for presenting and interpreting the findings from your model evaluations. It should go beyond simply stating the metrics and delve into what they mean in the context of your seismic interpretation task.

**(Ensure all plots and metrics mentioned below are generated by your code cells and displayed in the notebook before this markdown section.)**

### Performance of Individual Models

**1. 3D CNN Model:**
*   **Training History:** Discuss the learning curves (loss and accuracy vs. epochs for both training and validation sets). Look for signs of overfitting (large gap between training and validation performance), underfitting (both training and validation performance are low), or good convergence.
*   **Validation Metrics:** Present the final accuracy, precision, recall, and F1-score on the validation set. Interpret these scores. For instance, what does the precision for a specific facies tell you? What about recall?
*   **Confusion Matrix:** Analyze the confusion matrix. Which facies are well-classified? Which ones are commonly confused with each other? Provide geological or data-driven hypotheses for these confusions (e.g., subtle acoustic differences, imbalanced class representation, ambiguous labeling).

**2. BiLSTM Model:**
*   **Training History:** Similar to the CNN, discuss its learning curves.
*   **Validation Metrics:** Present and interpret its performance metrics on the validation set.
*   **Confusion Matrix:** Analyze its confusion matrix, comparing its strengths and weaknesses to the CNN.

### Performance of the Hybrid CNN-BiLSTM Model

*   **Training History:** Discuss the learning curves for the hybrid model.
*   **Validation Metrics:** Present the final accuracy, precision, recall, and F1-score. Critically compare these to the individual CNN and BiLSTM model performances. Did the hybrid model achieve the hypothesized improvement? By how much?
*   **Confusion Matrix:** Analyze the hybrid model's confusion matrix. Did it resolve some of the confusions observed in the individual models? Are there new patterns of misclassification?

### Comparative Analysis

*   **Overall Comparison Table/Plot:** Include a table or bar chart (as generated by your plotting code) that directly compares the key metrics (Accuracy, Precision, Recall, F1-score) across all three models (CNN, BiLSTM, Hybrid) on the validation set.
*   **Statistical Significance (Optional but Recommended for Rigorous Research):** If feasible, consider performing statistical tests (e.g., McNemar's test for comparing two classifiers, or ANOVA for multiple) to determine if the performance differences between models are statistically significant.
*   **Discussion of Strengths and Weaknesses:** Based on the results, discuss the relative strengths and weaknesses of each approach. For example:
    *   Did the 3D CNN excel at capturing certain types of spatial features that the BiLSTM missed?
    *   Was the BiLSTM better at distinguishing facies with subtle vertical variations along traces?
    *   How did the hybrid model balance or improve upon these aspects?
*   **Error Analysis:** Go deeper into specific examples of misclassifications if possible. Visualizing some misclassified patches/traces alongside their true and predicted labels can provide valuable insights into why the models made errors.

### Answering Research Questions

Explicitly revisit your research questions (stated in the Introduction/Methodology) and discuss how the results address them. For example:
*   How effective are deep learning models (specifically your CNN, BiLSTM, and Hybrid) for automated seismic facies classification in the F3 block?
*   Does the hybrid approach combining spatial and sequential feature extraction offer superior performance compared to individual models?

**(Ensure this section is rich with interpretation and connects back to the geological context of the F3 dataset.)**


### Performance Comparison

*(Note: The following discussion assumes hypothetical results. Actual results from running the code should replace this.)*

The 3D CNN model achieved a validation accuracy of approximately **[Insert CNN Accuracy, e.g., 0.6821]** and a weighted F1-score of **[Insert CNN F1, e.g., 0.6750]**. The confusion matrix (Figure 2a) shows **[Describe CNN CM - e.g., good performance on class X, confusion between classes Y and Z]**.

The BiLSTM model, operating on individual traces, achieved an accuracy of **[Insert BiLSTM Accuracy, e.g., 0.6255]** and an F1-score of **[Insert BiLSTM F1, e.g., 0.6180]**. Its confusion matrix (Figure 2b) indicates **[Describe BiLSTM CM - e.g., struggles with class A, better at class B compared to CNN]**.

**Table 2: Model Performance Summary**

| Model   | Accuracy | Precision (w) | Recall (w) | F1-Score (w) |
| :------ | :------- | :------------ | :--------- | :----------- |
| 3D CNN  | [CNN Acc]  | [CNN Prec]    | [CNN Rec]  | [CNN F1]     |
| BiLSTM  | [BiLSTM Acc]| [BiLSTM Prec] | [BiLSTM Rec]| [BiLSTM F1]  |

*(Replace bracketed values with actual results)*

Comparing the two models (Figure 1), the 3D CNN appears to slightly outperform the BiLSTM on this dataset based on overall accuracy and F1-score. This suggests that, for this specific task and data representation, analyzing the 3D spatial context within patches might be more informative than analyzing individual 1D trace sequences alone.

### Relation to Research Questions

1.  **CNN Effectiveness:** The 3D CNN demonstrated moderate effectiveness, achieving an accuracy significantly better than random chance but falling short of high performance (e.g., >90%). This indicates that while 3D patches contain useful information, the simple CNN architecture used might not be sufficient to fully capture the complex variations defining the facies, or the data itself might be inherently challenging. The discrepancy noted in the professor's feedback (claim of 90% vs. actual ~66%) has been corrected; the model achieves around **[Reiterate CNN Accuracy]**.
2.  **BiLSTM Competitiveness:** The BiLSTM model provided a reasonable baseline but did not outperform the CNN. This suggests that, at least with this implementation (using single traces), the sequential information alone might be less discriminative than the 3D spatial information captured by the CNN. Further investigation could explore using multiple traces per location or different sequence representations for the BiLSTM.

### Limitations and Caveats

It is important to acknowledge the limitations of this study:
*   **Simple Architectures:** The CNN and BiLSTM models used are relatively basic. More complex architectures or hyperparameter tuning might yield better results.
*   **Single Dataset:** Findings are based solely on the F3 block dataset and may not generalize to other geological settings.
*   **Trace Representation for BiLSTM:** The method of extracting and representing traces for the BiLSTM was simple (single random trace per patch location). More sophisticated feature extraction or sequence representation might improve BiLSTM performance.
*   **Preprocessing Impact:** The specific preprocessing steps (filtering, normalization) can influence model performance.
*   **Horizon Accuracy:** The analysis relies on the accuracy of the provided horizon interpretations.

The results should be interpreted with caution. While the CNN showed slightly better performance here, neither model achieved exceptionally high accuracy, indicating the task remains challenging. The choice between spatial (CNN) and sequential (BiLSTM) approaches may depend on the specific geological context and data characteristics.

## Conclusion

This study explored the automation of seismic facies classification on the F3 block dataset using two deep learning approaches: a 3D CNN operating on seismic patches and a BiLSTM operating on seismic traces. The preprocessing pipeline successfully transformed raw SEG-Y and horizon data into labeled inputs for both models.

Our results indicate that the 3D CNN achieved slightly better performance (Accuracy: **[CNN Acc]**, F1: **[CNN F1]**) compared to the BiLSTM (Accuracy: **[BiLSTM Acc]**, F1: **[BiLSTM F1]**) under the tested configurations. This suggests that leveraging 3D spatial context might be more advantageous for this specific classification task than relying solely on 1D sequential patterns within individual traces.

However, neither model achieved high-end performance, highlighting the inherent complexity of seismic facies classification. Future work could explore:
*   More advanced CNN architectures (e.g., ResNets adapted for 3D).
*   Hybrid models combining CNN features with RNNs/Transformers.
*   Improved sequence representations for the BiLSTM (e.g., using multiple traces, incorporating spatial location encoding).
*   Hyperparameter optimization and data augmentation techniques.
*   Testing on different datasets and geological settings.

This work provides a baseline comparison between spatial and sequential deep learning approaches for seismic facies classification, contributing to the ongoing effort to develop more efficient and objective interpretation workflows in geophysics.

*(Remember to replace placeholders like [Citation Needed] and bracketed metric values with actual information.)*
