In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as tick
import matplotlib as mpl
import h5py
import os
from utils import get_train_data, prepare_prediction_data, read_h5_file

get_dir = os.getcwd()
os.chdir(get_dir)

In [None]:
def get_predictions_and_true_data_poisson(mdir, ddir, indices):
    context_indices = np.arange(60)  # Default context size
    model = tf.keras.models.load_model(mdir, safe_mode=False)
    
    # Read Poisson data using the new utility function
    f_all, x_all, u_all, bc_all = read_h5_file(ddir)
    
    u_pred = []
    u_true = []
    f = []
    xs = []
    
    for i in indices:
        print(f"Processing sample {i}...")
        
        # Prepare data for this sample using utility function
        sample_data = prepare_prediction_data(x_all, f_all, u_all, context_indices, i)
        
        # Get inputs for model prediction
        x_input = sample_data['x']
        f_input = sample_data['f']
        xbc = sample_data['xbc']
        fbc = sample_data['fbc']
        ubc = sample_data['ubc']
        
        # Generate predictions
        u_prediction = model.predict([x_input, f_input, xbc, fbc, ubc], 
                                   batch_size=1024, verbose=0)
        
        u_pred.append(u_prediction.flatten())
        u_true.append(sample_data['u_true'].flatten())
        f.append(f_input.flatten())
        xs.append(x_input.flatten())
    
    print("Prediction generation completed!")
    return u_pred, u_true, f, xs

In [None]:
def plot_poisson_solutions(x, u_pred, u_true, f, indices, plot_save=False, plot_dir='poisson_results', fig_size=(15, 10)):
    n_samples = len(u_pred)
    fig, ax = plt.subplots(4, n_samples, figsize=fig_size, sharex=True)
    
    if n_samples == 1:
        ax = ax.reshape(-1, 1)
    
    for i in range(n_samples):
        # Solution p(x)
        ax[0][i].plot(x[i], u_pred[i], 'b-', label='PINTO Prediction', linewidth=2)
        ax[0][i].plot(x[i], u_true[i], 'r--', label='True Solution', linewidth=2)
        ax[0][i].set_title(f'Solution p(x) - Sample {indices[i]}')
        ax[0][i].set_ylabel('u(x)')
        ax[0][i].legend()
        ax[0][i].grid(True, alpha=0.3)
        
        # Forcing term f(x)
        ax[2][i].plot(x[i], f[i], 'm-', label='Forcing f(x)', linewidth=2)
        ax[2][i].set_title(f'Forcing Term f(x) - Sample {indices[i]}')
        ax[2][i].set_ylabel('f(x)')
        ax[2][i].legend()
        ax[2][i].grid(True, alpha=0.3)
        
        # Relative error
        rel_error = np.abs(u_pred[i] - u_true[i]) / (1 + np.abs(u_true[i]))
        ax[3][i].plot(x[i], rel_error, 'k-', label='Relative Error', linewidth=2)
        ax[3][i].set_title(f'Relative Error - Sample {indices[i]}')
        ax[3][i].set_ylabel('Relative Error')
        ax[3][i].set_xlabel('x')
        ax[3][i].legend()
        ax[3][i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if plot_save:
        plt.savefig(f'{plot_dir}.png', dpi=300, bbox_inches='tight', format='png')
    
    plt.show()

In [None]:
# Set up paths and parameters
mdir = './poisson_output/Poisson_model.keras'  # Path to trained model
ddir = './data/poisson_5000.h5'  # Path to Poisson data

# Define sample indices for analysis
train_indices = [10, 25, 40, 55]  # Seen boundary conditions
test_indices = [85, 90, 95]       # Unseen boundary conditions
all_indices = train_indices + test_indices

# Generate predictions
u_pred, u_true, f, xs = get_predictions_and_true_data_poisson(
    mdir, ddir, all_indices 
)

plot_poisson_solutions(xs[:len(train_indices)], u_pred[:len(train_indices)], u_true[:len(train_indices)], f[:len(train_indices)], train_indices, plot_save=True, plot_dir='Poisson_Train_Results', fig_size=(16, 12))

plot_poisson_solutions(xs[len(train_indices):], u_pred[len(train_indices):], u_true[len(train_indices):], f[len(train_indices):], test_indices, plot_save=True, plot_dir='Poisson_Test_Results', fig_size=(12, 12))