In [192]:
# ==============================================================================
# SIMPLIFIED & ROBUST DATA LOADING
# ==============================================================================
import pandas as pd
import numpy as np
import os

# --- 1. Helper function to create dummy CSV files for verification ---
def create_dummy_csv_files():
    """Creates two dummy CSV files to simulate the experimental data."""
    print("Creating dummy CSV files for verification...")

    # Data for Experiment 1
    data1 = {
        'time_min': [0, 13, 54, 120, 224, 436],
        'C_out': [0.0, 0.0, 0.87, 8.34, 61.8, 98.48]
    }
    df1 = pd.DataFrame(data1)
    df1.to_csv("exp_1.csv", index=False)

    # Data for Experiment 2
    data2 = {
        'time_min': [0, 10, 40, 80, 150, 200],
        'C_out': [0.0, 1.0, 15.0, 65.0, 98.0, 99.0]
    }
    df2 = pd.DataFrame(data2)
    df2.to_csv("exp_2.csv", index=False)
    print("Dummy files 'exp_1.csv' and 'exp_2.csv' created.")


# --- 2. Main function to load all data into a single DataFrame ---
def create_master_dataframe(file_paths, experiment_params_list):
    """
    Reads multiple CSV files and combines them with their experimental
    parameters into a single, master Pandas DataFrame.

    Args:
        file_paths (list of str): A list of paths to the CSV data files.
        experiment_params_list (list of dict): A list of dictionaries, where each
                                               dictionary contains the parameters
                                               for the corresponding file.
    Returns:
        pandas.DataFrame: A single DataFrame containing all experimental data.
    """
    all_experiments_df_list = []

    for i, file_path in enumerate(file_paths):
        # --- A. Read the time-series data from the CSV ---
        temp_df = pd.read_csv(file_path)
        # --> IMPORTANT: Adjust these column names if yours are different
        temp_df.rename(columns={'time_min': 't_min', 'C_out': 'C'}, inplace=True)

        # --- B. Get the corresponding parameters for this file ---
        params = experiment_params_list[i]

        # --- C. Add parameter columns to the DataFrame ---
        temp_df['experiment_id'] = i + 1 # Unique ID for each experiment
        temp_df['porosity'] = params['porosity_middle_layer']
        temp_df['inflow_rate_Q'] = params['inflow_rate_Q']
        temp_df['C0'] = params['C0']
        
        all_experiments_df_list.append(temp_df)

    # --- D. Concatenate all individual DataFrames into one master DataFrame ---
    master_df = pd.concat(all_experiments_df_list, ignore_index=True).astype('float32')

    return master_df

# --- 3. Define the physical constants and the experimental parameters ---

# Physical constants
DIAMETER = 16.0
RADIUS = DIAMETER / 2.0
AREA_MM2 = np.pi * RADIUS**2

# ==============================================================================
# >> USER ACTION REQUIRED HERE <<
# Manually define the parameters for each of your 9 CSV files below.
# I have created 2 dummy examples. You should replace this with your 9 files.
# ==============================================================================

# List of your CSV file paths
csv_files = ["exp_1.csv", "exp_2.csv", "exp_3.csv", "exp_4.csv", "exp_5.csv", "exp_6.csv", "exp_7.csv", "exp_8.csv", "exp_9.csv"] # <-- REPLACE WITH YOUR 9 FILE PATHS

# List of corresponding parameters for each file
# Each dictionary represents one experiment.
params_list = [
    # Parameters for exp_1.csv
    {
        "inflow_rate_Q": 0.4,   # ml/min
        "C0": 100.0,            # mg/L
        "porosity_middle_layer": 0.371
    },
    # Parameters for exp_2.csv
    {
        "inflow_rate_Q": 1.0,   # ml/min
        "C0": 100.0,            # mg/L
        "porosity_middle_layer": 0.371
    },
    # Parameters for exp_3.csv
    {
        "inflow_rate_Q": 2.5,   # ml/min
        "C0": 100.0,            # mg/L
        "porosity_middle_layer": 0.371
    },
    # Parameters for exp_4.csv
    {
        "inflow_rate_Q": 0.4,   # ml/min
        "C0": 150.0,            # mg/L
        "porosity_middle_layer": 0.371
    },
    # Parameters for exp_5.csv
    {
        "inflow_rate_Q": 0.4,   # ml/min
        "C0": 200.0,            # mg/L
        "porosity_middle_layer": 0.371
    },
    # Parameters for exp_6.csv
    {
        "inflow_rate_Q": 0.4,   # ml/min
        "C0": 250.0,            # mg/L
        "porosity_middle_layer": 0.371
    },
    # Parameters for exp_7.csv
    {
        "inflow_rate_Q": 0.4,   # ml/min
        "C0": 300.0,            # mg/L
        "porosity_middle_layer": 0.371
    },
    # Parameters for exp_8.csv
    {
        "inflow_rate_Q": 0.4,   # ml/min
        "C0": 100.0,            # mg/L
        "porosity_middle_layer": 0.424
    },
    # Parameters for exp_9.csv
    {
        "inflow_rate_Q": 0.4,   # ml/min
        "C0": 100.0,            # mg/L
        "porosity_middle_layer": 0.399
    }
]

# --- 4. Main execution block ---
if __name__ == '__main__':

    # Load all data into the master DataFrame
    master_df = create_master_dataframe(csv_files, params_list)
    
    # --- 5. Calculate derived values needed for the model ---
    # Calculate pore water velocity 'v_x' (mm/min) for each experiment
    # Note: 1 ml = 1000 mm^3
    master_df['v_x'] = (master_df['inflow_rate_Q'] * 1000.0) / (AREA_MM2 * master_df['porosity'])

    # Calculate the dimensionless concentration 'C/C0'
    master_df['C_over_C0'] = master_df['C'] / master_df['C0']
    
    # Display the final, processed DataFrame
    print("\n--- Master DataFrame successfully created ---")
    print(master_df.head())
    print("\nDataFrame Info:")
    master_df.info()

    P_MIN = torch.tensor(master_df[['porosity', 'v_x', 'C0']].min().values, dtype=torch.float32, device=device)
    P_RANGE = torch.tensor(master_df[['porosity', 'v_x', 'C0']].max().values - master_df[['porosity', 'v_x', 'C0']].min().values, dtype=torch.float32, device=device)
    T_MAX = torch.tensor(master_df['t_min'].max(), dtype=torch.float32, device=device)
    L_MAX = torch.tensor(TOTAL_LENGTH, dtype=torch.float32, device=device)

    print("\n--- Normalization Constants Defined ---")
    print(f"Time will be scaled by: {T_MAX.item():.2f}")
    print(f"Position will be scaled by: {L_MAX.item():.2f}")

    


--- Master DataFrame successfully created ---
   t_min         C  experiment_id  porosity  inflow_rate_Q     C0       v_x  \
0    0.0  0.000000            1.0     0.371            0.4  100.0  5.362363   
1   13.0  0.000000            1.0     0.371            0.4  100.0  5.362363   
2   27.0  0.000000            1.0     0.371            0.4  100.0  5.362363   
3   40.0  0.000000            1.0     0.371            0.4  100.0  5.362363   
4   54.0  0.872917            1.0     0.371            0.4  100.0  5.362363   

   C_over_C0  
0   0.000000  
1   0.000000  
2   0.000000  
3   0.000000  
4   0.008729  

DataFrame Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 329 entries, 0 to 328
Data columns (total 8 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   t_min          329 non-null    float32
 1   C              329 non-null    float32
 2   experiment_id  329 non-null    float32
 3   porosity       329 non-null    float32
 4 

In [143]:
master_df.to_csv('master_df.csv', index=False)

In [184]:
# ==============================================================================
# PHASE 1, STEP 2: FULL FILM-BASED BNN ARCHITECTURE (PYTORCH)
# ==============================================================================
import torch
import torch.nn as nn
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Constants (from your TF code) ---
NUM_KNOWN_PARAMS = 3
LATENT_DIM = 64
COORD_DIM = 2
OUTPUT_DIM = 1
LAYER_WIDTH = 50
NUM_HIDDEN_LAYERS = 2

# --- Custom Bayesian Layer (Verified from last step) ---
class DenseBayesian(nn.Module):
    def __init__(self, in_features, out_features, activation=None):
        super().__init__()
        self.in_features, self.out_features = in_features, out_features
        self.kernel_mu = nn.Parameter(torch.empty(in_features, out_features))
        self.kernel_rho = nn.Parameter(torch.empty(in_features, out_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_rho = nn.Parameter(torch.empty(out_features))
        nn.init.xavier_uniform_(self.kernel_mu)
        nn.init.normal_(self.kernel_rho, mean=-2., std=.1)
        nn.init.zeros_(self.bias_mu)
        nn.init.normal_(self.bias_rho, mean=--2., std=.1)
        self.prior = torch.distributions.Laplace(0, .1)
        self.activation = activation() if activation else nn.Identity()

    def forward(self, x):
        kernel_sigma = torch.nn.functional.softplus(self.kernel_rho)
        bias_sigma = torch.nn.functional.softplus(self.bias_rho)
        kernel_posterior = torch.distributions.Normal(self.kernel_mu, kernel_sigma)
        bias_posterior = torch.distributions.Normal(self.bias_mu, bias_sigma)
        w, b = kernel_posterior.rsample(), bias_posterior.rsample()
        kl_divergence = (kernel_posterior.log_prob(w) - self.prior.log_prob(w)).sum() + \
                        (bias_posterior.log_prob(b) - self.prior.log_prob(b)).sum()
        return self.activation(torch.matmul(x, w) + b), kl_divergence

# --- 1. Modulator and FiLM Layer Definitions (PyTorch) ---
class ModulatorNetwork(nn.Module):
    """Encodes known experimental parameters [ε, v_x, C₀] into a latent vector 'z'."""
    def __init__(self, input_param_dim, latent_dim):
        super().__init__()
        # self.net = nn.Sequential(
        #     nn.Linear(input_param_dim, 128),
        #     nn.Tanh(),
        #     nn.Linear(128, 128),
        #     nn.Tanh(),
        #     nn.Linear(128, latent_dim)
        # )
        self.net = nn.Sequential(
            nn.Linear(input_param_dim, latent_dim)
        )
    def forward(self, physical_params):
        return self.net(physical_params)

class FiLMLayer(nn.Module):
    """Feature-wise Linear Modulation Layer."""
    def __init__(self, num_features, latent_dim):
        super().__init__()
        self.num_features = num_features
        self.projection = nn.Linear(latent_dim, num_features * 2)

    def forward(self,x,z):
        gamma,beta=self.projection(z).chunk(2,dim=-1);
        return gamma*x+beta


# --- 2. The Main Bayesian Solver Model ---
class BayesianSolverWithFiLM(nn.Module):
    """A Bayesian Solver that uses FiLM layers for conditioning."""
    def __init__(self, coord_dim, output_dim, layer_width, latent_dim, num_hidden_layers):
        super().__init__()
        self.output_dim = output_dim
        
        self.input_dense = DenseBayesian(coord_dim, layer_width, activation=nn.Tanh)
        
        self.hidden_blocks = nn.ModuleList()
        for _ in range(num_hidden_layers):
            block = nn.ModuleDict({
                'dense': DenseBayesian(layer_width, layer_width),
                'film': FiLMLayer(layer_width, latent_dim),
                'activation': nn.Tanh()
            })
            self.hidden_blocks.append(block)

        self.output_dense = DenseBayesian(layer_width, output_dim * 2)

    def forward(self, coords, z):
        total_kl_loss = 0.
        
        x, kl = self.input_dense(coords)
        total_kl_loss += kl
        
        for block in self.hidden_blocks:
            x_res, kl = block['dense'](x)
            total_kl_loss += kl
            
            x_mod = block['film'](x_res, z)
            # Add a residual connection (as in your TF code)
            x = block['activation'](x_mod) + x_res

        raw_output, kl = self.output_dense(x)
        total_kl_loss += kl
        
        mu = raw_output[..., :self.output_dim]
        sigma = torch.nn.functional.softplus(raw_output[..., self.output_dim:]) + 1e-6
        
        return torch.distributions.Normal(loc=mu, scale=sigma), total_kl_loss

# --- 3. Sanity Check Cell ---
if __name__ == '__main__':
    print("\n--- Testing Full FiLM-based BNN Architecture ---")

    # 1. Instantiate models
    modulator = ModulatorNetwork(NUM_KNOWN_PARAMS, LATENT_DIM).to(device)
    solver = BayesianSolverWithFiLM(COORD_DIM, OUTPUT_DIM, LAYER_WIDTH, LATENT_DIM, NUM_HIDDEN_LAYERS).to(device)

    # 2. Create dummy data
    batch_size = 32
    dummy_params = torch.randn(batch_size, NUM_KNOWN_PARAMS).to(device)
    dummy_coords = torch.randn(batch_size, COORD_DIM).to(device)

    # 3. Perform a forward pass
    z_out = modulator(dummy_params)
    output_dist, total_kl = solver(dummy_coords, z_out)
    mu_out = output_dist.mean

    # 4. Verify outputs
    print(f"\nModulator 'z' output shape: {z_out.shape}")
    assert z_out.shape == (batch_size, LATENT_DIM)

    print(f"Solver 'mu' output shape: {mu_out.shape}")
    assert mu_out.shape == (batch_size, OUTPUT_DIM)
    
    print(f"Total KL loss aggregated: {total_kl.item():.4f}")
    assert total_kl.ndim == 0

    print("\nSUCCESS: The full PyTorch architecture is defined and verified.")

Using device: cpu

--- Testing Full FiLM-based BNN Architecture ---

Modulator 'z' output shape: torch.Size([32, 64])
Solver 'mu' output shape: torch.Size([32, 1])
Total KL loss aggregated: 5127.1465

SUCCESS: The full PyTorch architecture is defined and verified.


In [144]:
# ==============================================================================
# PHASE 2: DATA HANDLING WITH PYTORCH DATASET AND DATALOADER
# ==============================================================================
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

# --- 1. The PyTorch Dataset Class ---
class BreakthroughDataset(Dataset):
    """Custom PyTorch Dataset for the breakthrough curve data."""
    def __init__(self, dataframe):
        self.df = dataframe

        # Extract columns into numpy arrays for efficiency
        self.params = self.df[['porosity', 'v_x', 'C0']].values
        self.times = self.df[['t_min']].values
        self.concentrations = self.df[['C_over_C0']].values

    def __len__(self):
        # The total number of samples is the number of rows in the DataFrame
        return len(self.df)

    def __getitem__(self, idx):
        # Retrieve a single sample by its index
        params = self.params[idx]
        time = self.times[idx]
        concentration = self.concentrations[idx]
        
        # Return as PyTorch tensors of type float32
        return {
            'params': torch.tensor(params, dtype=torch.float32),
            't': torch.tensor(time, dtype=torch.float32),
            'c_over_c0': torch.tensor(concentration, dtype=torch.float32)
        }

# --- 2. Sanity Check Cell ---
if __name__ == '__main__':
    print("\n--- Testing the PyTorch Dataset and DataLoader ---")
    
    # Assume master_df has been created from the previous data loading script
    # For this test, we'll create a small dummy DataFrame
    dummy_data = {
        'porosity': np.random.rand(100),
        'v_x': np.random.rand(100) * 5 + 5,
        'C0': np.random.rand(100) * 100 + 100,
        't_min': np.random.rand(100) * 500,
        'C_over_C0': np.random.rand(100)
    }
    master_df_dummy = pd.DataFrame(dummy_data).astype(np.float32)

    # 1. Instantiate the Dataset
    dataset = BreakthroughDataset(master_df_dummy)
    print(f"Dataset created with {len(dataset)} samples.")
    
    # 2. Instantiate the DataLoader
    batch_size = 16
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"DataLoader created with batch size {batch_size}.")

    # 3. Retrieve and inspect one batch
    first_batch = next(iter(data_loader))

    print("\n--- Inspecting the first batch ---")
    p = first_batch['params']
    t = first_batch['t']
    c = first_batch['c_over_c0']

    print(f"Parameters shape: {p.shape}")
    print(f"Time shape:       {t.shape}")
    print(f"Conc shape:       {c.shape}")

    assert p.shape == (batch_size, 3)
    assert t.shape == (batch_size, 1)
    assert c.shape == (batch_size, 1)
    assert p.dtype == torch.float32

    print("\nSUCCESS: The data handling pipeline is correctly defined and verified.")


--- Testing the PyTorch Dataset and DataLoader ---
Dataset created with 100 samples.
DataLoader created with batch size 16.

--- Inspecting the first batch ---
Parameters shape: torch.Size([16, 3])
Time shape:       torch.Size([16, 1])
Conc shape:       torch.Size([16, 1])

SUCCESS: The data handling pipeline is correctly defined and verified.


In [145]:
# =osta===========================================================================
# PHASE 3, STEP 1 (CORRECTED): PHYSICS LOSS WITH REAL ARCHITECTURE
# ==============================================================================
import torch
import torch.nn as nn
import numpy as np

# --- 1. Constants and Architecture ---
device = torch.device("cpu")
print(f"Using device: {device}")

# Physical Constants
SAND_LAYER_1_L, MIDDLE_LAYER_L, TOTAL_LENGTH = 21.0, 50.0, 92.0
POROSITY_SAND, RHO_B = 0.43, 1.5e-3

# Model Hyperparameters
NUM_KNOWN_PARAMS, LATENT_DIM, COORD_DIM, OUTPUT_DIM, LAYER_WIDTH, NUM_HIDDEN_LAYERS = 3, 64, 2, 1, 256, 4

# Learnable Physical Parameter Distributions
d_l_dist = torch.distributions.LogNormal(torch.tensor(np.log(0.01)), torch.tensor(0.5))
alpha_dist = torch.distributions.LogNormal(torch.tensor(np.log(0.015)), torch.tensor(0.5))
beta_dist = torch.distributions.LogNormal(torch.tensor(np.log(29.185)), torch.tensor(0.5))

# --- 2. The Physics Loss Function ---
# def compute_pde_residual(modulator, solver, coords_phys, params_phys):
#     coords_phys.requires_grad_(True)
    
#     # Use the real model, not a dummy
#     z = modulator(params_phys)
#     c = solver(coords_phys, z)[0].mean
    
#     c_derivs = torch.autograd.grad(c, coords_phys, grad_outputs=torch.ones_like(c), create_graph=True)[0]
#     c_x, c_t = c_derivs[:, 0:1], c_derivs[:, 1:2]
    
#     # THE FIX: Set allow_unused=True. It is possible that c_x does not depend on t,
#     # and this was breaking the gradient graph.
#     c_xx_grad_output = torch.autograd.grad(c_x, coords_phys, grad_outputs=torch.ones_like(c_x), create_graph=True, allow_unused=True)[0]
#     if c_xx_grad_output is None:
#         # Handle the case where the graph is broken for a legitimate reason (e.g., zero gradients)
#         c_xx = torch.zeros_like(c_x)
#     else:
#         c_xx = c_xx_grad_output[:, 0:1]

#     x_coord, porosity_middle, v_x = coords_phys[:, 0:1], params_phys[:, 0:1], params_phys[:, 1:2]
#     D_L, alpha, beta = d_l_dist.rsample(), alpha_dist.rsample(), beta_dist.rsample()
    
#     interface1, interface2 = SAND_LAYER_1_L, SAND_LAYER_1_L + MIDDLE_LAYER_L
#     is_middle_layer = (x_coord > interface1) & (x_coord < interface2)
#     epsilon = torch.where(is_middle_layer, porosity_middle, POROSITY_SAND)
#     alpha = torch.where(is_middle_layer, alpha, 0.0)
#     beta = torch.where(is_middle_layer, beta, 0.0)
    
#     retardation_factor = 1.0 + (RHO_B / epsilon) * ((alpha * beta) / (1.0 + alpha * c)**2)
#     pde_residual = c_t * retardation_factor - D_L * c_xx + v_x * c_x
    
#     return torch.mean(pde_residual**2)


# ==============================================================================
# BLOCK 1: REPLACE THE `compute_pde_residual` FUNCTION
# ==============================================================================
def compute_pde_residual(modulator, solver, coords_phys_norm, params_phys_norm):
    """
    Calculates the physics-informed residual using NORMALIZED inputs.
    """
    coords_phys_norm.requires_grad_(True)
    
    # 1. De-normalize inputs to physical values for use inside the PDE calculation
    coords_phys = coords_phys_norm * torch.tensor([L_MAX, T_MAX], device=device)
    params_phys = params_phys_norm * P_RANGE + P_MIN

    # 2. Forward pass uses NORMALIZED inputs
    z = modulator(params_phys_norm)
    c_pred_dist, kl_loss = solver(coords_phys_norm, z)
    c = c_pred_dist.mean
    
    # 3. Compute gradients with respect to NORMALIZED coordinates
    c_derivs = torch.autograd.grad(c, coords_phys_norm, grad_outputs=torch.ones_like(c), create_graph=True)[0]
    c_x_norm, c_t_norm = c_derivs[:, 0:1], c_derivs[:, 1:2]
    
    # 4. Apply the Chain Rule to get derivatives in PHYSICAL dimensions
    c_t = c_t_norm / T_MAX
    c_x = c_x_norm / L_MAX
    
    c_xx_norm_grad = torch.autograd.grad(c_x_norm, coords_phys_norm, grad_outputs=torch.ones_like(c_x_norm), create_graph=True, allow_unused=True)[0]
    
    if c_xx_norm_grad is None:
        c_xx = torch.zeros_like(c_x)
    else:
        c_xx = c_xx_norm_grad[:, 0:1] / (L_MAX**2)

    # --- Assemble the PDE (using de-normalized physical values) ---
    x_coord = coords_phys[:, 0:1]
    porosity_middle = params_phys[:, 0:1]
    v_x = params_phys[:, 1:2]
    
    D_L, alpha, beta = d_l_dist.rsample(), alpha_dist.rsample(), beta_dist.rsample()
    
    interface1, interface2 = SAND_LAYER_1_L, SAND_LAYER_1_L + MIDDLE_LAYER_L
    is_middle_layer = (x_coord > interface1) & (x_coord < interface2)
    epsilon = torch.where(is_middle_layer, porosity_middle, POROSITY_SAND)
    alpha = torch.where(is_middle_layer, alpha, 0.0)
    beta = torch.where(is_middle_layer, beta, 0.0)
    
    retardation_factor = 1.0 + (RHO_B / epsilon) * ((alpha * beta) / (1.0 + alpha * c)**2)
    pde_residual = c_t * retardation_factor - D_L * c_xx + v_x * c_x
    
    return torch.mean(pde_residual**2)

# # --- 3. Sanity Check Cell ---
# if __name__ == '__main__':
#     print("\n--- Testing Physics Loss & Gradients with REAL Architecture ---")
    
#     modulator = ModulatorNetwork(NUM_KNOWN_PARAMS, LATENT_DIM).to(device)
#     solver = BayesianSolverWithFiLM(COORD_DIM, OUTPUT_DIM, LAYER_WIDTH, LATENT_DIM, NUM_HIDDEN_LAYERS).to(device)
    
#     dummy_coords = torch.rand(128, COORD_DIM, device=device) * torch.tensor([TOTAL_LENGTH, 500.0], device=device)
#     dummy_params = torch.rand(128, NUM_KNOWN_PARAMS, device=device) * 0.5 + 0.5
    
#     loss = compute_pde_residual(modulator, solver, dummy_coords, dummy_params)
#     print(f"Physics loss computed: {loss.item():.4e}")
#     assert not torch.isnan(loss) and not torch.isinf(loss)

#     print("\n--- Verifying Gradient Flow ---")
#     params_to_check = list(modulator.parameters()) + list(solver.parameters())
#     loss.backward()

#     if all(p.grad is not None for p in params_to_check):
#         print("SUCCESS: Gradients are flowing correctly to all model parameters.")
#     else:
#         for name, p in modulator.named_parameters():
#             if p.grad is None: print(f"FAIL: Modulator param '{name}' has NO gradient.")
#         for name, p in solver.named_parameters():
#             if p.grad is None: print(f"FAIL: Solver param '{name}' has NO gradient.")


# ==============================================================================
# BLOCK 2: RUN THE TEST FOR THE NEW `compute_pde_residual`
# ==============================================================================
if __name__ == '__main__':
    print("\n--- Testing NORMALIZED Physics Loss ---")
    
    modulator = ModulatorNetwork(NUM_KNOWN_PARAMS, LATENT_DIM).to(device)
    solver = BayesianSolverWithFiLM(COORD_DIM, OUTPUT_DIM, LAYER_WIDTH, LATENT_DIM, NUM_HIDDEN_LAYERS).to(device)
    
    # Create NORMALIZED dummy data in the [0, 1] range
    dummy_coords_norm = torch.rand(128, COORD_DIM, device=device)
    dummy_params_norm = torch.rand(128, NUM_KNOWN_PARAMS, device=device)
    
    modulator.zero_grad()
    solver.zero_grad()

    loss = compute_pde_residual(modulator, solver, dummy_coords_norm, dummy_params_norm)
    (loss).backward() # Backpropagate combined loss for gradient check

    print(f"Physics loss computed: {loss.item():.4e}")
    grad_check = all(p.grad is not None for p in list(modulator.parameters()) + list(solver.parameters()))
    print(f"Gradient check: {'SUCCESS' if grad_check else 'FAIL'}")
    

Using device: cpu

--- Testing NORMALIZED Physics Loss ---
Physics loss computed: 7.0604e-05
Gradient check: SUCCESS


In [146]:
# ==============================================================================
# PHASE 3, STEP 2 (CORRECTED): IC/BC LOSSES WITH GRADIENT TEST
# ==============================================================================

# --- 1. Loss Function Definitions (Unchanged) ---

def compute_ic_loss(modulator, solver, coords_ic, params_ic):
    c_pred_mean = solver(coords_ic, modulator(params_ic))[0].mean
    return torch.mean(c_pred_mean**2)

def compute_bc_loss(modulator, solver, coords_bc, params_bc):
    c_pred_mean = solver(coords_bc, modulator(params_bc))[0].mean
    return torch.mean((c_pred_mean - 1.0)**2)

# --- 2. Sanity Check Cell with Gradient Verification ---
if __name__ == '__main__':
    print("\n--- Testing IC and BC Loss Functions and Gradients ---")
    
    # 1. Instantiate the real models
    modulator = ModulatorNetwork(NUM_KNOWN_PARAMS, LATENT_DIM).to(device)
    solver = BayesianSolverWithFiLM(COORD_DIM, OUTPUT_DIM, LAYER_WIDTH, LATENT_DIM, NUM_HIDDEN_LAYERS).to(device)
    
    # --- IC Test ---
    print("\nVerifying IC Loss and Gradients (t=0)...")
    batch_size_ic = 64
    ic_coords = torch.cat([torch.rand(batch_size_ic, 1) * TOTAL_LENGTH, torch.zeros(batch_size_ic, 1)], dim=1).to(device)
    ic_params = torch.rand(batch_size_ic, NUM_KNOWN_PARAMS, device=device)
    
    # Zero out any previous gradients
    modulator.zero_grad()
    solver.zero_grad()
    
    loss_ic = compute_ic_loss(modulator, solver, ic_coords, ic_params)
    loss_ic.backward() # Backpropagate the loss

    print(f"IC loss computed: {loss_ic.item():.4e}")
    assert not torch.isnan(loss_ic)
    
    # Check for gradients
    ic_grad_check_passed = all(p.grad is not None for p in list(modulator.parameters()) + list(solver.parameters()))
    if ic_grad_check_passed:
        print("SUCCESS: IC gradients are flowing correctly.")
    else:
        print("FAILURE: IC gradients are missing for some parameters.")

    # --- BC Test ---
    print("\nVerifying BC Loss and Gradients (x=0)...")
    batch_size_bc = 64
    bc_coords = torch.cat([torch.zeros(batch_size_bc, 1), torch.rand(batch_size_bc, 1) * 500.0], dim=1).to(device)
    bc_params = torch.rand(batch_size_bc, NUM_KNOWN_PARAMS, device=device)
    
    # Zero out gradients again
    modulator.zero_grad()
    solver.zero_grad()

    loss_bc = compute_bc_loss(modulator, solver, bc_coords, bc_params)
    loss_bc.backward() # Backpropagate the loss

    print(f"BC loss computed: {loss_bc.item():.4e}")
    assert not torch.isnan(loss_bc)

    # Check for gradients
    bc_grad_check_passed = all(p.grad is not None for p in list(modulator.parameters()) + list(solver.parameters()))
    if bc_grad_check_passed:
        print("SUCCESS: BC gradients are flowing correctly.")
    else:
        print("FAILURE: BC gradients are missing for some parameters.")

    print("\n--- Verification Complete ---")


--- Testing IC and BC Loss Functions and Gradients ---

Verifying IC Loss and Gradients (t=0)...
IC loss computed: 2.0421e+00
SUCCESS: IC gradients are flowing correctly.

Verifying BC Loss and Gradients (x=0)...
BC loss computed: 6.8219e+00
SUCCESS: BC gradients are flowing correctly.

--- Verification Complete ---


In [147]:
# ==============================================================================
# PHASE 3, STEP 3: DATA LOSS (NEGATIVE LOG-LIKELIHOOD)
# ==============================================================================
# Assume all previous code blocks (imports, constants, data loading, models) are in scope.

# --- 1. Data Loss Function Definition ---
def compute_data_loss(modulator, solver, data_batch):
    """
    Computes the data-driven loss (Negative Log-Likelihood) on experimental data.
    
    Args:
        modulator (nn.Module): The modulator network.
        solver (nn.Module): The solver network.
        data_batch (dict): A dictionary of tensors from our DataLoader, containing
                           'params', 't', and 'c_over_c0'.
    Returns:
        Tensor: A scalar tensor representing the NLL loss.
    """
    # Unpack the batch
    params_data = data_batch['params']
    t_data = data_batch['t']
    c_data = data_batch['c_over_c0']
    
    # Create the spatial coordinate 'x'. All experimental data is at the column outlet.
    x_coords = torch.ones_like(t_data, device=t_data.device) * TOTAL_LENGTH
    
    # Combine (x, t) into the coordinate tensor
    coords_data = torch.cat([x_coords, t_data], dim=1)

    # --- Forward Pass ---
    # Get the latent code and the predicted distribution for C/C0
    z = modulator(params_data)
    predicted_distribution, kl_loss = solver(coords_data, z)
    
    # --- Loss Calculation ---
    # The loss is the negative log-likelihood of the true data under the predicted distribution.
    # We take the mean across the batch.
    nll_loss = -torch.mean(predicted_distribution.log_prob(c_data))
    
    # We return both the NLL and the KL from the BNN layers,
    # as both were computed during this single forward pass.
    return nll_loss, kl_loss

# --- 2. Sanity Check Cell ---
if __name__ == '__main__':
    print("\n--- Testing Data Loss Function and Gradients ---")

    # 1. Instantiate models
    modulator = ModulatorNetwork(NUM_KNOWN_PARAMS, LATENT_DIM).to(device)
    solver = BayesianSolverWithFiLM(COORD_DIM, OUTPUT_DIM, LAYER_WIDTH, LATENT_DIM, NUM_HIDDEN_LAYERS).to(device)
    
    # 2. Create a dummy DataLoader
    dummy_data = {
        'porosity': np.random.rand(100), 'v_x': np.random.rand(100) * 5 + 5,
        'C0': np.random.rand(100) * 100 + 100, 't_min': np.random.rand(100) * 500,
        'C_over_C0': np.random.rand(100)
    }
    master_df_dummy = pd.DataFrame(dummy_data).astype(np.float32)
    dummy_dataset = BreakthroughDataset(master_df_dummy)
    dummy_loader = DataLoader(dummy_dataset, batch_size=16)
    
    # 3. Get one batch and compute loss
    first_batch = next(iter(dummy_loader))
    # Move batch to the correct device
    for key in first_batch:
        first_batch[key] = first_batch[key].to(device)
        
    # Zero out gradients
    modulator.zero_grad()
    solver.zero_grad()
    
    data_loss, kl = compute_data_loss(modulator, solver, first_batch)
    total_loss = data_loss + kl # For the gradient test, we combine them
    
    # 4. Verify loss values
    print(f"Data loss (NLL) computed: {data_loss.item():.4e}")
    print(f"Internal KL loss computed: {kl.item():.4e}")
    assert not torch.isnan(data_loss) and not torch.isnan(kl)
    
    # 5. Gradient Test
    print("\n--- Verifying Gradient Flow ---")
    total_loss.backward()
    
    data_grad_check_passed = all(p.grad is not None for p in list(modulator.parameters()) + list(solver.parameters()))
    if data_grad_check_passed:
        print("SUCCESS: Data loss gradients are flowing correctly.")
    else:
        print("FAILURE: Data loss gradients are missing.")
        
    print("\n--- Verification Complete ---")


--- Testing Data Loss Function and Gradients ---
Data loss (NLL) computed: 1.6119e+00
Internal KL loss computed: 9.2859e+05

--- Verifying Gradient Flow ---
SUCCESS: Data loss gradients are flowing correctly.

--- Verification Complete ---


In [199]:
# ==============================================================================
# FINAL, COMPLETE AND VERIFIED TRAINING SCRIPT
# ==============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import os
import time

# --- 1. SETUP: Constants and Device ---
device = torch.device("cpu")
print(f"Using device: {device}")

DIAMETER, SAND_LAYER_1_L, MIDDLE_LAYER_L, TOTAL_LENGTH = 16.0, 21.0, 50.0, 92.0
AREA_MM2 = np.pi * (DIAMETER / 2.0)**2
POROSITY_SAND, RHO_B = 0.43, 1.5e-3

NUM_KNOWN_PARAMS = 3
LATENT_DIM = 32
COORD_DIM = 2
OUTPUT_DIM = 1
LAYER_WIDTH = 128
NUM_HIDDEN_LAYERS = 4
EPOCHS = 1001
BATCH_SIZE_DATA = 32
BATCH_SIZE_BC = 64
BATCH_SIZE_IC = 64
BATCH_SIZE_PHYSICS = 128
LEARNING_RATE = 1e-3
KL_WEIGHT = 1.0 / BATCH_SIZE_PHYSICS
KL_ANNEALING_EPOCHS = 400


# --- 2. DATA LOADING (Your Verified Code) ---
master_df = create_master_dataframe(csv_files, params_list)
master_df['v_x'] = (master_df['inflow_rate_Q'] * 1000.0) / (AREA_MM2 * master_df['porosity'])
master_df['C_over_C0'] = master_df['C'] / master_df['C0']
master_df['x_dummy'] = 0.0

P_MIN = torch.tensor(master_df[['porosity', 'v_x', 'C0']].min().values, dtype=torch.float32, device=device)
P_RANGE = torch.tensor(master_df[['porosity', 'v_x', 'C0']].max().values - p_min_vals, dtype=torch.float32, device=device)
T_MAX = torch.tensor(master_df['t_min'].max(), dtype=torch.float32, device=device)
L_MAX = torch.tensor(TOTAL_LENGTH, dtype=torch.float32, device=device)

class BreakthroughDataset(Dataset):
    # Your verified Dataset class here...
    def __init__(self, df): self.df=df; self.p=df[['porosity','v_x','C0']].values; self.t=df[['t_min']].values; self.c=df[['C_over_C0']].values
    def __len__(self): return len(self.df)
    def __getitem__(self, idx): return {'params':torch.tensor(self.p[idx],dtype=torch.float32),'t':torch.tensor(self.t[idx],dtype=torch.float32),'c_over_c0':torch.tensor(self.c[idx],dtype=torch.float32)}

data_dataset = BreakthroughDataset(master_df)
data_loader = DataLoader(data_dataset, batch_size=BATCH_SIZE_DATA, shuffle=True)

# --- 3. ARCHITECTURE ---


class DenseBayesian(nn.Module):
    def __init__(self, in_features, out_features, activation=None):
        super().__init__()
        self.in_features, self.out_features = in_features, out_features
        self.kernel_mu = nn.Parameter(torch.empty(in_features, out_features))
        self.kernel_rho = nn.Parameter(torch.empty(in_features, out_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_rho = nn.Parameter(torch.empty(out_features))
        nn.init.xavier_uniform_(self.kernel_mu)
        nn.init.normal_(self.kernel_rho, mean=-3., std=1.)
        nn.init.zeros_(self.bias_mu)
        nn.init.normal_(self.bias_rho, mean=-3., std=1.)
        self.prior = torch.distributions.Laplace(0, 1.)
        self.activation = activation() if activation else nn.Identity()

    def forward(self, x):
        kernel_sigma = torch.nn.functional.softplus(self.kernel_rho)
        bias_sigma = torch.nn.functional.softplus(self.bias_rho)
        kernel_posterior = torch.distributions.Normal(self.kernel_mu, kernel_sigma)
        bias_posterior = torch.distributions.Normal(self.bias_mu, bias_sigma)
        w, b = kernel_posterior.rsample(), bias_posterior.rsample()
        kl_divergence = (kernel_posterior.log_prob(w) - self.prior.log_prob(w)).sum() + \
                        (bias_posterior.log_prob(b) - self.prior.log_prob(b)).sum()
        return self.activation(torch.matmul(x, w) + b), kl_divergence
        

class ModulatorNetwork(nn.Module):
    """Encodes known experimental parameters [ε, v_x, C₀] into a latent vector 'z'."""
    def __init__(self, input_param_dim, latent_dim):
        super().__init__()
        # self.net = nn.Sequential(
        #     nn.Linear(input_param_dim, 64),
        #     nn.Tanh(),
        #     nn.Linear(64, 64),
        #     nn.Tanh(),
        #     nn.Linear(64, latent_dim)
        # )
        self.net = nn.Sequential(
            nn.Linear(input_param_dim, latent_dim),
            nn.Linear(latent_dim, latent_dim)
        )
    def forward(self, physical_params):
        return self.net(physical_params)
        

class FiLMLayer(nn.Module):
    """Feature-wise Linear Modulation Layer."""
    def __init__(self, num_features, latent_dim):
        super().__init__()
        self.num_features = num_features
        self.projection = nn.Linear(latent_dim, num_features * 2)

    def forward(self,x,z):
        gamma,beta=self.projection(z).chunk(2,dim=-1);
        return gamma*x+beta

class BayesianSolverWithFiLM(nn.Module):
    """A Bayesian Solver that uses FiLM layers for conditioning."""
    def __init__(self, coord_dim, output_dim, layer_width, latent_dim, num_hidden_layers):
        super().__init__()
        self.output_dim = output_dim
        
        self.input_dense = DenseBayesian(coord_dim, layer_width, activation=nn.Tanh)


        self.hidden_blocks = nn.ModuleList([
            nn.ModuleDict({
                'dense': DenseBayesian(layer_width, layer_width),
                'film': FiLMLayer(layer_width, latent_dim),
                'activation': nn.Tanh()
            }) for _ in range(num_hidden_layers)
        ])

        self.output_dense = DenseBayesian(layer_width, output_dim * 2)

    def forward(self, coords, z):
        total_kl_loss = 0.
        
        x, kl = self.input_dense(coords)
        total_kl_loss += kl
        
        for block in self.hidden_blocks:
            x_res, kl = block['dense'](x)
            total_kl_loss += kl
            
            x_mod = block['film'](x_res, z)
            # Add a residual connection (as in your TF code)
            x = block['activation'](x_mod) + x_res

        raw_output, kl = self.output_dense(x)
        total_kl_loss += kl
        
        mu = raw_output[..., :self.output_dim]
        sigma_raw = raw_output[..., self.output_dim:]
        sigma = torch.nn.functional.softplus(sigma_raw) + 1e-6
        
        return torch.distributions.Normal(loc=mu, scale=sigma), total_kl_loss

# --- 4. Loss Functions ---

def compute_pde_residual(modulator, solver, coords_norm, params_norm):
    # Create a fresh copy and set requires_grad
    coords_norm = coords_norm.clone().detach().requires_grad_(True)
    
    c, kl_phys = solver(coords_norm, modulator(params_norm))
    c_mean = c.mean
    
    c_derivs = torch.autograd.grad(c_mean, coords_norm, torch.ones_like(c_mean), create_graph=True)[0]
    c_x_norm, c_t_norm = c_derivs[:, 0:1], c_derivs[:, 1:2]
    
    # Use create_graph=True but not retain_graph
    c_xx_norm_grad = torch.autograd.grad(c_x_norm, coords_norm, torch.ones_like(c_x_norm), create_graph=True, allow_unused=True)[0]
    
    # De-normalize everything for the physical calculation
    coords = coords_norm * torch.tensor([L_MAX, T_MAX], device=device)
    params = params_norm * P_RANGE + P_MIN
    c_t = c_t_norm / T_MAX
    c_x = c_x_norm / L_MAX
    c_xx = c_xx_norm_grad[:, 0:1] / (L_MAX**2) if c_xx_norm_grad is not None else torch.zeros_like(c_x)

    x_coord, porosity_middle, v_x = coords[:, 0:1], params[:, 0:1], params[:, 1:2]
    
    # Sample from distributions without modifying them
    D_L = d_l_dist.rsample()
    alpha = alpha_dist.rsample()
    beta = beta_dist.rsample()
    
    i1, i2 = SAND_LAYER_1_L, SAND_LAYER_1_L + MIDDLE_LAYER_L
    is_mid = (x_coord > i1) & (x_coord < i2)
    
    # Use where instead of in-place operations
    eps = torch.where(is_mid, porosity_middle, torch.tensor(POROSITY_SAND, device=device))
    alpha_val = torch.where(is_mid, alpha, torch.tensor(0., device=device))
    beta_val = torch.where(is_mid, beta, torch.tensor(0., device=device))
    
    R = 1.0 + (RHO_B / eps) * ((alpha_val * beta_val) / (1.0 + alpha_val * c_mean)**2)
    pde_residual = c_t * R - D_L * c_xx + v_x * c_x
    
    return torch.mean(pde_residual**2), kl_phys

def compute_data_loss(modulator, solver, coords_norm, params_norm, c_data):
    pred_dist, kl_loss = solver(coords_norm, modulator(params_norm))
    nll = -torch.mean(pred_dist.log_prob(c_data))
    return nll, kl_loss

def compute_ic_loss(modulator, solver, coords_norm, params_norm):
    c_pred, kl_ic = solver(coords_norm, modulator(params_norm))
    c_pred = c_pred.mean
    return torch.mean(c_pred**2), kl_ic

def compute_bc_loss(modulator, solver, coords_norm, params_norm):
    c_pred, kl_bc = solver(coords_norm, modulator(params_norm))
    c_pred = c_pred.mean
    return torch.mean((c_pred - 1.0)**2), kl_bc

# --- 5. Training Step and Loop ---

def train_step(modulator, solver, optimizer, batch, current_kl_weight):
    # Clear gradients
    optimizer.zero_grad()
    
    # --- Loss Calculation ---
    # Create fresh copies of tensors
    coords_data_norm = batch['coords_data_norm'].clone().detach()
    params_data_norm = batch['params_data_norm'].clone().detach()
    c_data = batch['c_data'].clone().detach()
    
    coords_phys_norm = batch['coords_phys_norm'].clone().detach()
    params_phys_norm = batch['params_phys_norm'].clone().detach()
    
    coords_ic_norm = batch['coords_ic_norm'].clone().detach()
    params_ic_norm = batch['params_ic_norm'].clone().detach()
    
    coords_bc_norm = batch['coords_bc_norm'].clone().detach()
    params_bc_norm = batch['params_bc_norm'].clone().detach()
    
    # Calculate losses with cloned tensors
    loss_data, kl_from_data_pass = compute_data_loss(modulator, solver, coords_data_norm, params_data_norm, c_data)
    loss_phys, kl_phys = compute_pde_residual(modulator, solver, coords_phys_norm, params_phys_norm)
    loss_ic, kl_ic = compute_ic_loss(modulator, solver, coords_ic_norm, params_ic_norm)
    loss_bc, kl_bc = compute_bc_loss(modulator, solver, coords_bc_norm, params_bc_norm)

    # Aggregate KL from all forward passes
    kl_weights = kl_from_data_pass + kl_phys + kl_ic + kl_bc
    
    # KL for our learnable physical parameters
    kl_params = (torch.distributions.kl_divergence(d_l_dist, torch.distributions.LogNormal(0., 1.)) +
                 torch.distributions.kl_divergence(alpha_dist, torch.distributions.LogNormal(0., 1.)) +
                 torch.distributions.kl_divergence(beta_dist, torch.distributions.LogNormal(0., 1.)))

    
    # --- Total Loss ---
    total_loss = loss_data + loss_phys + loss_ic + loss_bc + current_kl_weight * (kl_weights + kl_params)

    # --- Backpropagation ---
    total_loss.backward()
    optimizer.step()
    
    # Create a new dictionary with detached tensors
    loss_dict = {
        "total": total_loss.detach().clone(),
        "data": loss_data.detach().clone(), 
        "phys": loss_phys.detach().clone(), 
        "ic": loss_ic.detach().clone(), 
        "bc": loss_bc.detach().clone(), 
        "kl": (kl_weights + kl_params).detach().clone()
    }
    
    return loss_dict

if __name__ == '__main__':

    log_D_L_loc = torch.nn.Parameter(torch.log(torch.tensor(0.01)))
    log_D_L_scale = torch.nn.Parameter(torch.tensor(0.5))
    log_alpha_loc = torch.nn.Parameter(torch.log(torch.tensor(0.015)))
    log_alpha_scale = torch.nn.Parameter(torch.tensor(0.5))
    log_beta_loc = torch.nn.Parameter(torch.log(torch.tensor(29.185)))
    log_beta_scale = torch.nn.Parameter(torch.tensor(0.5))
    
    # d_l_dist = torch.distributions.LogNormal(log_D_L_loc, torch.nn.functional.softplus(log_D_L_scale))
    # alpha_dist = torch.distributions.LogNormal(log_alpha_loc, torch.nn.functional.softplus(log_alpha_scale))
    # beta_dist = torch.distributions.LogNormal(log_beta_loc, torch.nn.functional.softplus(log_beta_scale))

    print("\n--- Unified Test of All Loss Components and train_step ---")

    modulator = ModulatorNetwork(NUM_KNOWN_PARAMS, LATENT_DIM).to(device)
    solver = BayesianSolverWithFiLM(COORD_DIM, OUTPUT_DIM, LAYER_WIDTH, LATENT_DIM, NUM_HIDDEN_LAYERS).to(device)
    
    # The list of all trainable parameters must be passed to the optimizer
    all_trainable_vars = list(modulator.parameters()) + list(solver.parameters()) + \
                         [log_D_L_loc, log_D_L_scale, log_alpha_loc, log_alpha_scale, log_beta_loc, log_beta_scale]
    optimizer = torch.optim.Adam(all_trainable_vars, lr=LEARNING_RATE)

    # --- Initialize Sobol Samplers ---
    sobol_engine_phys = torch.quasirandom.SobolEngine(dimension=2)
    sobol_engine_params = torch.quasirandom.SobolEngine(dimension=3)

    
    # # --- Prepare a single, comprehensive batch dictionary ---
    # batch = {}

    # # 1. Data points (using standard random sampling from dataframe)
    # data_sample = master_df.sample(n=BATCH_SIZE_DATA)
    # params_data = torch.tensor(data_sample[['porosity', 'v_x', 'C0']].values, device=device)
    # t_data = torch.tensor(data_sample[['t_min']].values, device=device)
    # c_data = torch.tensor(data_sample[['C_over_C0']].values, device=device)
    # x_data = torch.ones_like(t_data) * TOTAL_LENGTH
    
    # # Normalize and add to batch
    # batch['c_data'] = c_data
    # batch['coords_data_norm'] = (torch.cat([x_data, t_data], dim=1) / torch.tensor([L_MAX, T_MAX], device=device))
    # batch['params_data_norm'] = ((params_data - P_MIN) / P_RANGE)

    # # 2. Physics points (using Sobol sampling)
    # phys_params_norm = sobol_engine_params.draw(BATCH_SIZE_PHYSICS).to(device)
    # phys_coords_norm = sobol_engine_phys.draw(BATCH_SIZE_PHYSICS).to(device)
    # batch['params_phys_norm'] = phys_params_norm
    # batch['coords_phys_norm'] = phys_coords_norm
    
    # # 3. IC points (using Sobol sampling)
    # ic_params_norm = sobol_engine_params.draw(BATCH_SIZE_IC).to(device)
    # # Sobol for x, combined with zeros for t
    # ic_coords_norm = torch.cat([sobol_engine_phys.draw(BATCH_SIZE_IC)[:, 0:1].to(device), torch.zeros(BATCH_SIZE_IC, 1, device=device)], dim=1)
    # batch['params_ic_norm'] = ic_params_norm
    # batch['coords_ic_norm'] = ic_coords_norm

    # # 4. BC points (using Sobol sampling)
    # bc_params_norm = sobol_engine_params.draw(BATCH_SIZE_BC).to(device)
    # # Zeros for x, Sobol for t
    # bc_coords_norm = torch.cat([torch.zeros(BATCH_SIZE_BC, 1, device=device), sobol_engine_phys.draw(BATCH_SIZE_BC)[:, 1:2].to(device)], dim=1)
    # batch['params_bc_norm'] = bc_params_norm
    # batch['coords_bc_norm'] = bc_coords_norm
    
    # print("Batch prepared successfully using Sobol sampling.")
    
    # # --- Execute and Test ---
    # # The train_step function itself is correct now that the decorator is removed.
    # loss_dict = train_step(modulator, solver, optimizer, batch, current_kl_weight=0.1)

    # print("\n--- Test Results ---")
    # for key, val in loss_dict.items():
    #     print(f"Loss {key}: {val.item():.4e}")
    #     assert not torch.isnan(val)
    # print("\nSUCCESS: The full, normalized train_step is verified with Sobol sampling.")


        # --- Execute and Test ---
    # Run multiple iterations safely
    # --- Execute and Test ---
    print("Running multiple training iterations...")
    
    for i in range(10):
        print(f"Iteration {i+1}")
        
        # Recreate the distributions for each iteration
        d_l_dist = torch.distributions.LogNormal(log_D_L_loc, torch.nn.functional.softplus(log_D_L_scale))
        alpha_dist = torch.distributions.LogNormal(log_alpha_loc, torch.nn.functional.softplus(log_alpha_scale))
        beta_dist = torch.distributions.LogNormal(log_beta_loc, torch.nn.functional.softplus(log_beta_scale))
        
        # Create a fresh batch for each iteration
        fresh_batch = {}
        
        # 1. Data points
        data_sample = master_df.sample(n=BATCH_SIZE_DATA)
        params_data = torch.tensor(data_sample[['porosity', 'v_x', 'C0']].values, device=device)
        t_data = torch.tensor(data_sample[['t_min']].values, device=device)
        c_data = torch.tensor(data_sample[['C_over_C0']].values, device=device)
        x_data = torch.ones_like(t_data) * TOTAL_LENGTH
        
        fresh_batch['c_data'] = c_data
        fresh_batch['coords_data_norm'] = (torch.cat([x_data, t_data], dim=1) / torch.tensor([L_MAX, T_MAX], device=device))
        fresh_batch['params_data_norm'] = ((params_data - P_MIN) / P_RANGE)
    
        # 2. Physics points
        fresh_batch['params_phys_norm'] = sobol_engine_params.draw(BATCH_SIZE_PHYSICS).to(device)
        fresh_batch['coords_phys_norm'] = sobol_engine_phys.draw(BATCH_SIZE_PHYSICS).to(device)
        
        # 3. IC points
        fresh_batch['params_ic_norm'] = sobol_engine_params.draw(BATCH_SIZE_IC).to(device)
        fresh_batch['coords_ic_norm'] = torch.cat([
            sobol_engine_phys.draw(BATCH_SIZE_IC)[:, 0:1].to(device), 
            torch.zeros(BATCH_SIZE_IC, 1, device=device)
        ], dim=1)
    
        # 4. BC points
        fresh_batch['params_bc_norm'] = sobol_engine_params.draw(BATCH_SIZE_BC).to(device)
        fresh_batch['coords_bc_norm'] = torch.cat([
            torch.zeros(BATCH_SIZE_BC, 1, device=device), 
            sobol_engine_phys.draw(BATCH_SIZE_BC)[:, 1:2].to(device)
        ], dim=1)
        
        # Run a single training step with the fresh batch
        loss_dict = train_step(modulator, solver, optimizer, fresh_batch, current_kl_weight=1e-6)
        print(f"  Total loss: {loss_dict['total'].item():.4e}")
        
        # Clear any cached tensors
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print("\n--- Final Test Results ---")
    for key, val in loss_dict.items():
        print(f"Loss {key}: {val.item():.4e}")
        assert not torch.isnan(val)
    print("\nSUCCESS: The full, normalized train_step is verified with multiple iterations.")
    

Using device: cpu

--- Unified Test of All Loss Components and train_step ---
Running multiple training iterations...
Iteration 1
  Total loss: 2.8968e+05
Iteration 2
  Total loss: 5.4394e+00
Iteration 3
  Total loss: 4.0111e+01
Iteration 4
  Total loss: 8.9989e+06
Iteration 5
  Total loss: 7.9592e+00
Iteration 6
  Total loss: 1.0615e+01
Iteration 7
  Total loss: 2.6060e+02
Iteration 8
  Total loss: 9.0806e+00
Iteration 9
  Total loss: 2.3000e+06
Iteration 10
  Total loss: 5.5874e+01

--- Final Test Results ---
Loss total: 5.5874e+01
Loss data: 5.1589e+00
Loss phys: 1.3582e-01
Loss ic: 2.2340e+01
Loss bc: 2.7595e+01
Loss kl: 6.4548e+05

SUCCESS: The full, normalized train_step is verified with multiple iterations.
