

<p align="center">
    <img src="https://github.com/GeostatsGuy/GeostatsPy/blob/master/TCG_color_logo.png?raw=true" width="220" height="240" />

</p>

## DIRECT 6th Annual Consortium 

### Generalized Conditioning of Generative Artificial Intelligence for History Matching Subsurface Models

#### Ahmed Merzoug, PhD Student, The University of Texas at Austin
#### [LinkedIn](https://www.linkedin.com/in/ahmed-merzoug/) | [GitHub](https://github.com/amerzoug) | [GoogleScholar](https://scholar.google.com/citations?user=Ppx0Y1sAAAAJ&hl=en&oi=ao)

#### Honggeun Jo, Assistant Professor, Inha University
#### [LinkedIn](https://www.linkedin.com/in/honggeun-jo/) | [GoogleScholar](https://scholar.google.com/citations?user=u0OE5CIAAAAJ&hl=en)


#### Michael Pyrcz, Professor,The University of Texas at Austin
##### [Twitter](https://twitter.com/geostatsguy) | [GitHub](https://github.com/GeostatsGuy) | [Website](http://michaelpyrcz.com) | [GoogleScholar](https://scholar.google.com/citations?user=QVZ20eQAAAAJ&hl=en&oi=ao) | [Book](https://www.amazon.com/Geostatistical-Reservoir-Modeling-Michael-Pyrcz/dp/0199731446) | [YouTube](https://www.youtube.com/channel/UCLqEr-xV-ceHdXXXrTId5ig)  | [LinkedIn](https://www.linkedin.com/in/michael-pyrcz-61a648a1) | [GeostatsPy](https://github.com/GeostatsGuy/GeostatsPy)


#### Work completed as part of the DIRECT consortium for Subsurface Data Analytics and Machine Learning

### Overview

This Jupyter Notebook contains the Python code implementing the generalized conditioning approach for history matching subsurface models using Generative Artificial Intelligence (GenAI) and an Inference Network, as described in the associated research paper.

The core idea is to train a primary GenAI model (specifically, a WGAN-GP) on **unconditioned** subsurface realizations. A separate, smaller Inference Network is then trained to map a latent space to the input space of the GenAI's generator such that the generated models are conditioned to specific well data (hard data and dynamic response). For history matching, the combined Inference Network and Generator are used within an Ensemble Smoother with Multiple Data Assimilation (ES-MDA) framework.

The key advantage of this approach is its **generalizability**. When new well data becomes available (e.g., due to new wells being drilled), only the computationally inexpensive Inference Network needs to be retrained, while the expensive primary GenAI model remains fixed. This significantly reduces computational costs compared to methods that require retraining the entire generative model.

The notebook demonstrates this workflow using a 3D fluvial channel reservoir case study. It simulates forward models using an external reservoir simulator (CMG) and assimilates dynamic production data (BHP) to update the latent space and generate history-matched reservoir models.

### Project Structure

The code is designed to be run in a Jupyter Notebook environment. It requires several Python libraries and external components:

1.  **Python Libraries:** `numpy`, `multiprocessing`, `time`, `os`, `torch`, `torch.nn`, `pandas`, `sklearn.metrics.mean_squared_error`, `pathlib`, `subprocess`, `scipy.interpolate`.
2.  **External CMG Simulator:** The code interacts with the CMG IMEX simulator executable to run forward simulations. The path to the executable must be configured (`CMG` variable in `worker` function).
3.  **`Sr3Reader.py` Module:** A custom Python module is required to read CMG simulation output files (`.sr3`). This module should contain `read_SR3` and `get_wells_timeseries` functions.
4.  **Pre-trained Models:** The code loads a pre-trained Generator model (`checkpoint_epoch_*.pt`) and a pre-trained Inference Network model (`inference_net_epoch_*.pt`). These are assumed to be located in the `base_dir`.
5.  **Simulation Template Files:** A base CMG input file template (`CMGBuilder00.dat`) and a file containing target days for interpolation (`Da2.csv`) are required in the `base_dir`.
6.  **Truth Data:** CSV files containing the observed (truth) well production/pressure data (`Truth/Inj 1_CMG_aligned.csv`, `Truth/Prod 1_CMG_aligned.csv`, etc.) are required.

### Code Functionality

The notebook defines and executes the following steps:

1.  **Setup and Configuration:** Imports necessary libraries, defines the Generator and Inference Network architectures (based on Appendix A and C of the paper), loads pre-trained models, and sets up key parameters in a `CONFIG` dictionary (paths, network dimensions, ESMDA settings).
2.  **Helper Functions:**
    *   `load_generator`: Loads a PyTorch Generator model from a checkpoint.
    *   `flatteningRealization`: Reshapes a 3D realization into a 1D array slice by slice and maps facies values (0, 1, 2) to permeability values (0.01, 150, 50) suitable for the simulator input.
    *   `create_dat_file`: Writes the flattened permeability data to a `.dat` file with a CMG-specific header.
    *   `modify_dat_file`: Modifies the main CMG input template to include the generated permeability file name.
    *   `worker`: This is the core simulation worker function. It takes a single realization, prepares the necessary CMG input files (`.dat`), runs the CMG simulator using `subprocess`, waits for the simulation output (`.sr3`), reads the output using `Sr3Reader`, extracts well timeseries data, interpolates it to target days (from `Da2.csv`), saves the original and aligned timeseries, and cleans up intermediate files. It includes error handling for simulation failure and file processing issues.
    *   `threshold_samples`: Applies a 0.5 threshold to the continuous generator output to obtain discrete facies values (0 or 1).
3.  **ESMDA Execution:**
    *   The main execution block (`if __name__ == '__main__':`) orchestrates the ESMDA process.
    *   **Load and Standardize Observations:** Reads the truth well data and standardizes it using its mean and standard deviation. This standardization is crucial for the assimilation step.
    *   **Initialize Ensemble:** Initializes the latent variable ensemble (`z_ensemble`) from a standard normal distribution.
    *   **ESMDA Loop:** Iterates for the specified number of ESMDA steps (`n_iter`).
        *   **Map Latent to Generator Input:** Uses the `inference_net` to transform `z_ensemble` into `w_ensemble`. (`w_ensemble = inference_net(z_ensemble)`)
        *   **Generate Models:** Uses the `generator` to create subsurface realizations from `w_ensemble`. (`generated_batch = generator(w_ensemble)`)
        *   **Threshold Models:** Applies `threshold_samples` to get discrete facies/permeability models.
        *   **Run Simulations:** Executes the `worker` function for each ensemble member in parallel using `multiprocessing`.
        *   **Collect and Standardize Results:** Reads the aligned well timeseries data saved by the `worker` functions and standardizes it using the *truth* data's mean and standard deviation.
        *   **Compute MSE:** Calculates the Mean Squared Error (MSE) between the *original* truth data and the *original* (unstandardized) simulation results for diagnostic purposes. Stores this history. Includes backing up and cleaning up simulation outputs.
        *   **Compute Kalman Gain:** Calculates the sample covariances (`Cov_zd_std`, `Cov_dd_std`) between the latent variable anomalies (`A_z`) and the standardized simulation data anomalies (`A_d_std`), and computes the Kalman Gain (`K_std`).
        *   **Update Latent Ensemble:** Applies the ESMDA update formula to each member of the latent ensemble (`z_ensemble`), incorporating a noisy observation (`d_obs_std + eps_std`) and the simulated data (`d_model_std[i]`).
        *   **Save Ensembles:** Saves the updated `w_ensemble` and `z_ensemble` to `.pt` files, appending the iteration number.
    *   **Save MSE History:** After the loop, saves the recorded MSE history to a CSV file.

### Import Necessary Libraries

In [None]:
import numpy as np
import multiprocessing
import time
import os
import torch
import torch.nn as nn
import pandas as pd
from sklearn.metrics import mean_squared_error
from pathlib import Path
import math
from typing import List, Tuple
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pathlib
import subprocess # Used for running external programs like the CMG simulator
import shutil # Used for file operations
# Libraries for data handling and interpolation
import pandas as pd
from scipy.interpolate import interp1d
# Assuming Sr3Reader.py is available in the same directory or PYTHONPATH
# This module is expected to provide functions for reading CMG .sr3 output files.
# Error handling is included in the worker function for the import.
try:
    from Sr3Reader import read_SR3, get_wells_timeseries
except ImportError:
    print("Warning: Sr3Reader module not found. Simulation output processing will fail.")



#### Define Functions for History Matching 


This section defines components and utilities for a workflow involving geological model generation, processing, and reservoir simulation:

1.  **`Generator` Class:** Defines a neural network that takes a low-dimensional latent vector and transforms it into a 3D geological volume using transposed 3D convolutions. It typically outputs continuous values (0-1 range) representing a geological property or facies probability.
2.  **`load_generator` Function:** A utility to load a pre-trained instance of the `Generator` model from a saved file (`.pt` or `.pth`). It sets the model to evaluation mode.
3.  **`InferenceNet` Class:** Defines a neural network that maps an input noise vector (`w`) to a latent vector (`z`). This network is designed to be trained to produce latent vectors suitable for the Generator, often conditioned on external data (as described in the previous context).
4.  **Data Processing for Simulation:** Several functions handle converting the generated 3D volumes into a format required by reservoir simulators (like CMG):
    *   **`threshold_samples`:** Converts continuous generator output into discrete values (e.g., facies indices like 0 or 1).
    *   **`flatteningRealization`:** Takes a 3D array, flattens it slice-by-slice, and applies a mapping (e.g., facies to permeability values) to prepare it for simulation input.
    *   **`create_dat_file`:** Writes the processed 1D data array into a `.dat` file with a specific header, suitable for simulator input.
    *   **`modify_dat_file`:** Modifies a template simulation input file (`.dat`) to point to the newly created data file.
5.  **Simulation Workflow:** The **`worker`** function encapsulates the process of taking a 3D realization, preparing its simulation input files, running an external reservoir simulator (using `subprocess`), processing the simulator's output (extracting and interpolating well timeseries data using external `Sr3Reader` functions and `interp1d`), and cleaning up intermediate files.

In essence, this code provides the tools to generate 3D geological models, prepare them for reservoir simulation, run the simulation, and process the results, integrating a pre-trained generator and a potentially trained inference network into a simulation-based workflow.

In [2]:

class Generator(nn.Module):
    """
    Generator network definition.
    This network takes a latent vector (noise) and outputs a 3D volume realization.
    It uses a series of ConvTranspose3d layers to upsample the input.
    """

    def __init__(self, latent_dim=100):
        """
        Initializes the Generator network.

        Args:
            latent_dim (int): Dimension of the input latent vector.
        """
        super().__init__()
        self.net = nn.Sequential(
            # Project and reshape the latent vector to an initial 3D feature map
            nn.Linear(latent_dim, 512 * 2 * 8 * 8),
            nn.LeakyReLU(0.2),
            # Unflatten the linear output into a 3D volume (channels, depth, height, width)
            nn.Unflatten(1, (512, 2, 8, 8)),

            # Upsampling block 1: Increase depth, height, and width by factor of 2
            nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(256), # Normalize across the batch and channels
            nn.LeakyReLU(0.2),

            # Upsampling block 2: Increase depth, height, and width by factor of 2
            nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2),

            # Upsampling block 3: Increase depth, height, and width by factor of 2
            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2),

            # Final upsampling block: Output 1 channel.
            # Kernel size (3, 4, 4), Stride (1, 2, 2). Stride 1 in depth direction, 2 in spatial.
            nn.ConvTranspose3d(64, 1, kernel_size=(3, 4, 4), stride=(1, 2, 2), padding=1),
            nn.Sigmoid() # Output values between 0 and 1
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the generator network.

        Args:
            x (torch.Tensor): Input latent vector (batch_size, latent_dim).

        Returns:
            torch.Tensor: Generated 3D volume (batch_size, 1, depth, height, width).
                          Values are between 0 and 1.
        """
        return self.net(x)


def load_generator(checkpoint_path: str, device: torch.device) -> nn.Module:
    """
    Load the trained generator from a checkpoint file.

    Args:
        checkpoint_path (str): Path to the generator checkpoint file (.pth).
        device (torch.device): Device (e.g., 'cuda', 'cpu') to load the generator onto.

    Returns:
        nn.Module: Loaded generator model in evaluation mode.
    """
    # Initialize the generator model with default latent dimension
    generator = Generator()
    # Move the model to the specified device
    generator.to(device)
    
    # Load the checkpoint file from disk
    # map_location ensures the tensor is loaded onto the correct device,
    # regardless of where it was saved.
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Check if the checkpoint is a dictionary containing the state dict
    # (common practice when saving optimizer state, epoch, etc.)
    if 'generator_state_dict' in checkpoint:
        print("Loading generator state dict from checkpoint['generator_state_dict']")
        # Load the generator's state dictionary from the specific key
        generator.load_state_dict(checkpoint['generator_state_dict'])
    else:
        print("Loading generator state dict directly from checkpoint (assuming it's just the state dict)")
        # Assume the checkpoint file *is* just the generator's state dict
        generator.load_state_dict(checkpoint)
    
    # Set the model to evaluation mode (disables dropout, batch norm updates, etc.)
    generator.eval()
    print(f"Generator loaded successfully from {checkpoint_path}")
    return generator

def flatteningRealization(realization: np.ndarray) -> np.ndarray:
    """
    Flattens a 3D realization array slice by slice along the last axis (depth/z),
    processing from the top slice (index shape[-1]-1) down to the bottom slice (index 0).
    Applies a mapping to discrete values: 0 -> 0.01, 1 -> 150, 2 -> 50.
    This format is often required by reservoir simulators like CMG.

    Args:
        realization (np.ndarray): A 3D NumPy array representing the realization
                                 (e.g., shape [nx, ny, nz]).

    Returns:
        np.ndarray: A 1D NumPy array containing the flattened and mapped values,
                    ordered slice by slice from top to bottom.
    """
    T = [] # List to store flattened slices
    # Get the number of slices along the last dimension (depth)
    num_slices = realization.shape[-1]
    
    print(f"[{time.strftime('%H:%M:%S')}] flatteningRealization: Input shape {realization.shape}. Processing {num_slices} slices.", flush=True)

    # Process slices in reverse order (from the last slice index down to 0)
    # This assumes the last dimension corresponds to depth and simulation input
    # expects data from top layers first. Check this assumption based on your simulator.
    for reverse_idx in range(num_slices):
        # Calculate the actual slice index (e.g., for num_slices=10, reverse_idx 0 -> idx 9, reverse_idx 9 -> idx 0)
        idx = num_slices - 1 - reverse_idx
        
        # Extract the 2D slice at the current depth index
        current_slice = realization[:, :, idx]

        print(f"[{time.strftime('%H:%M:%S')}] flatteningRealization: Processing slice {idx} (reverse order {reverse_idx})", flush=True)

        # Flatten the 2D slice into a 1D array (row by row)
        flattened_array = current_slice.flatten()

        # Apply the mapping based on facies values (assuming values 0, 1, 2 represent facies)
        # Replace 0 with 0.01 (e.g., low permeability)
        flattened_array = np.where(flattened_array == 0, 0.01, flattened_array)
        # Replace 1 with 150 (e.g., high permeability)
        flattened_array = np.where(flattened_array == 1, 150, flattened_array)
        # Replace 2 with 50 (e.g., medium permeability)
        flattened_array = np.where(flattened_array == 2, 50, flattened_array)
        
        # Append the processed and flattened slice to the list
        T.append(flattened_array)
    
    # Concatenate all flattened slices into a single 1D array
    result = np.concatenate(T)
    print(f"[{time.strftime('%H:%M:%S')}] flatteningRealization: Finished. Result length = {len(result)}", flush=True)
    return result

def create_dat_file(filename: str, data_array: np.ndarray):
    """
    Creates a .dat file formatted for reservoir simulation input (e.g., CMG).
    Writes a header line followed by the data values, one value per line.

    Args:
        filename (str): The full path and name of the .dat file to create.
        data_array (np.ndarray): A 1D NumPy array containing the data values to write.
    """
    print(f"[{time.strftime('%H:%M:%S')}] create_dat_file: Creating file '{filename}' with {len(data_array)} values.", flush=True)
    try:
        # Open the file in write mode ('w')
        with open(filename, 'w') as f:
            # Write the required header line for permeability data in CMG
            f.write("*PERMI *ALL\n")
            # Write each data value on a new line
            for value in data_array:
                f.write(str(value) + "\n")
        print(f"[{time.strftime('%H:%M:%S')}] create_dat_file: '{filename}' created successfully.", flush=True)
    except Exception as e:
        # Print an error message if file creation fails
        print(f"[{time.strftime('%H:%M:%S')}] create_dat_file: Error creating '{filename}': {e}", flush=True)

def modify_dat_file(input_file: str, output_file: str, old_keyword: str, new_keyword: str):
    """
    Reads the content of an input file, replaces all occurrences of a specific
    string (old_keyword) with another string (new_keyword), and writes the
    modified content to an output file.

    This is typically used to update a master simulation input file (.dat)
    to point to a newly created data file (e.g., the permeability file).

    Args:
        input_file (str): Path to the original template .dat file.
        output_file (str): Path where the modified .dat file will be saved.
        old_keyword (str): The string to search for and replace (e.g., a placeholder filename).
        new_keyword (str): The string to replace with (e.g., the actual filename).
    """
    print(f"[{time.strftime('%H:%M:%S')}] modify_dat_file: Modifying file '{input_file}' and saving to '{output_file}'.", flush=True)
    
    # Check if the input file exists
    if not os.path.exists(input_file):
        print(f"[{time.strftime('%H:%M:%S')}] modify_dat_file: ERROR - The input file '{input_file}' does not exist.", flush=True)
        return # Exit the function if the input file is not found
        
    try:
        # Open the input file in read mode ('r')
        with open(input_file, 'r') as file:
            # Read the entire content of the file
            data = file.read()
            
        # Perform the string replacement
        modified_data = data.replace(old_keyword, new_keyword)
        
        # Open the output file in write mode ('w')
        with open(output_file, 'w') as file:
            # Write the modified content to the output file
            file.write(modified_data)
            
        print(f"[{time.strftime('%H:%M:%S')}] modify_dat_file: Successfully replaced '{old_keyword}' with '{new_keyword}'.", flush=True)
        
    except Exception as e:
        # Print an error message if reading, replacing, or writing fails
        print(f"[{time.strftime('%H:%M:%S')}] modify_dat_file: Error during modification: {e}", flush=True)


def worker(i: int, realization: np.ndarray):
    """
    Worker function designed to process a single 3D realization, simulate
    reservoir flow using an external simulator (CMG), extract well data
    from the simulation output, process and save the data, and clean up
    intermediate files. This function is intended to be potentially run
    for multiple realizations, possibly in parallel.

    Args:
        i (int): An index or identifier for this specific realization/worker.
                 Used for generating unique filenames.
        realization (np.ndarray): A 3D NumPy array representing the realization.
                                  Expected shape [nx, ny, nz].
    """
    try:
        print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Started processing realization.", flush=True)
        start_time = time.time() # Record start time for performance tracking
        
        # --- Configuration ---
        # Path to the CMG IMEX executable
        CMG = "C:\\Program Files\\CMG\\IMEX\\2023.30\\Win_x64\\EXE\\mx202330.exe"
        

        # Base directory for input/output files
        base_dir = "D:/Generalizedized_GAN/ESMDA/"
        
        # Template CMG input file (needs to be modified)
        template_file = os.path.join(base_dir, 'CMGBuilder00.dat')

        # Directories for saving well timeseries data
        output_dir = os.path.join(base_dir, 'wells_timeseries')
        alined_dir = os.path.join(base_dir, 'wells_timeseries_alined')
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(alined_dir, exist_ok=True)


        # --- Prepare Permeability Input File ---
        # Flatten and map the realization data to permeability values for CMG input
        data = flatteningRealization(realization)
        
        # Define the unique filename for the permeability input file
        dat_filename = os.path.join(base_dir, f"IN_PERM_X{i}.DAT")
        
        # Create the .dat file with the processed permeability data
        create_dat_file(dat_filename, data)
        
        # --- Modify Simulation Input File ---
        # Define the output file path for the modified CMG input file
        output_file = os.path.join(base_dir, f"Ensemble{i}.dat")
        # Define the placeholder keyword in the template and the new keyword (the actual filename)
        old_keyword = 'IN_PERM_X.DAT' # This keyword is expected in CMGBuilder00.dat
        new_keyword = f'IN_PERM_X{i}.DAT' # This is the name of the file we just created
        
        # Modify the template file to point to the new permeability file and save as Ensemble{i}.dat
        modify_dat_file(template_file, output_file, old_keyword, new_keyword)
 
        # --- Run Simulation ---
        print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Running CMG simulation with '{output_file}'.", flush=True)
        # Execute the CMG simulator. The -f flag specifies the input file, -np 1 specifies 1 processor.
        # subprocess.call waits for the command to complete.
        # Assumes CMG executable is in the system's PATH or full path is provided.
        # Assumes the command is correctly formed for your CMG installation.
        ret_code = subprocess.call([CMG, "-f", output_file, "-np", "1"])
        
        # Check the return code of the subprocess call
        if ret_code != 0:
            # If the return code is non-zero, the simulation failed
            print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Simulation failed with return code {ret_code}. Skipping SR3 processing and cleanup.", flush=True)
            # Optionally, clean up input files even on failure:
            # if os.path.exists(dat_filename): os.remove(dat_filename)
            # if os.path.exists(output_file): os.remove(output_file)
            return # Exit the worker function if simulation failed
        print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: CMG simulation completed successfully (return code {ret_code}).", flush=True)

        # -------------------------------------------------
        # --- Process Simulation Output (.sr3 file) ---
        # -------------------------------------------------
        # We assume that running the simulation with input file "Ensemble{i}.dat"
        # automatically produces an SR3 file named "Ensemble{i}.sr3" in the same directory.
        sr3_filename = os.path.join(base_dir, f"Ensemble{i}.sr3")
        
        # Wait loop to ensure the SR3 file has been created by the simulator
        timeout = 300 # Max time to wait for SR3 file (e.g., 5 minutes)
        wait_interval = 5 # seconds between checks
        waited_time = 0
        
        print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Waiting for SR3 file '{sr3_filename}'.", flush=True)
        while not os.path.exists(sr3_filename) and waited_time < timeout:
            time.sleep(wait_interval)
            waited_time += wait_interval
            print(f"[{time.strftime('%H:%H:%S')}] Worker {i}: Waited {waited_time}/{timeout}s for '{sr3_filename}'.", flush=True)
            
        # Check if the SR3 file exists after waiting
        if os.path.exists(sr3_filename):
            try:
                # Check if the Sr3Reader functions were successfully imported earlier
                if read_SR3 is None or get_wells_timeseries is None:
                     raise ImportError("Sr3Reader functions not available.")

                print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: SR3 file found. Processing '{sr3_filename}'.", flush=True)
                # Read the SR3 file content using the Sr3Reader library
                sr3 = read_SR3(sr3_filename)
                
                # Extract well timeseries data from the SR3 object
                wells_ts = get_wells_timeseries(sr3)
                print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Extracted timeseries data for {len(wells_ts)} wells.", flush=True)
                # print(wells_ts) # Optional: print extracted data structure

                # --- Save and Process Well Timeseries Data ---
                # Define a unique identifier for this simulation run
                run_id = f"Ensemble{i}"
                
                # Load target days for interpolation (assuming Da2.csv exists and has a 'Days' column)
                try:
                    new_days_df = pd.read_csv(os.path.join(base_dir, "Da2.csv"))
                    new_days = new_days_df["Days"].values
                    print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Loaded {len(new_days)} target days for interpolation.", flush=True)
                except FileNotFoundError:
                    print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: ERROR - Da2.csv not found at {os.path.join(base_dir, 'Da2.csv')}. Skipping interpolation.", flush=True)
                    new_days = None # Set to None to skip interpolation

                # Iterate through each well's timeseries data
                for well_name, ts_data in wells_ts.items():
                    print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Processing well '{well_name}'.", flush=True)
                    
                    # Ensure the 'Days' column exists and is sorted
                    if "Days" not in ts_data.columns:
                         print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: WARNING - 'Days' column not found for well '{well_name}'. Skipping.", flush=True)
                         continue # Skip this well if 'Days' column is missing

                    # Sort data by Days, just in case
                    ts_data = ts_data.sort_values(by="Days").reset_index(drop=True)

                    # Save the original extracted timeseries data as a CSV
                    unique_filename = f"{well_name}_{run_id}.csv"
                    output_path = os.path.join(output_dir, unique_filename)
                    ts_data.to_csv(output_path, index=False)
                    print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Saved original timeseries for '{well_name}' to '{output_path}'.", flush=True)

                    # --- Interpolate Timeseries Data ---
                    if new_days is not None:
                        # Define the columns from the timeseries data that need to be interpolated
                        columns_to_model = ["BHP", "OILVOLSC", "GASVOLSC", "WATVOLSC", "LIQVOLSC", 
                                          "OILRATSC", "GASRATSC", "WATRATSC", "LIQRATSC"]
                        
                        # Filter for columns actually present in the DataFrame
                        columns_to_interpolate = [col for col in columns_to_model if col in ts_data.columns]
                        
                        if not columns_to_interpolate:
                            print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: WARNING - None of the target columns {columns_to_model} found for well '{well_name}'. Skipping interpolation.", flush=True)
                        else:
                            # Prepare dictionary to hold interpolated results
                            predictions = {"Days": new_days}
                            
                            # Perform linear interpolation for each specified column
                            source_days = ts_data["Days"].values # Original days from simulation output
                            
                            for col in columns_to_interpolate:
                                try:
                                    # Create an interpolation function. fill_value='extrapolate' handles days outside the original range.
                                    f = interp1d(source_days, ts_data[col].values, 
                                               kind='linear', fill_value='extrapolate')
                                    # Apply the interpolation function to the new target days
                                    predictions[col] = f(new_days)
                                except ValueError as e:
                                    print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: ERROR - Interpolation failed for well '{well_name}', column '{col}': {e}. Data might not be suitable for interpolation.", flush=True)
                                    # Skip interpolation for this column but continue with others
                                    predictions[col] = np.nan # Or handle appropriately

                            # Save the interpolated (aligned) predictions as a CSV
                            predictions_df = pd.DataFrame(predictions)
                            unique_filename2 = f"{well_name}_{run_id}_Aligned.csv"
                            predictions_output_path = os.path.join(alined_dir, unique_filename2)
                            predictions_df.to_csv(predictions_output_path, index=False)
                            print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Saved aligned timeseries for '{well_name}' to '{predictions_output_path}'.", flush=True)


                # --- Clean Up Intermediate Files ---
                print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Cleaning up intermediate files.", flush=True)
                del sr3 # Release the SR3 object memory
                # Remove the SR3 output file
                if os.path.exists(sr3_filename):
                     os.remove(sr3_filename)
                     print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Removed '{sr3_filename}'.", flush=True)
                # Remove the generated permeability input file
                if os.path.exists(dat_filename):
                     os.remove(dat_filename)
                     print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Removed '{dat_filename}'.", flush=True)
                # Remove the modified simulation input file
                if os.path.exists(output_file):
                     os.remove(output_file)
                     print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Removed '{output_file}'.", flush=True)

            except Exception as e:
                # Catch any errors during SR3 processing or file handling
                print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Exception occurred during SR3 processing or cleanup: {e}", flush=True)
                # Note: Files might not be cleaned up if an error occurs here.

        else:
            # Message if the SR3 file was not found after the timeout
            print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: ERROR - SR3 file '{sr3_filename}' not found after simulation or timeout.", flush=True)
            # Optionally, clean up input files even if SR3 wasn't found:
            # if os.path.exists(dat_filename): os.remove(dat_filename)
            # if os.path.exists(output_file): os.remove(output_file)
        
        # Calculate and print total time taken for this worker
        elapsed = time.time() - start_time
        print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: Finished processing in {elapsed:.2f} seconds.", flush=True)

    except Exception as e:
        # Catch any errors that occurred outside the specific SR3 processing block
        print(f"[{time.strftime('%H:%M:%S')}] Worker {i}: An unexpected Exception occurred: {e}", flush=True)


def threshold_samples(samples: np.ndarray) -> np.ndarray:
    """
    Applies a simple binary threshold to the generated samples.
    Values below 0.5 are set to 0, and values >= 0.5 are set to 1.
    This is typically used to convert continuous generator outputs (0-1 range)
    into discrete facies indices (e.g., 0 or 1).

    Args:
        samples (np.ndarray): A NumPy array of continuous values (e.g., output of a Sigmoid).

    Returns:
        np.ndarray: A NumPy array with values thresholded to 0 or 1.
    """
    print(f"[{time.strftime('%H:%M:%S')}] threshold_samples: Applying binary threshold at 0.5.", flush=True)
    # Use numpy's where function for efficient element-wise conditional replacement
    # The condition is `samples >= 0.5`.
    # If True, assign 1.
    # If False, assign 0.
    samples = np.where(samples >= 0.5, 1, 0)
    print(f"[{time.strftime('%H:%M:%S')}] threshold_samples: Thresholding complete. Values are now 0 or 1.", flush=True)
    return samples

class InferenceNet(nn.Module):
    """
    Inference network that transforms an input noise vector (w) into an output
    latent vector (z). The intention is typically for 'z' to have properties
    (e.g., Gaussian) that make it suitable as input to a generator, potentially
    conditioned on external data (though the conditioning data input is missing
    in this network definition).
    
    The architecture uses a sequence of linear layers with BatchNorm and SELU activations.
    """

    def __init__(self, w_dim: int = 100, z_dim: int = 100, hidden_dim: int = 256):
        """
        Initializes the Inference network.
        
        Args:
            w_dim (int): Dimension of the input noise vector (w).
            z_dim (int): Dimension of the output latent vector (z).
            hidden_dim (int): Dimension of the hidden layers.
        """
        super(InferenceNet, self).__init__()
        self.net = nn.Sequential(
            # Linear layer to project input noise to hidden dimension
            nn.Linear(w_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), # Batch normalization
            nn.SELU(), # Scaled Exponential Linear Unit activation

            # Additional hidden layers
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SELU(),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SELU(),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SELU(),

            # Final linear layer to project to the output latent dimension
            nn.Linear(hidden_dim, z_dim)  # Output is a vector of size z_dim
        )
        
        # Initialize network weights for potentially better convergence
        self._init_weights()
        print(f"InferenceNet initialized with w_dim={w_dim}, z_dim={z_dim}, hidden_dim={hidden_dim}")
        
    def _init_weights(self):
        """Initializes network weights using Xavier Normal and biases to zeros."""
        print("Initializing InferenceNet weights...")
        for m in self.modules(): # Iterate through all modules in the network
            if isinstance(m, nn.Linear): # Check if the module is a Linear layer
                # Initialize weights using Xavier Normal distribution
                nn.init.xavier_normal_(m.weight)
                # Initialize biases to zeros
                nn.init.zeros_(m.bias)
        print("InferenceNet weights initialized.")
        
    def forward(self, w: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the Inference network.
        
        Args:
            w (torch.Tensor): Input noise vector (batch_size, w_dim), typically from a Gaussian distribution.
            
        Returns:
            torch.Tensor: Output latent vector (batch_size, z_dim).
        """
        # Pass the input noise vector through the sequential network
        z = self.net(w)
        return z

#### Training the Inference Network 


This section focuses on generating diverse geological realizations that honor observed data at specific well locations. It achieves this by training a separate **Inference Network** while using a pre-trained, **frozen** image **Generator**.

1.  **Frozen Generator:** A pre-trained generator model is loaded and its parameters are fixed. It acts as a fixed mapping from latent vectors (`z`) to geological images.
2.  **Inference Network:** A new neural network (`InferenceNet`) is defined and trained. It takes simple random noise (`w`) as input and outputs a latent vector (`z`).
3.  **Conditioning Data:** Observed geological values (`d_obs`) at specific well locations are loaded and used as the conditioning data.
4.  **Training Objective (Cost Function):** The `InferenceNet` is trained to make the distribution of generated `z` vectors approximate the desired posterior distribution `P(z | d_obs)`. The loss function minimizes the difference between the approximate posterior `Q(z|w)` and the true posterior `P(z|d_obs)`. This involves three key components:
    *   **Well Likelihood:** Measures how well the images produced by the *frozen* generator (using `z` from the Inference Net) match the observed data at the well locations. This term drives the network to generate data-consistent images.
    *   **Gaussian Prior:** Encourages the latent vectors `z` to follow a standard Gaussian distribution, regularizing the latent space.
    *   **Entropy Estimation:** Estimates the entropy of the generated `z` distribution using a k-nearest neighbor method. This term promotes diversity in the generated `z` vectors, ensuring a wide range of plausible realizations are sampled.
5.  **Outcome:** After training, the `InferenceNet` can be used to sample new random noise vectors (`w`) and transform them into conditioned latent vectors (`z`). Feeding these `z` vectors into the frozen generator produces diverse geological realizations that are consistent with the observed well data.

In [None]:

# Configuration
CONFIG = {
    # Paths
    "base_dir": "D:/DA Project3/Corrected Bew results",
    "checkpoint_path": "checkpoint_epoch_3000.pt",
    
    
    # Well locations (row, col) format
    "well_locations": [(59, 37), (86, 40), (109, 39), (72, 72), (20,80), (95, 95) ],
    
    # Network parameters
    "w_dim": 100,          # Noise dimension
    "z_dim": 100,          # Latent dimension
    "hidden_dim": 256,     # Hidden layer dimension
    
    # Training parameters
    "alpha": 10,            # Likelihood variance
    "lr": 1e-4,            # Learning rate
    "batch_size": 50,      # Batch size
    "total_cases": 1000,   # Total number of generated examples
    "num_epochs": 400,     # Number of epochs
    "log_interval": 10,    # Log every N batches
    
    # Threshold parameters
    "threshold": 0.2,      # Threshold value for binary classification
}

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

###################################
# 1. Load the Frozen Generator
###################################
checkpoint_path = os.path.join(CONFIG["base_dir"], CONFIG["checkpoint_path"])
generator = load_generator(checkpoint_path, device)
generator.eval()  # Set to evaluation mode

# Freeze generator parameters
for param in generator.parameters():
    param.requires_grad = False
print("Generator loaded and parameters frozen.")

###################################
# 2. Load Observations (Well Data)
###################################
data_path = os.path.join(CONFIG["base_dir"], CONFIG["data_path"])
CA = np.load(data_path)
Truth = CA.transpose(1, 2, 0)  # shape: [H, W, Channels]

# Extract observations at well locations
facies_obs = []
for row, col in CONFIG["well_locations"]:
    facies_obs.append(Truth[row, col, :])

facies_obs = np.stack(facies_obs, axis=0)
d_obs = torch.tensor(facies_obs, dtype=torch.float32, device=device)

print(f"Shape of facies_obs (num_wells, Channels) = {facies_obs.shape}")
print(f"Shape of d_obs = {d_obs.shape}")

###################################
# 3. Define the Inference Network for Gaussian Sampling
###################################
class InferenceNet(nn.Module):
    def __init__(self, w_dim: int = 100, z_dim: int = 100, hidden_dim: int = 256):
        """
        Inference network that transforms noise vectors while maintaining Gaussian properties.
        The network takes random Gaussian noise as input and transforms it into a new Gaussian
        that is conditioned on the well data.
        
        Args:
            w_dim: Dimension of input noise vector
            z_dim: Dimension of output latent vector
            hidden_dim: Dimension of hidden layers
        """
        super(InferenceNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(w_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, z_dim)  # Output is a vector of size z_dim
        )
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize network weights for better convergence."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)
        
    def forward(self, w: torch.Tensor) -> torch.Tensor:
        """
        Transform input noise to output latent vector.
        
        Args:
            w: Input noise vector with Gaussian distribution
            
        Returns:
            z: Output latent vector that maintains Gaussian properties
        """
        z = self.net(w)
        return z

# Initialize inference network
inference_net = InferenceNet(
    CONFIG["w_dim"], 
    CONFIG["z_dim"], 
    CONFIG["hidden_dim"]
).to(device)

print(f"InferenceNet initialized with output size {CONFIG['z_dim']}")

###################################
# 4. Define Loss Functions and Entropy Estimator
###################################
def well_likelihood(z_batch: torch.Tensor, alpha: float = CONFIG["alpha"]) -> torch.Tensor:
    """
    Compute the well data likelihood for a batch of latent vectors z.
    
    Args:
        z_batch: Batch of latent vectors, shape [batch_size, z_dim]
        alpha: Likelihood variance
        
    Returns:
        Log likelihood for each sample in the batch
    """
    # Pass z through the generator
    x = generator(z_batch).squeeze()  # Expected shape: [batch, Channels, H, W]
    
    # Extract simulated observations at the well locations
    gen = x.permute(0, 2, 3, 1)  # [batch, H, W, Channels]
    
    # Collect simulated measurements at well locations
    sim_obs_list = []
    for row, col in CONFIG["well_locations"]:
        sim_obs_list.append(gen[:, row, col, :])
    
    # Stack to get shape [batch, num_wells, Channels]
    sim_obs = torch.stack(sim_obs_list, dim=1)
    
    # Binary cross-entropy loss
    bce_loss = F.binary_cross_entropy(sim_obs, d_obs.unsqueeze(0).expand_as(sim_obs), reduction='none')
    log_likelihood = -bce_loss.view(z_batch.shape[0], -1).sum(dim=1) / alpha
    
    return log_likelihood

def gaussian_prior(z_batch: torch.Tensor) -> torch.Tensor:
    """
    Compute the standard Gaussian prior for a batch of latent vectors.
    
    Args:
        z_batch: Batch of latent vectors
        
    Returns:
        Log prior for each sample in the batch
    """
    # Standard normal prior: -0.5 * sum(z^2)
    log_prior = -0.5 * (z_batch**2).view(z_batch.shape[0], -1).sum(dim=1)
    return log_prior

def compute_statistics(z_batch: torch.Tensor) -> tuple:
    """
    Compute the mean, variance, and covariance of the batch.
    
    Args:
        z_batch: Batch of latent vectors
        
    Returns:
        Tuple of (mean, variance, covariance matrix)
    """
    mean = torch.mean(z_batch, dim=0)
    centered = z_batch - mean.unsqueeze(0)
    variance = torch.mean(centered**2, dim=0)
    cov = torch.matmul(centered.T, centered) / z_batch.shape[0]
    return mean, variance, cov

def pairwise_distances(z_batch: torch.Tensor) -> torch.Tensor:
    """
    Compute pairwise squared distances between all points in the batch.
    
    Args:
        z_batch: Batch of vectors, shape [batch_size, dim]
        
    Returns:
        Matrix of squared distances, shape [batch_size, batch_size]
    """
    # Compute squared Euclidean distance between all pairs
    # ||a - b||^2 = ||a||^2 + ||b||^2 - 2 * a · b
    
    # Compute squared norms for each vector: ||a||^2
    z_norm = (z_batch**2).sum(1).view(-1, 1)
    
    # Transpose for matrix multiplication
    z_t = torch.transpose(z_batch, 0, 1)
    
    # Transpose squared norms to match matrix shape: ||b||^2
    z_t_norm = z_norm.view(1, -1)
    
    # Calculate ||a||^2 + ||b||^2 - 2 * a · b
    dist = z_norm + z_t_norm - 2.0 * torch.mm(z_batch, z_t)
    
    # Ensure no negative distances due to numerical issues
    dist = torch.clamp(dist, 0.0, float('inf'))
    
    return dist

def kozachenko_leonenko_entropy(z_batch: torch.Tensor, k: int = 5) -> torch.Tensor:
    """
    Estimate the entropy of a distribution using the Kozachenko-Leonenko estimator.
    
    Args:
        z_batch: Batch of vectors sampled from distribution, shape [batch_size, dim]
        k: Number of nearest neighbor to use (recommended: sqrt(batch_size))
        
    Returns:
        Entropy estimate
    """
    batch_size, dim = z_batch.shape
    
    # Compute pairwise distances
    distances = pairwise_distances(z_batch)
    
    # Set diagonal to infinity (exclude self)
    distances = distances + torch.diag(torch.ones(batch_size, device=z_batch.device) * float('inf'))
    
    # Get k-th smallest distance for each row
    knn_dist, _ = torch.topk(distances, k=k, dim=1, largest=False)
    
    # Get the k-th nearest distance
    kth_distance = knn_dist[:, k-1]
    
    # Log distance (with small epsilon for numerical stability)
    log_dist = torch.log(kth_distance + 1e-8)
    
    # Entropy estimate (excluding constants)
    # H = dim * mean(log(distance)) + constant
    entropy = dim * torch.mean(log_dist)
    
    return entropy

###################################
# 5. Training Loop (Modified to include Entropy)
###################################
def train():
    """Execute the training loop for the inference network."""
    # Setup optimizer with improved settings
    optimizer = optim.Adam(
        inference_net.parameters(), 
        lr=CONFIG["lr"], 
        amsgrad=True, 
        betas=(0.5, 0.9)
    )
    
    # Learning rate scheduler for better convergence
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=10, factor=0.5, verbose=True
    )
    
    # Calculate batches per epoch
    num_batches_per_epoch = CONFIG["total_cases"] // CONFIG["batch_size"]
    batch_size = CONFIG["batch_size"]
    
    # Track metrics
    metrics = {
        'epoch': [],
        'loss': [],
        'well_likelihood': [],
        'prior': [],
        'entropy': [],
        'gaussian_metrics': []
    }
    
    print(f"Starting training for {CONFIG['num_epochs']} epochs")
    print(f"Batches per epoch: {num_batches_per_epoch}")
    print(f"Using Kozachenko-Leonenko entropy estimator with k=sqrt(batch_size)")
    
    # Get k for nearest neighbor entropy estimator (sqrt of batch size)
    k = max(1, int(math.sqrt(batch_size)))
    print(f"Using k={k} for entropy estimation")
    
    for epoch in range(CONFIG["num_epochs"]):
        epoch_loss = 0
        epoch_well_likelihood = 0
        epoch_prior = 0
        epoch_entropy = 0
        epoch_gaussian_metrics = []
        
        for batch_idx in range(num_batches_per_epoch):
            optimizer.zero_grad()
            
            # Sample random noise vector w ~ N(0,I)
            w = torch.randn(batch_size, CONFIG["w_dim"], device=device)
            
            # Get latent code from inference network
            z = inference_net(w)  # [batch_size, z_dim]
            
            # Compute well data likelihood (negative log likelihood)
            ll = well_likelihood(z)
            
            # Compute Gaussian prior
            prior = gaussian_prior(z)
            
            # Compute expected negative log posterior (loss term)
            expected_loss = -(ll + prior).mean()
            
            # Compute entropy using Kozachenko-Leonenko estimator
            entropy = kozachenko_leonenko_entropy(z, k)
            
            # Total loss: expected loss (negative log posterior) - entropy
            # This follows the Kullback-Leibler divergence minimization approach
            # KL(q||p) = E_q[log q] - E_q[log p] = -H(q) + E_q[-log p]
            total_loss = expected_loss - 0.5*entropy
            
            # Backpropagation and optimization
            total_loss.backward()
            # After loss.backward() but before optimizer.step()
            torch.nn.utils.clip_grad_norm_(inference_net.parameters(), max_norm=1.0)
            optimizer.step()
            
            # Compute Gaussian metrics for monitoring
            with torch.no_grad():
                mean, variance, _ = compute_statistics(z)
                mean_error = torch.mean(torch.abs(mean))
                var_error = torch.mean(torch.abs(variance - 1.0))
                epoch_gaussian_metrics.append((mean_error.item(), var_error.item()))
            
            # Update running metrics
            epoch_loss += total_loss.item()
            epoch_well_likelihood += ll.mean().item()
            epoch_prior += prior.mean().item()
            epoch_entropy += entropy.item()
            
            # Log progress
            if batch_idx % CONFIG["log_interval"] == 0:
                print(f"Epoch {epoch}, Batch {batch_idx} | "
                      f"well_ll: {ll.mean().item():.4f}, "
                      f"prior: {prior.mean().item():.4f}, "
                      f"entropy: {entropy.item():.4f}, "
                      f"loss: {total_loss.item():.4f}")
        
        # Calculate epoch averages
        avg_loss = epoch_loss / num_batches_per_epoch
        avg_well_likelihood = epoch_well_likelihood / num_batches_per_epoch
        avg_prior = epoch_prior / num_batches_per_epoch
        avg_entropy = epoch_entropy / num_batches_per_epoch
        avg_mean_error, avg_var_error = np.mean(epoch_gaussian_metrics, axis=0)
        
        # Update metrics history
        metrics['epoch'].append(epoch)
        metrics['loss'].append(avg_loss)
        metrics['well_likelihood'].append(avg_well_likelihood)
        metrics['prior'].append(avg_prior)
        metrics['entropy'].append(avg_entropy)
        metrics['gaussian_metrics'].append((avg_mean_error, avg_var_error))
        
        # Update learning rate scheduler
        scheduler.step(avg_loss)
        
        # Log epoch summary
        print(f"Epoch {epoch} Summary | "
              f"Avg well_ll: {avg_well_likelihood:.4f}, "
              f"Avg prior: {avg_prior:.4f}, "
              f"Avg entropy: {avg_entropy:.4f}, "
              f"Mean error: {avg_mean_error:.4f}, "
              f"Var error: {avg_var_error:.4f}, "
              f"Avg loss: {avg_loss:.4f}, "
              f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save checkpoint every 100 epochs
        if (epoch + 1) % 100 == 0:
            save_path = os.path.join(CONFIG["base_dir"], f"inference_net_epoch_{epoch+1}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': inference_net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'metrics': metrics
            }, save_path)
            print(f"Checkpoint saved at {save_path}")
    
    return metrics

###################################
# 6. Visualization Functions
###################################
def generate_and_visualize(num_samples: int = 5):
    """Generate samples using the trained inference network and visualize results."""
    inference_net.eval()
    
    with torch.no_grad():
        # Generate latent codes through the inference network
        w = torch.randn(num_samples, CONFIG["w_dim"], device=device)
        z = inference_net(w)
        
        # Generate images
        samples = generator(z).cpu().squeeze()
        
        # Plot generated samples
        fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*4, 4))
        
        for i in range(num_samples):
            # Assuming the output is in the format [channels, height, width]
            # and you want to visualize the first channel
            axes[i].imshow(samples[i, 15].cpu().numpy(), cmap='viridis')
            axes[i].set_title(f"Sample {i+1}")
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(CONFIG["base_dir"], "generated_samples.png"))
        plt.show()

def visualize_latent_distribution(num_samples: int = 1000):
    """Visualize the distribution of the latent space to verify it's Gaussian."""
    inference_net.eval()
    
    with torch.no_grad():
        # Generate latent codes through the inference network
        w = torch.randn(num_samples, CONFIG["w_dim"], device=device)
        z = inference_net(w)
        
        # Move to CPU for visualization
        z_np = z.cpu().numpy()
        
        # Check if the distribution is Gaussian by plotting several dimensions
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        # Plot histograms of the first 6 dimensions
        for i in range(min(6, CONFIG["z_dim"])):
            axes[i].hist(z_np[:, i], bins=50, alpha=0.7, density=True)
            axes[i].set_title(f"Distribution of z[{i}]")
            
            # Plot standard normal for comparison
            x = np.linspace(-3, 3, 100)
            axes[i].plot(x, np.exp(-x**2/2) / np.sqrt(2*np.pi), 'r-', lw=2)
            axes[i].legend(['Standard Normal', 'Generated'])
        
        plt.tight_layout()
        plt.savefig(os.path.join(CONFIG["base_dir"], "latent_distribution.png"))
        plt.show()
        
        # Compute statistics of the latent distribution
        z_mean = z_np.mean(axis=0)
        z_var = z_np.var(axis=0)
        
        print(f"Mean of z: {np.mean(z_mean):.4f} (should be close to 0)")
        print(f"Variance of z: {np.mean(z_var):.4f} (should be close to 1)")
        
        # Check covariance structure
        z_cov = np.cov(z_np.T)
        plt.figure(figsize=(10, 8))
        plt.imshow(z_cov, cmap='coolwarm')
        plt.colorbar()
        plt.title("Covariance Matrix of Latent Variables")
        plt.savefig(os.path.join(CONFIG["base_dir"], "latent_covariance.png"))
        plt.show()
        
        # Check if values match well data
        print("Checking conditioning on well data...")
        check_well_conditioning(z, num_samples=5)

def check_well_conditioning(z: torch.Tensor, num_samples: int = 5):
    """Check if generated samples match the well data observations."""
    with torch.no_grad():
        # Generate images from latent codes
        x = generator(z[:num_samples]).squeeze()  # [batch, channels, H, W]
        
        # Extract simulated observations at the well locations
        gen = x.permute(0, 2, 3, 1)  # [batch, H, W, channels]
        
        # Collect simulated measurements at well locations
        for i in range(num_samples):
            print(f"\nSample {i+1} Well Values:")
            for j, (row, col) in enumerate(CONFIG["well_locations"]):
                sim_val = gen[i, row, col, :].cpu().numpy()
                # Apply hard threshold for comparison
                sim_val_binary = (sim_val > CONFIG["threshold"]).astype(np.float32)
                true_val = d_obs[j].cpu().numpy()
                print(f"  Well {j+1}:")
                print(f"    Raw simulated = {sim_val}")
                print(f"    Thresholded   = {sim_val_binary}")
                print(f"    Observed      = {true_val}")
                
                # Calculate match percentage
                match_pct = np.mean(sim_val_binary == true_val) * 100
                print(f"    Match: {match_pct:.1f}%")

def plot_training_curves(metrics):
    """Plot training metrics over time."""
    fig, axes = plt.subplots(4, 1, figsize=(10, 16), sharex=True)
    
    # Plot loss
    axes[0].plot(metrics['epoch'], metrics['loss'])
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].grid(True)
    
    # Plot likelihood and prior
    axes[1].plot(metrics['epoch'], metrics['well_likelihood'], label='Well Likelihood')
    axes[1].plot(metrics['epoch'], metrics['prior'], label='Prior')
    axes[1].set_ylabel('Value')
    axes[1].set_title('Well Likelihood and Prior')
    axes[1].legend()
    axes[1].grid(True)
    
    # Plot entropy
    axes[2].plot(metrics['epoch'], metrics['entropy'])
    axes[2].set_ylabel('Entropy')
    axes[2].set_title('Estimated Entropy')
    axes[2].grid(True)
    
    # Plot Gaussian metrics
    mean_errors = [m[0] for m in metrics['gaussian_metrics']]
    var_errors = [m[1] for m in metrics['gaussian_metrics']]
    axes[3].plot(metrics['epoch'], mean_errors, label='Mean Error')
    axes[3].plot(metrics['epoch'], var_errors, label='Variance Error')
    axes[3].set_ylabel('Error')
    axes[3].set_xlabel('Epoch')
    axes[3].set_title('Gaussian Distribution Metrics')
    axes[3].legend()
    axes[3].grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG["base_dir"], "training_curves.png"))
    plt.show()

###################################
# 7. Main Execution
###################################
if __name__ == "__main__":
    # Run training
    metrics = train()
    
    # Plot training progress
    plot_training_curves(metrics)
    
    # Generate and visualize some samples
    generate_and_visualize(num_samples=5)
    
    # Visualize latent distribution
    visualize_latent_distribution()
    
    print("Training and visualization complete!")

#### History Matching Process 

This section implements the **Ensemble Smoother with Multiple Data Assimilation (ESMDA)** algorithm to condition geological models generated by a GAN to observed reservoir data.

*   **Goal:** Iteratively update an ensemble of **latent vectors (`z`)** so that the geological models produced from these vectors by a pre-trained Generator match observed data from reservoir simulations.
*   **Key Components:**
    *   A **frozen, pre-trained Generator** that maps input vectors (`w`) to 3D geological models.
    *   A **frozen, pre-trained Inference Network** which, in this script, is used to map the latent vectors (`z`) to the generator input vectors (`w`).
    *   Observed **BHP timeseries data** from wells (`d_obs`).
*   **ESMDA Process (Main Loop):**
    1.  Start with an initial ensemble of latent vectors (`z_ensemble`).
    2.  For a set number of iterations:
        *   Map the current `z_ensemble` to a `w_ensemble` using the **Inference Network**.
        *   Generate a batch of 3D geological models from the `w_ensemble` using the **Generator**.
        *   Run **parallel reservoir simulations** for each generated model to obtain simulated well data (`d_model`).
        *   Compare the simulated data to the observed data (`d_obs`).
        *   Compute the **Kalman Gain** based on the ensemble covariance of the latent vectors (`z`) and the simulated data (`d_model`).
        *   **Update** each latent vector in the `z_ensemble` using a Kalman-like update equation that combines the original latent vector, the Kalman Gain, the difference between the observed and simulated data, and added noise.
        *   Save the updated `z_ensemble` and the corresponding `w_ensemble`.
    3.  Track and save performance metrics like Mean Squared Error (MSE) between simulated and observed data.
*   **Outcome:** The process yields a final ensemble of latent vectors (`z_ensemble`) and their corresponding `w_ensemble` vectors. These vectors, when fed into the Generator, produce geological realizations that are conditioned to the observed well data while maintaining diversity within the ensemble.

In [None]:


# List to store MSE results per iteration for later analysis.
mse_history = []

# Configuration dictionary holding all parameters for the script.
CONFIG = {
    # --- Paths ---
    "base_dir": "D:/Generalizedized_GAN/ESMDA/", # Base directory for checkpoints and data
    "checkpoint_path": "checkpoint_epoch_3000.pt", # Path to the Generator checkpoint

    # --- Network parameters ---
    # Dimensions for the GAN (Generator) and Inference Network.
    "w_dim": 100,          # Dimension of the noise vector fed into the Generator
    "z_dim": 100,          # Dimension of the latent space used by the Inference Network
    "hidden_dim": 256,     # Dimension of hidden layers in networks (if applicable)
    "Final_iter_IN": 400,    # Number of epochs for training (used to find inference net checkpoint name)
}

# --- Device Setup ---
# Determine whether to use GPU (CUDA) if available, otherwise use CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Model Loading ---
# Load the pre-trained Generator model.
# This model takes 'w' and generates subsurface models.
generator_checkpoint_path = os.path.join(CONFIG["base_dir"], CONFIG["checkpoint_path"])
print(f"Loading generator from {generator_checkpoint_path}")
generator = load_generator(generator_checkpoint_path, device)
generator.eval() # Set the generator to evaluation mode

# Load the pre-trained Inference Network model.
# This model takes 'z' and outputs 'w'.
inference_net_checkpoint_path = os.path.join(CONFIG["base_dir"], f"inference_net_epoch_{CONFIG['Final_iter_IN']}.pt")
print(f"Loading inference network from {inference_net_checkpoint_path}")
inference_net = InferenceNet(
    CONFIG["w_dim"], 
    CONFIG["z_dim"], 
    CONFIG["hidden_dim"]
).to(device)

# Load model weights into the inference network.
try:
    checkpoint = torch.load(inference_net_checkpoint_path, map_location=device)
    inference_net.load_state_dict(checkpoint["model_state_dict"])
    inference_net.eval()  # Set the inference network to evaluation mode
except FileNotFoundError:
    print(f"Error: Inference network checkpoint not found at {inference_net_checkpoint_path}")
    # Handle the error appropriately, e.g., exit or train the network
    exit()
except KeyError:
     print(f"Error: 'model_state_dict' not found in checkpoint file {inference_net_checkpoint_path}. Check the checkpoint structure.")
     exit()


# --- Main Execution Block ---
# This ensures the code inside only runs when the script is executed directly
# (not when imported as a module).
if __name__ == '__main__':
    print("Main: Starting ESMDA process.")

    # -------------------------------------------------------------------------
    # 1. LOAD AND STANDARDIZE OBSERVATION DATA
    # -------------------------------------------------------------------------
    print("Loading and standardizing observation data...")
    # Load the observed BHP data from CSV files.
    # We assume these are the 'truth' observations to assimilate against.
    
    # Load Injector 1 BHP data
    InJ_data = pd.read_csv("Truth/Inj 1_CMG_aligned.csv")['BHP'].values
    # Standardize the data: Subtract mean and divide by standard deviation.
    # Store the mean and std of the *truth* data, as simulated data will be
    # standardized using these same values to make them comparable.
    InJ_mean, InJ_std = np.mean(InJ_data), np.std(InJ_data)
    InJ_obs = (InJ_data - InJ_mean) / InJ_std

    # Load Producer 1 BHP data
    P1_data = pd.read_csv("Truth/Prod 1_CMG_aligned.csv")['BHP'].values
    P1_mean, P1_std = np.mean(P1_data), np.std(P1_data)
    P1_obs = (P1_data - P1_mean) / P1_std
    
    # Load Producer 2 BHP data
    P2_data = pd.read_csv("Truth/Prod 2_CMG_aligned.csv")['BHP'].values
    P2_mean, P2_std = np.mean(P2_data), np.std(P2_data)
    P2_obs = (P2_data - P2_mean) / P2_std

    # Concatenate all standardized observation data into a single vector.
    d_obs = np.concatenate([InJ_data, P1_data, P2_data]) # Store original for MSE
    d_obs_std = np.concatenate([InJ_obs, P1_obs, P2_obs]) # Standardized for assimilation
    n_obs = d_obs_std.shape[0] # Total number of observations

    print(f"Loaded {n_obs} observation points.")

    # -------------------------------------------------------------------------
    # 2. ESMDA PARAMETERS & INITIAL ENSEMBLE SETUP
    # -------------------------------------------------------------------------
    print("Setting up ESMDA parameters...")
    n_iter = 25 # Number of assimilation (ESMDA) iterations to perform
    
    # Define the sequence of alpha values for each iteration.
    # In standard ESMDA, alpha_k = N_iter for all k=1,...,N_iter.
    alpha_list = [n_iter] * n_iter 

    # Define the original observation error standard deviation (sigma).
    # This represents the uncertainty in the measured data.
    sigma = 0.06 
    # The covariance matrix of observation error (Gamma).
    # Assuming errors are uncorrelated and have the same variance.
    # This needs to be in the *standardized* data space for the update step.
    # Since the data is standardized by dividing by its std dev (which is ~1
    # for the whole concatenated vector), the scaled sigma_std is approximately sigma.
    # A more rigorous approach would scale sigma by the effective std dev of d_obs.
    # For simplicity here, we assume sigma_std = sigma.
    sigma_std = sigma  # Keep sigma as is, assuming it's relative to the standardized scale
    Gamma_std = np.diag([sigma_std**2] * n_obs) # Diagonal covariance matrix

    # Define the ensemble size.
    ensemble_size = 200
    latent_dim = CONFIG["z_dim"] # Dimension of the latent space (z)

    # Initialize the latent variable ensemble (z_ensemble).
    # Each row is a member of the ensemble.
    # Initialized from a standard normal distribution.
    # The commented-out line shows how to load a previously saved ensemble.
    z_ensemble = torch.randn(ensemble_size, latent_dim, device=device) 
    # z_ensemble = torch.load("z_ensemble_iter14.pt", map_location=device)
    
    print(f"Ensemble size: {ensemble_size}, Latent dimension: {latent_dim}")
    print(f"Number of ESMDA iterations: {n_iter}, Initial sigma: {sigma}")


    # -------------------------------------------------------------------------
    # 3. MAIN LOOP OVER ESMDA ITERATIONS
    # -------------------------------------------------------------------------
    print("Starting ESMDA iterations...")
    
    # Adjust starting iteration if resuming (e.g., from loaded z_ensemble)
    start_iter = 0 # Change this if resuming from a specific iteration
    
    for k in range(start_iter, n_iter):
        current_iteration = k + 1
        print(f"\n--- ESMDA Iteration {current_iteration}/{n_iter} ---")

        # --- Map latent variable 'z' to generator input 'w' using the Inference Network ---
        # This is specific to this GAN setup where G(w) generates samples and
        # the inference network learns a mapping z -> w.
        w_ensemble = inference_net(z_ensemble)
        
        # --- Generate subsurface models from the 'w' ensemble using the Generator ---
        # generator(w_ensemble) produces outputs likely of shape (ensemble_size, C, H, W)
        # .detach().cpu().numpy() converts the torch tensor to a numpy array on CPU.
        generated_batch = generator(w_ensemble).detach().cpu().numpy()
        
        # Apply a thresholding function if needed (e.g., for binary or categorical models).
        # Assumes threshold_samples is defined in Utilities.
        generated_batch = threshold_samples(generated_batch.squeeze())
        
        # Re-order dimensions for the simulation code if needed.
        # From (ensemble_size, Nx, Ny) or (ensemble_size, Channels, Nx, Ny)
        # to (ensemble_size, Nx, Ny, Channels).
        # Assuming the generator output is (N, 1, H, W) after squeeze -> (N, H, W)
        # and simulation expects (N, H, W, 1).
        Z = generated_batch[:, :, :, np.newaxis] # Add channel dimension back

        # Set the number of processors for parallel simulations.
        # Limit to ensemble size and a reasonable maximum (e.g., 128).
        num_processors = min(multiprocessing.cpu_count(), ensemble_size, 128)
        print(f"Using {num_processors} processors for forward simulations.")

        # ---------------------------------------------------------------------
        # 3A. RUN FORWARD SIMULATIONS IN PARALLEL
        # ---------------------------------------------------------------------
        print("Running forward simulations...")
        processes = []
        completed_indices = []
        batch_size = num_processors # Process in batches equal to the number of available processors
        
        for batch_start in range(0, ensemble_size, batch_size):
            batch_end = min(batch_start + batch_size, ensemble_size)
            current_batch_indices = range(batch_start, batch_end)
            print(f"Launching simulation batch for indices {batch_start} to {batch_end-1}...")
            
            # Launch processes for the current batch
            for i in current_batch_indices:
                # The 'worker' function (assumed from Utilities) takes the ensemble
                # index and the specific subsurface model data for that member.
                # It's responsible for setting up and running the simulator, saving
                # the results (e.g., to CSV files).
                p = multiprocessing.Process(target=worker, args=(i, Z[i]))
                p.start()
                processes.append((i, p))
            
            # Wait for all processes in this batch to complete
            for idx, p in processes:
                p.join() # Wait for process to finish
                completed_indices.append(idx)
            processes = [] # Clear the list for the next batch
            
            print(f"Completed simulation batch starting at index {batch_start}.")
        
        # Verify that simulations ran for all ensemble members.
        if len(completed_indices) != ensemble_size:
            raise RuntimeError(
                f"Expected {ensemble_size} completed processes, but only {len(completed_indices)} finished. Check simulation logs."
            )
        print("All forward simulations completed.")
        
        # ---------------------------------------------------------------------
        # 3B. COLLECT SIMULATION RESULTS & STANDARDIZE
        # ---------------------------------------------------------------------
        print("Collecting and standardizing simulation results...")
        d_model = []
        # Define a backup path for the simulation results
        backup_path = "simulation_results_backup/" 
        os.makedirs(backup_path, exist_ok=True) # Create backup directory if it doesn't exist

        for i in range(ensemble_size):
            try:
                # Define paths to the simulation output files for ensemble member i
                # These paths should match where the 'worker' function saves the results.
                sim_InJ_file = f"wells_timeseries_alined/Inj 1_Ensemble{i}_Aligned.csv"
                sim_P1_file = f"wells_timeseries_alined/Prod 1_Ensemble{i}_Aligned.csv"
                sim_P2_file = f"wells_timeseries_alined/Prod 2_Ensemble{i}_Aligned.csv"

                # Load the simulated BHP data
                sim_InJ_data = pd.read_csv(sim_InJ_file)['BHP'].values
                sim_P1_data = pd.read_csv(sim_P1_file)['BHP'].values
                sim_P2_data = pd.read_csv(sim_P2_file)['BHP'].values
                
                # Standardize the *simulated* data using the mean and std dev of the *truth* data.
                # This makes the simulated data statistically comparable to the observed data.
                sim_InJ_std = (sim_InJ_data - InJ_mean) / InJ_std
                sim_P1_std = (sim_P1_data - P1_mean) / P1_std
                sim_P2_std = (sim_P2_data - P2_mean) / P2_std

                # Concatenate the standardized simulated data for this ensemble member.
                sim_concat_std = np.concatenate([sim_InJ_std, sim_P1_std, sim_P2_std])
                d_model.append(sim_concat_std) # Append the standardized result
                
                # Backup the original simulation output files
                # Appending iteration number to filename helps track history
                shutil.copy(sim_InJ_file, os.path.join(backup_path, f"Inj 1_Ensemble{i}_Aligned_iter{current_iteration}.csv"))
                shutil.copy(sim_P1_file, os.path.join(backup_path, f"Prod 1_Ensemble{i}_Aligned_iter{current_iteration}.csv"))
                shutil.copy(sim_P2_file, os.path.join(backup_path, f"Prod 2_Ensemble{i}_Aligned_iter{current_iteration}.csv"))

                # Clean up the original simulation output files to save space
                os.remove(sim_InJ_file)
                os.remove(sim_P1_file)
                os.remove(sim_P2_file)

            except FileNotFoundError as e:
                 print(f"Error: Simulation output file not found for ensemble {i}: {e}")
                 # This indicates a simulation failed to produce output.
                 # The script cannot proceed; raise the error.
                 raise
            except Exception as e:
                print(f"Error loading or processing simulation output for ensemble {i}: {e}")
                raise
        
        # Convert the list of simulation results to a numpy array.
        d_model_std = np.array(d_model)  # shape: (ensemble_size, n_obs)
        print("Simulation results collected and standardized.")

        # ---------------------------------------------------------------------
        # 3C. COMPUTE MSE IN ORIGINAL SPACE (FOR DIAGNOSTICS)
        # ---------------------------------------------------------------------
        # Note: This calculates MSE on the *unstandardized* data (d_obs vs
        # *unstandardized* d_model) for interpretability, even though the
        # assimilation happens on the standardized data. We need to reload
        # the unstandardized data for this calculation.
        d_model_orig = []
        for i in range(ensemble_size):
             # Reload original (unstandardized) data from backup
            sim_InJ_file = os.path.join(backup_path, f"Inj 1_Ensemble{i}_Aligned_iter{current_iteration}.csv")
            sim_P1_file = os.path.join(backup_path, f"Prod 1_Ensemble{i}_Aligned_iter{current_iteration}.csv")
            sim_P2_file = os.path.join(backup_path, f"Prod 2_Ensemble{i}_Aligned_iter{current_iteration}.csv")
            
            sim_InJ_data = pd.read_csv(sim_InJ_file)['BHP'].values
            sim_P1_data = pd.read_csv(sim_P1_file)['BHP'].values
            sim_P2_data = pd.read_csv(sim_P2_file)['BHP'].values
            d_model_orig.append(np.concatenate([sim_InJ_data, sim_P1_data, sim_P2_data]))
            
        d_model_orig = np.array(d_model_orig)

        # Calculate Mean Squared Error between truth observations and each simulation.
        mse_errors = [mean_squared_error(d_obs, d_model_orig[i]) for i in range(ensemble_size)]
        mean_mse = np.mean(mse_errors)
        print(f"Mean MSE (original space) for iteration {current_iteration}: {mean_mse:.6f}")
        # Store the MSE values for this iteration.
        mse_history.append({'iteration': current_iteration, 'mean_mse': mean_mse, 'all_mses': mse_errors})


        # ---------------------------------------------------------------------
        # 3D. COMPUTE STATISTICS & FORM KALMAN GAIN IN STANDARDIZED SPACE
        # ---------------------------------------------------------------------
        print("Computing ensemble statistics and Kalman Gain...")
        # Convert latent ensemble tensor to numpy for calculations.
        z_np = z_ensemble.cpu().numpy()
        
        # Compute ensemble mean of the latent variables (z).
        z_mean = np.mean(z_np, axis=0)
        # Compute anomalies (deviations from the mean) for z.
        A_z = z_np - z_mean
        
        # Compute ensemble mean of the standardized simulation results (d_model_std).
        d_mean_std = np.mean(d_model_std, axis=0)
        # Compute anomalies (deviations from the mean) for d_model_std.
        A_d_std = d_model_std - d_mean_std
        
        # Compute sample covariance between z anomalies and d anomalies.
        # Pzd = (A_z.T @ A_d_std) / (N-1)
        Cov_zd_std = (A_z.T @ A_d_std) / (ensemble_size - 1)
        
        # Compute sample covariance of d anomalies.
        # Pdd = (A_d_std.T @ A_d_std) / (N-1)
        Cov_dd_std = (A_d_std.T @ A_d_std) / (ensemble_size - 1)
        
        # Get the current alpha value for this ESMDA iteration.
        alpha = alpha_list[k]
        
        # Compute the Kalman Gain (K) in the standardized data space.
        # K = Pzd @ (Pdd + alpha * Gamma_std)^-1
        # Using np.linalg.pinv for robustness against singular matrices.
        K_std = Cov_zd_std @ np.linalg.pinv(Cov_dd_std + alpha * Gamma_std)
        print("Kalman Gain computed.")

        # ---------------------------------------------------------------------
        # 3E. UPDATE EACH ENSEMBLE MEMBER IN STANDARDIZED SPACE
        # ---------------------------------------------------------------------
        print("Updating latent ensemble members...")
        # Create a copy of the ensemble before update to measure change.
        z_np_before_update = np.copy(z_np) 

        # Apply the ESMDA update formula to each ensemble member.
        # zi_new = zi_old + K * (d_obs_std + sqrt(alpha) * epsilon_i - di_old_std)
        # Note: The epsilon term is already scaled by sqrt(alpha) implicitly
        # when sampling from a distribution with covariance alpha * Gamma_std.
        
        for i in range(ensemble_size):
            # Generate an observation error realization for this ensemble member.
            # The error is sampled from a multivariate normal distribution
            # with mean 0 and covariance alpha * Gamma_std.
            eps_std = np.random.multivariate_normal(
                mean=np.zeros(n_obs), cov=alpha * Gamma_std
            )
            # Compute the innovation term (difference between observed data + noise and simulated data).
            innovation_std = d_obs_std + eps_std - d_model_std[i]
            # Apply the update formula.
            z_np[i] = z_np[i] + (K_std @ innovation_std)
            
        # Calculate the average magnitude of the change in z for diagnostics.
        z_change = np.mean(np.abs(z_np - z_np_before_update))
        print(f"Average |Δz| for iteration {current_iteration}: {z_change:.6f}")
        
        # ---------------------------------------------------------------------
        # 3F. SAVE UPDATED LATENT ENSEMBLE AND W ENSEMBLE
        # ---------------------------------------------------------------------
        print("Saving updated ensembles...")
        # Convert the updated numpy array back to a torch tensor for the next iteration.
        z_ensemble = torch.tensor(z_np, device=device, dtype=torch.float32)
        
        # Define filenames for saving.
        # Appending iteration number to track history.
        # Saving w_ensemble is useful for debugging or if you wanted to regenerate
        # models directly from w without re-running the inference network.
        w_ensemble_filename = f"w_ensemble_iter{current_iteration}.pt"
        z_ensemble_filename = f"z_ensemble_iter{current_iteration}.pt"
        
        # Save the ensembles to files.
        # .cpu() moves the tensor to CPU memory before saving, which is generally recommended.
        torch.save(w_ensemble.cpu(), w_ensemble_filename)
        torch.save(z_ensemble.cpu(), z_ensemble_filename)
        
        print(f"Ensembles saved: {w_ensemble_filename}, {z_ensemble_filename}")

    # -------------------------------------------------------------------------
    # 4. SAVE MSE HISTORY TO CSV
    # -------------------------------------------------------------------------
    print("\nESMDA process finished.")
    print("Saving MSE history...")
    # Convert the list of MSE dictionaries into a pandas DataFrame.
    # Note: 'all_mses' will store lists, which might make the CSV less readable
    # directly but preserves all individual MSE values. You might choose to
    # save only 'mean_mse' if desired.
    mse_df = pd.DataFrame(mse_history)
    
    # Save the DataFrame to a CSV file.
    mse_csv_path = "mse_history.csv"
    mse_df.to_csv(mse_csv_path, index=False)
    print(f"MSE history saved to {mse_csv_path}")