In [None]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../..")
print(os.getcwd())

import torch
from modules.amine_blend_pipeline import AmineBlendPipeline
import modules.datasplit_module as dsm
from modules.model_loader import load_model
from torch_geometric.loader import DataLoader
import numpy as np
import random

# --- Reproducibility settings ---
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
seed = 21
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
# --- Process data ---
pipeline = AmineBlendPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')
pipeline.save_canonical_df(canonical_data, 'datasets/canonical_data.csv')
random.shuffle(graph_list)
train_raw, val_raw, test_raw, train_std, val_std, test_std, stats = \
    dsm.standardized_system_disjoint_split(graph_list, random_state=seed)

In [None]:
# Load model
model, stats = load_model('checkpoints/01_dummy.pt', return_stats=True)
model = model.to('cuda')
# Load data
train_loader = DataLoader(
    dataset=train_std,
    batch_size=128,
    shuffle=True,
    follow_batch=['component_batch']
)

val_loader = DataLoader(
    dataset=val_std,
    batch_size=128,
    shuffle=False,
    follow_batch=['component_batch']
)

test_loader = DataLoader(
    dataset=test_std,
    batch_size=128,
    shuffle=False,
    follow_batch=['component_batch']
)

In [None]:
def analyze_batch(model, batched_data, device='cpu'):
    """
    Perform forward pass and analyze a single batch.
    
    Args:
        model: Trained GDMPNN model
        batched_data: PyG Data object
        device: Device to run on ('cpu' or 'cuda')
    
    Returns:
        dict: Analysis results including predictions and intermediate values (all as tensors)
    """
    model.eval()
    batched_data = batched_data.to(device)
    
    with torch.no_grad():
        y_pred, latent_vectors, comp_emb = model(batched_data)
    
    # Unstandardize values
    T_actual = batched_data.T * batched_data.T_std + batched_data.T_mean
    pco2_actual = batched_data.pco2 * batched_data.pco2_std + batched_data.pco2_mean
    
    # Calculate thermodynamic quantities
    R = 8.314  # Gas constant J/(mol·K)
    gamma_i = latent_vectors.mean(dim=1).view(1, -1)
    G_excess = gamma_i * R * T_actual.view(-1, 1)
    G_excess_sum = G_excess.sum(dim=1)
    comp_emb_mean = comp_emb.mean(dim=1).view(1, -1)
    
    torch.set_printoptions(precision=2, sci_mode=False)
    # Print formatted output
    print("\n" + "="*120)
    print("FORWARD PASS INPUTS:")
    print("="*120)
    print(f"Component names                     : {batched_data.component_names}")
    print(f"Component mass fractions            : {batched_data.component_mass_frac}")
    print(f"Component mole fractions            : {batched_data.component_mole_frac}")
    print(f"System temperature                  : {T_actual} Kelvin")
    print(f"Partial pressure of CO2             : {pco2_actual} kPa")
    
    print("\n" + "="*120)
    print("FORWARD PASS OUTPUTS:")
    print("="*120)
    print(f"Predicted CO2 solubility            : {y_pred.detach()}")
    print(f"Actual CO2 solubility               : {batched_data.aco2}")
    print(f"Absolute error                      : {torch.abs(y_pred - batched_data.aco2).detach()}")
    print(f"Relative error (%)                  : {(torch.abs(y_pred - batched_data.aco2) / batched_data.aco2 * 100).detach()}")
    
    print("\n" + "="*120)
    print("SOLVENT SUBSYSTEM ANALYSIS:")
    print("="*120)
    print(f"Aggregated latent vector (γᵢ)       : {gamma_i.detach()}")
    print(f"Excess Gibbs Energy (Gᵢ_excess)     : {G_excess.detach()} J/mol")
    print(f"Sum of Excess Gibbs Energy          : {G_excess_sum.detach()} J/mol")
    print(f"Aggregated component embedding      : {comp_emb_mean.detach()}")
    print("="*120 + "\n")
    
    # Return results as dict (all tensors)
    results = {
        'component_names': batched_data.component_names,
        'mass_fractions': batched_data.component_mass_frac,
        'mole_fractions': batched_data.component_mole_frac,
        'temperature': T_actual,
        'pco2': pco2_actual,
        'predicted_aco2': y_pred.detach(),
        'actual_aco2': batched_data.aco2,
        'absolute_error': torch.abs(y_pred - batched_data.aco2).detach(),
        'relative_error': (torch.abs(y_pred - batched_data.aco2) / batched_data.aco2 * 100).detach(),
        'latent_vectors': latent_vectors.detach(),
        'gamma_i': gamma_i.detach(),
        'G_excess': G_excess.detach(),
        'G_excess_sum': G_excess_sum.detach(),
        'component_embeddings': comp_emb.detach(),
        'comp_emb_mean': comp_emb_mean.detach()
    }
    
    return results


def analyze_sample(model, dataset, sample_idx, device='cpu'):
    """
    Analyze a specific sample from the dataset.
    
    Args:
        model: Trained GDMPNN model
        dataset: PyG dataset
        sample_idx: Index of the sample to analyze
        device: Device to run on ('cpu' or 'cuda')
    
    Returns:
        dict: Analysis results (all tensors)
    """
    loader = DataLoader(
        dataset=dataset,
        batch_size=1,
        shuffle=False,
        follow_batch=['component_batch']
    )
    
    for i, batched_data in enumerate(loader):
        if i == sample_idx:
            return analyze_batch(model, batched_data, device)
    
    raise IndexError(f"Sample index {sample_idx} not found in dataset of size {len(dataset)}")

In [None]:
results = analyze_sample(model, test_std, sample_idx=45, device='cuda')