In [None]:
# IMPORTING LIBRARIES
import copy
import os
import random
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'
import numpy as np
import torch
from torch_geometric.nn import global_mean_pool
from torch_geometric.loader import DataLoader
# FUNCTIONS
from data_processing import load_dataset, smiles_to_graph, process_dataset
from path_helpers import get_path
from stats_compute import compute_statistics, scale_graphs
from smart_loader import load_model_for_inference
from EnhancedDataSplit import DataSplitter
# DIRECTORY SETUP
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

In [None]:
# HYPERPARAMETER SETTINGS
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Reproducibility settings
seed = 21
split_seed = 42
# Hyperparameters
batch_size = 32
runtime = timestamp

###########################
model_inference_dir = os.path.join(
    os.path.dirname(os.getcwd()),
    "models",
    "models_root",
    "model_for_inference",
    "ras_pinn"
)
# List files and pick the first one
files = sorted(os.listdir(model_inference_dir))  # sorted to make it deterministic
if len(files) == 0:
    raise FileNotFoundError(f"No files found in {model_inference_dir}")
model_file = files[0]  # first file in directory
print(f"Using model file: {model_file}")
# Full path
path = os.path.join(model_inference_dir, model_file)
########### IMPORTING MODEL ###############
selected_device = 'cuda'
device = torch.device(selected_device)
model = load_model_for_inference(path, device=device)
############################################
###########################

# CUDA Deterministic (ON/OFF SETTING)
# For PyTorch
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True


print('device           :', device)
print('seed             :', seed)
print('split seed       :', split_seed)

In [None]:
# LOAD & GRAPH GENERATION FOR EITHER SRS OR RAS
df_components = load_dataset(get_path(file_name = 'components_set.csv', folder_name='datasets'))
smiles_dict = dict(zip(df_components['Abbreviation'], df_components['SMILES']))
df_systems = load_dataset(get_path(file_name = 'systems_set.csv', folder_name='datasets'))
smiles_list = df_components["SMILES"].dropna().tolist()
mol_name_dict = smiles_dict.copy()
# GRAPH
system_graphs = process_dataset(df_systems, smiles_dict)
# LOAD DATASET
splitter = DataSplitter(system_graphs, random_state=split_seed)
splitter.print_dataset_stats()
# Options: rarity_aware_unseen_amine_split stratified_random_split
train_data, val_data, test_data = splitter.rarity_aware_unseen_amine_split()
#Retrieve the statistics of train_data
stats = compute_statistics(train_data)
conc_mean = stats[0]
conc_std = stats[1]
temp_mean = stats[2]
temp_std = stats[3]
pco2_mean = stats[4]
pco2_std = stats[5]
#Apply the scaling to validation and test
original_train_data = copy.deepcopy(train_data)
original_val_data = copy.deepcopy(val_data)
original_test_data = copy.deepcopy(test_data)
combined_original_data = original_train_data + original_val_data + original_test_data
train_data = scale_graphs(train_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
val_data = scale_graphs(val_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
test_data = scale_graphs(test_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
#Load the data into DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
# Mol name processing
test_molecules = list(set([data.name for data in test_data]))
# Separate by molecule name
molecule_groups_test = {}
for mol_name in test_molecules:
   molecule_groups_test[mol_name] = [data for data in test_data if data.name == mol_name]

val_molecules = list(set([data.name for data in val_data]))

# Separate by molecule name
molecule_groups_val = {}
for mol_name in val_molecules:
   molecule_groups_val[mol_name] = [data for data in val_data if data.name == mol_name]

In [None]:
# MODEL WRAPPER
class EmbeddingsExtractor(torch.nn.Module):
    def __init__(self, full_model):
        super().__init__()
        self.graph_block = full_model.graph_block
        self.use_adaptive_pooling = full_model.use_adaptive_pooling
        if self.use_adaptive_pooling:
            self.adaptive_pool = full_model.adaptive_pool

    def forward(self, data, extract_embeddings=False, include_conditions=False):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.graph_block(x, edge_index, edge_attr)
        
        if self.use_adaptive_pooling:
            x = self.adaptive_pool(x, batch)
        else:
            x = global_mean_pool(x, batch)

        if extract_embeddings:
            return x

        if include_conditions:
            additional_features = torch.stack([data.conc, data.temp, data.pco2], dim=1).float()
            x = torch.cat([x, additional_features], dim=1)
        return x

class FCNNWrapper:
    def __init__(self, fc_block, device):
        self.fc_block = fc_block
        self.device = device
        self.fc_block.eval()
    def __call__(self, x_numpy):
        x_tensor = torch.from_numpy(x_numpy).float().to(self.device)
        with torch.no_grad():
            y = self.fc_block(x_tensor)
        return y.cpu().numpy()
fc_wrapper = FCNNWrapper(model.fc_block, device)

In [None]:
# Filter molecular graph from combined dataset  & define embedding_generator
unique_named_graphs = {}
for graph in combined_original_data:
    name = graph['name']
    if name not in unique_named_graphs:
        unique_named_graphs[name] = graph
unique_graph_list = list(unique_named_graphs.values())
mol_loader = DataLoader(unique_graph_list, batch_size=batch_size, shuffle=False)
# Execute the EmbeddingsExtractor
embedding_generator = EmbeddingsExtractor(model).to(device)
embedding_generator.eval()
all_embeddings = []
all_amines = []
all_data = []
with torch.no_grad():
    for data in mol_loader:
        data = data.to(device)
        combined_embedding = embedding_generator(data, 
                                                 include_conditions=False, 
                                                 extract_embeddings=True)
        all_embeddings.append(combined_embedding.cpu())
        all_amines.extend(data.name)
        all_data.extend(data.cpu().to_data_list())

# Now after the loop
all_embeddings = torch.cat(all_embeddings)
index_to_name = {i: name for i, name in enumerate(all_amines)}

# Assigning extracted embeddings to Data object as tensor
name_to_embedding = {name: emb for name, emb in zip(all_amines, all_embeddings)}
for dataset in [train_data, val_data, test_data]:
    for data in dataset:
        name = data.name
        data.embedding = name_to_embedding[name]
        if name in smiles_dict:
            data.smiles = smiles_dict[name]
        else:
            data.smiles = None

def extract_combined_vectors(loader, model, embedding_generator, device):
    vectors = []
    model.eval()
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            graph_emb = embedding_generator(data, extract_embeddings=True)
            ext_feats = torch.stack([data.conc, data.temp, data.pco2], dim=1).float()
            combined = torch.cat([graph_emb, ext_feats], dim=1)
            vectors.append(combined.cpu())
    return torch.cat(vectors).numpy()
background_vectors = extract_combined_vectors(train_loader, model, embedding_generator, device)
val_vectors        = extract_combined_vectors(val_loader, model, embedding_generator, device)
test_vectors       = extract_combined_vectors(test_loader, model, embedding_generator, device)
num_emb_dims = test_vectors.shape[1] - 3
feature_names = [f"embedding_{i}" for i in range(num_emb_dims)] + ['conc', 'temp', 'pco2']

combined_data_with_embeddings = train_data + val_data + test_data
unique_named_graphs_with_embeddings = {}
for graph in combined_data_with_embeddings:
    name = graph['name']
    if name not in unique_named_graphs_with_embeddings:
        unique_named_graphs_with_embeddings[name] = graph
unique_graph_list_with_embeddings = list(unique_named_graphs_with_embeddings.values())

In [None]:
""" from collections import defaultdict
import numpy as np
import warnings
import shap

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*NumPy global RNG.*")

# ---- 1) Prepare background once ----
background = np.stack([
    np.concatenate([data.embedding.numpy(), 
                    np.array([data.conc, data.temp, data.pco2])])
    for data in train_data
])
background = background[np.random.choice(background.shape[0], 100, replace=False)]

explainer = shap.KernelExplainer(fc_wrapper, background)

# ---- 2) Prepare full test set ----
test_vectors = np.stack([
    np.concatenate([data.embedding.numpy(), 
                    np.array([data.conc, data.temp, data.pco2])])
    for data in val_data
])
molecule_names = [data.name for data in val_data]

# ---- 3) Compute SHAP for full test set ----
shap_values = explainer.shap_values(test_vectors)

# ---- 4) Plot summary for all test samples ----
shap.summary_plot(
    np.squeeze(shap_values),
    test_vectors,
    feature_names=feature_names,
    plot_type="dot",
    show=True,
    max_display=10
)

# ---- 5) Extract top 5 embeddings per molecule ----
embedding_size = len(train_data[0].embedding.numpy())
molecule_top_embeddings = defaultdict(dict)

for molecule_name in set(molecule_names):
    indices_mask = [i for i, name in enumerate(molecule_names) if name == molecule_name]
    molecule_vectors = test_vectors[indices_mask]
    molecule_shap = np.abs(np.squeeze(shap_values)[indices_mask, :embedding_size]).mean(axis=0)
    top5_indices = np.argsort(molecule_shap)[-5:][::-1]
    avg_embeddings = molecule_vectors[:, :embedding_size].mean(axis=0)
    top5_embeddings = avg_embeddings[top5_indices]
    
    molecule_top_embeddings[molecule_name] = {
        'indices': top5_indices,
        'values': top5_embeddings,
        'importance_scores': molecule_shap[top5_indices]
    }

    print(f"\n=== {molecule_name} ===")
    print(f"Top 5 embedding indices: {top5_indices}")
    print(f"Top 5 embedding values: {top5_embeddings}")
    print(f"Importance scores: {molecule_shap[top5_indices]}") """

In [113]:
# Hypothesis generator
from torch_geometric.data import Batch
import torch

def predict_from_smiles(model, smiles_list, conc_list, temp_list, pco2_list,
                        conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std,
                        device='cuda'):
    """
    Predict alpha_CO2 from raw SMILES + concentration/temp/pco2,
    using scaling from training set. Supports CUDA.
    """
    model.eval()
    model.to(device)
    data_list = []

    for smi, conc, temp, pco2 in zip(smiles_list, conc_list, temp_list, pco2_list):
        # Convert SMILES → PyG Data
        data = smiles_to_graph(smi, None)

        # Apply scaling
        conc_scaled = (conc - conc_mean) / conc_std
        temp_scaled = (temp - temp_mean) / temp_std
        pco2_scaled = (pco2 - pco2_mean) / pco2_std

        # Attach features as tensors on the same device
        data.conc = torch.tensor(conc_scaled, dtype=torch.float, device=device)
        data.temp = torch.tensor(temp_scaled, dtype=torch.float, device=device)
        data.pco2 = torch.tensor(pco2_scaled, dtype=torch.float, device=device)

        data_list.append(data)

    # Batch all molecules and move to device
    batch_data = Batch.from_data_list(data_list).to(device)

    # Forward pass
    with torch.no_grad():
        preds = model(batch_data).squeeze()

    return preds.cpu().numpy()
def round_significant(x, sig=3):
    """
    Round array or number x to `sig` significant digits
    """
    x = np.array(x)
    return np.array([float(f"{v:.{sig}g}") for v in x])

# Prediction maker
smiles_list = ["CC(CN(C)C)O", "CC(CN(C)C)O"]
conc = 2
temp = 313.15
pco2 = 15

# Repeat the scalars for each SMILES
conc_list = [conc] * len(smiles_list)
temp_list = [temp] * len(smiles_list)
pco2_list = [pco2] * len(smiles_list)

preds = round_significant(
    predict_from_smiles(
        model, smiles_list, conc_list, temp_list, pco2_list,
        conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std,
        device='cuda'
    ),
    sig=5
)
for smi, pred in zip(smiles_list, preds):
    print(40 * '-')
    print(f"SMILES -> {smi} \n Conc: {conc} M \n Temp: {temp} K \n pCO2: {pco2} kPa \n α Pred -> {pred} mol/mol")
    print(40 * '-')

----------------------------------------
SMILES -> CC(CN(C)C)O 
 Conc: 2 M 
 Temp: 313.15 K 
 pCO2: 15 kPa 
 α Pred -> 0.64809 mol/mol
----------------------------------------
----------------------------------------
SMILES -> CC(CN(C)C)O 
 Conc: 2 M 
 Temp: 313.15 K 
 pCO2: 15 kPa 
 α Pred -> 0.64809 mol/mol
----------------------------------------


# INTEGRATED GRADIENT ATTRIBUTION (ALL)

In [None]:
unique_named_graphs = {}
for graph in combined_original_data:
    name = graph['name']
    if name not in unique_named_graphs:
        unique_named_graphs[name] = graph
unique_graph_list = list(unique_named_graphs.values())
mol_loader = DataLoader(unique_graph_list, batch_size=batch_size, shuffle=False)
# Execute the EmbeddingsExtractor
embedding_generator = EmbeddingsExtractor(model).to(device)
embedding_generator.eval()
all_embeddings = []
all_amines = []
all_data = []
with torch.no_grad():
    for data in mol_loader:
        data = data.to(device)
        combined_embedding = embedding_generator(data, 
                                                 include_conditions=False, 
                                                 extract_embeddings=True)
        all_embeddings.append(combined_embedding.cpu())
        all_amines.extend(data.name)
        all_data.extend(data.cpu().to_data_list())

# Now after the loop
all_embeddings = torch.cat(all_embeddings)
index_to_name = {i: name for i, name in enumerate(all_amines)}

In [None]:
from captum.attr import IntegratedGradients
import torch
import numpy as np
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from matplotlib import cm, colors
import matplotlib.pyplot as plt
from PIL import Image
import io

model.eval()

# Modified forward: return pooled embedding vector (all dims)
def embedding_forward_all(x, data):
    x = embedding_generator.graph_block(x, data.edge_index, data.edge_attr)
    return x  # no indexing — return full embedding

molecule_node_attributions = {}

for molecule_name in [d.name for d in all_data]:
    print(f"\n=== Captum Analysis for {molecule_name} (ALL embeddings) ===")
    
    graph_data = next(data for data in all_data if data.name == molecule_name)
    graph_data = graph_data.to(device)
    
    x = graph_data.x.clone().detach().float().to(device)
    x.requires_grad = True
    baseline = torch.zeros_like(x)
    
    ig = IntegratedGradients(lambda x, data: embedding_forward_all(x, data).sum(dim=1))  
    # sum over embedding dims → scalar output
    
    attr = ig.attribute(
        inputs=x,
        baselines=baseline,
        additional_forward_args=(graph_data,),
        n_steps=50
    )  # shape: (num_nodes, num_node_features)
    
    # Aggregate feature-level attributions per node → scalar signed attribution
    node_signed = attr.sum(dim=1).detach().cpu().numpy()
    node_magnitude = np.abs(node_signed)
    
    molecule_node_attributions[molecule_name] = {
        'signed': node_signed,
        'magnitude': node_magnitude
    }


In [None]:
from pathlib import Path
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from PIL import Image, ImageChops
import io

def trim_white(img):
    """Crop extra white space around the molecule drawing."""
    bg = Image.new(img.mode, img.size, (255, 255, 255))
    diff = ImageChops.difference(img, bg)
    bbox = diff.getbbox()
    if bbox:
        return img.crop(bbox)
    return img

# === Modified IG Visualization ===
def visualize_combined_molecule(molecule_data, node_attributions, figsize=(800, 800), save_path=None, show=True):
    mol = molecule_data.mol
    if mol is None:
        print(f"No RDKit molecule available for {molecule_data.name}")
        return
    
    mol = Chem.Mol(mol)
    if mol.GetNumConformers() == 0:
        AllChem.Compute2DCoords(mol)
    
    node_attribute = node_attributions['signed']
    # Normalize the IG results to [-1, 1] range after computation
    max_abs_val = max(abs(node_attribute.min()), abs(node_attribute.max()))
    if max_abs_val > 0:
        node_attribute = node_attribute / max_abs_val  # Scale to [-1, 1]
    norm = colors.Normalize(vmin=-1, vmax=1)
    cmap = plt.colormaps.get_cmap("bwr")  # blue = negative, red = positive
    
    # Atom colors
    atom_colors = {i: cmap(norm(val))[:3] for i, val in enumerate(node_attribute)}
    
    # Bond colors: average of connected atoms
    bond_colors = {}
    for bond in mol.GetBonds():
        idx1, idx2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        avg_val = (node_attribute[idx1] + node_attribute[idx2]) / 2
        bond_colors[bond.GetIdx()] = cmap(norm(avg_val))[:3]
    
    plt.rcParams['font.family'] = 'Times New Roman'
    
    # Draw molecule with highlights on top
    drawer = rdMolDraw2D.MolDraw2DCairo(figsize[0], figsize[1])
    opts = drawer.drawOptions()
    opts.useBWAtomPalette()
    opts.bondLineWidth = 1.5
    opts.atomLabelFontSize = 16
    opts.highlightColour = (1.0, 1.0, 1.0)  # Make highlights more transparent
    opts.fillHighlights = True
    opts.highlightsAreCircles = False
    
    drawer.DrawMolecule(
        mol,
        highlightAtoms=list(atom_colors.keys()), 
        highlightAtomColors=atom_colors,
        highlightBonds=list(bond_colors.keys()), 
        highlightBondColors=bond_colors
    )
    drawer.FinishDrawing()
    
    png = drawer.GetDrawingText()
    image = Image.open(io.BytesIO(png)).convert("RGB")
    image = trim_white(image)
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    
    if show:
        fig, ax = plt.subplots(figsize=(6, 6), dpi=150)
        im = ax.imshow(image)
        ax.axis('off')
        ax.set_title(f"{molecule_data.name} — Type: {molecule_data.type}", pad=25)
        
        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
        cbar.set_label('Signed Attribution', fontsize=12)
        # Set major ticks with interval of 0.5
        from matplotlib.ticker import MultipleLocator
        cbar.ax.xaxis.set_major_locator(MultipleLocator(0.5))
        if save_path:
            plt.savefig(f"{save_path}.png", dpi=300, bbox_inches='tight')
        plt.show()
    
    return image

# === Run visualization ===
for molecule_name in molecule_node_attributions:
    mol_data = next(d for d in all_data if d.name == molecule_name)
    mol_type = mol_data.type
    save_path = f"IntegratedGradients_fig/all/{molecule_name}_{mol_type}"
    
    visualize_combined_molecule(
        mol_data, 
        molecule_node_attributions[molecule_name], 
        save_path=save_path
    )

# INTEGRATED GRADIENT ATTRIBUTION (TRAIN)

In [None]:
unique_named_graphs = {}
for graph in original_train_data:
    name = graph['name']
    if name not in unique_named_graphs:
        unique_named_graphs[name] = graph
unique_graph_list = list(unique_named_graphs.values())
mol_loader = DataLoader(unique_graph_list, batch_size=batch_size, shuffle=False)
# Execute the EmbeddingsExtractor
embedding_generator = EmbeddingsExtractor(model).to(device)
embedding_generator.eval()
all_embeddings = []
all_amines = []
all_data = []
with torch.no_grad():
    for data in mol_loader:
        data = data.to(device)
        combined_embedding = embedding_generator(data, 
                                                 include_conditions=False, 
                                                 extract_embeddings=True)
        all_embeddings.append(combined_embedding.cpu())
        all_amines.extend(data.name)
        all_data.extend(data.cpu().to_data_list())

# Now after the loop
all_embeddings = torch.cat(all_embeddings)
index_to_name = {i: name for i, name in enumerate(all_amines)}

In [None]:
from captum.attr import IntegratedGradients
import torch
import numpy as np
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from matplotlib import cm, colors
import matplotlib.pyplot as plt
from PIL import Image
import io

model.eval()

# Modified forward: return pooled embedding vector (all dims)
def embedding_forward_all(x, data):
    x = embedding_generator.graph_block(x, data.edge_index, data.edge_attr)
    return x  # no indexing — return full embedding

molecule_node_attributions = {}

for molecule_name in [d.name for d in all_data]:
    print(f"\n=== Captum Analysis for {molecule_name} (ALL embeddings) ===")
    
    graph_data = next(data for data in all_data if data.name == molecule_name)
    graph_data = graph_data.to(device)
    
    x = graph_data.x.clone().detach().float().to(device)
    x.requires_grad = True
    baseline = torch.zeros_like(x)
    
    ig = IntegratedGradients(lambda x, data: embedding_forward_all(x, data).sum(dim=1))  
    # sum over embedding dims → scalar output
    
    attr = ig.attribute(
        inputs=x,
        baselines=baseline,
        additional_forward_args=(graph_data,),
        n_steps=50
    )  # shape: (num_nodes, num_node_features)
    
    # Aggregate feature-level attributions per node → scalar signed attribution
    node_signed = attr.sum(dim=1).detach().cpu().numpy()
    node_magnitude = np.abs(node_signed)
    
    molecule_node_attributions[molecule_name] = {
        'signed': node_signed,
        'magnitude': node_magnitude
    }


In [None]:
from pathlib import Path
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from PIL import Image, ImageChops
import io

def trim_white(img):
    """Crop extra white space around the molecule drawing."""
    bg = Image.new(img.mode, img.size, (255, 255, 255))
    diff = ImageChops.difference(img, bg)
    bbox = diff.getbbox()
    if bbox:
        return img.crop(bbox)
    return img

# === Modified IG Visualization ===
def visualize_combined_molecule(molecule_data, node_attributions, figsize=(800, 800), save_path=None, show=True):
    mol = molecule_data.mol
    if mol is None:
        print(f"No RDKit molecule available for {molecule_data.name}")
        return
    
    mol = Chem.Mol(mol)
    if mol.GetNumConformers() == 0:
        AllChem.Compute2DCoords(mol)
    
    node_attribute = node_attributions['signed']
    # Normalize the IG results to [-1, 1] range after computation
    max_abs_val = max(abs(node_attribute.min()), abs(node_attribute.max()))
    if max_abs_val > 0:
        node_attribute = node_attribute / max_abs_val  # Scale to [-1, 1]
    norm = colors.Normalize(vmin=-1, vmax=1)
    cmap = plt.colormaps.get_cmap("bwr")  # blue = negative, red = positive
    
    # Atom colors
    atom_colors = {i: cmap(norm(val))[:3] for i, val in enumerate(node_attribute)}
    
    # Bond colors: average of connected atoms
    bond_colors = {}
    for bond in mol.GetBonds():
        idx1, idx2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        avg_val = (node_attribute[idx1] + node_attribute[idx2]) / 2
        bond_colors[bond.GetIdx()] = cmap(norm(avg_val))[:3]
    
    plt.rcParams['font.family'] = 'Times New Roman'
    
    # Draw molecule with highlights on top
    drawer = rdMolDraw2D.MolDraw2DCairo(figsize[0], figsize[1])
    opts = drawer.drawOptions()
    opts.useBWAtomPalette()
    opts.bondLineWidth = 1.5
    opts.atomLabelFontSize = 16
    opts.highlightColour = (1.0, 1.0, 1.0)  # Make highlights more transparent
    opts.fillHighlights = True
    opts.highlightsAreCircles = False
    
    drawer.DrawMolecule(
        mol,
        highlightAtoms=list(atom_colors.keys()), 
        highlightAtomColors=atom_colors,
        highlightBonds=list(bond_colors.keys()), 
        highlightBondColors=bond_colors
    )
    drawer.FinishDrawing()
    
    png = drawer.GetDrawingText()
    image = Image.open(io.BytesIO(png)).convert("RGB")
    image = trim_white(image)
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    
    if show:
        fig, ax = plt.subplots(figsize=(6, 6), dpi=150)
        im = ax.imshow(image)
        ax.axis('off')
        ax.set_title(f"{molecule_data.name} — Type: {molecule_data.type}", pad=25)
        
        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
        cbar.set_label('Signed Attribution', fontsize=12)
        # Set major ticks with interval of 0.5
        from matplotlib.ticker import MultipleLocator
        cbar.ax.xaxis.set_major_locator(MultipleLocator(0.5))
        if save_path:
            plt.savefig(f"{save_path}.png", dpi=300, bbox_inches='tight')
        plt.show()
    
    return image

# === Run visualization ===
for molecule_name in molecule_node_attributions:
    mol_data = next(d for d in all_data if d.name == molecule_name)
    mol_type = mol_data.type
    save_path = f"IntegratedGradients_fig/train/{molecule_name}_{mol_type}"
    
    visualize_combined_molecule(
        mol_data, 
        molecule_node_attributions[molecule_name], 
        save_path=save_path
    )

# INTEGRATED GRADIENT ATTRIBUTION (VAL)

In [None]:
unique_named_graphs = {}
for graph in original_val_data:
    name = graph['name']
    if name not in unique_named_graphs:
        unique_named_graphs[name] = graph
unique_graph_list = list(unique_named_graphs.values())
mol_loader = DataLoader(unique_graph_list, batch_size=batch_size, shuffle=False)
# Execute the EmbeddingsExtractor
embedding_generator = EmbeddingsExtractor(model).to(device)
embedding_generator.eval()
all_embeddings = []
all_amines = []
all_data = []
with torch.no_grad():
    for data in mol_loader:
        data = data.to(device)
        combined_embedding = embedding_generator(data, 
                                                 include_conditions=False, 
                                                 extract_embeddings=True)
        all_embeddings.append(combined_embedding.cpu())
        all_amines.extend(data.name)
        all_data.extend(data.cpu().to_data_list())

# Now after the loop
all_embeddings = torch.cat(all_embeddings)
index_to_name = {i: name for i, name in enumerate(all_amines)}

In [None]:
from captum.attr import IntegratedGradients
import torch
import numpy as np
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from matplotlib import cm, colors
import matplotlib.pyplot as plt
from PIL import Image
import io

model.eval()

# Modified forward: return pooled embedding vector (all dims)
def embedding_forward_all(x, data):
    x = embedding_generator.graph_block(x, data.edge_index, data.edge_attr)
    return x  # no indexing — return full embedding

molecule_node_attributions = {}

for molecule_name in [d.name for d in all_data]:
    print(f"\n=== Captum Analysis for {molecule_name} (ALL embeddings) ===")
    
    graph_data = next(data for data in all_data if data.name == molecule_name)
    graph_data = graph_data.to(device)
    
    x = graph_data.x.clone().detach().float().to(device)
    x.requires_grad = True
    baseline = torch.zeros_like(x)
    
    ig = IntegratedGradients(lambda x, data: embedding_forward_all(x, data).sum(dim=1))  
    # sum over embedding dims → scalar output
    
    attr = ig.attribute(
        inputs=x,
        baselines=baseline,
        additional_forward_args=(graph_data,),
        n_steps=50
    )  # shape: (num_nodes, num_node_features)
    
    # Aggregate feature-level attributions per node → scalar signed attribution
    node_signed = attr.sum(dim=1).detach().cpu().numpy()
    node_magnitude = np.abs(node_signed)
    
    molecule_node_attributions[molecule_name] = {
        'signed': node_signed,
        'magnitude': node_magnitude
    }


In [None]:
from pathlib import Path
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from PIL import Image, ImageChops
import io

def trim_white(img):
    """Crop extra white space around the molecule drawing."""
    bg = Image.new(img.mode, img.size, (255, 255, 255))
    diff = ImageChops.difference(img, bg)
    bbox = diff.getbbox()
    if bbox:
        return img.crop(bbox)
    return img

# === Modified IG Visualization ===
def visualize_combined_molecule(molecule_data, node_attributions, figsize=(800, 800), save_path=None, show=True):
    mol = molecule_data.mol
    if mol is None:
        print(f"No RDKit molecule available for {molecule_data.name}")
        return
    
    mol = Chem.Mol(mol)
    if mol.GetNumConformers() == 0:
        AllChem.Compute2DCoords(mol)
    
    node_attribute = node_attributions['signed']
    # Normalize the IG results to [-1, 1] range after computation
    max_abs_val = max(abs(node_attribute.min()), abs(node_attribute.max()))
    if max_abs_val > 0:
        node_attribute = node_attribute / max_abs_val  # Scale to [-1, 1]
    norm = colors.Normalize(vmin=-1, vmax=1)
    cmap = plt.colormaps.get_cmap("bwr")  # blue = negative, red = positive
    
    # Atom colors
    atom_colors = {i: cmap(norm(val))[:3] for i, val in enumerate(node_attribute)}
    
    # Bond colors: average of connected atoms
    bond_colors = {}
    for bond in mol.GetBonds():
        idx1, idx2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        avg_val = (node_attribute[idx1] + node_attribute[idx2]) / 2
        bond_colors[bond.GetIdx()] = cmap(norm(avg_val))[:3]
    
    plt.rcParams['font.family'] = 'Times New Roman'
    
    # Draw molecule with highlights on top
    drawer = rdMolDraw2D.MolDraw2DCairo(figsize[0], figsize[1])
    opts = drawer.drawOptions()
    opts.useBWAtomPalette()
    opts.bondLineWidth = 1.5
    opts.atomLabelFontSize = 16
    opts.highlightColour = (1.0, 1.0, 1.0)  # Make highlights more transparent
    opts.fillHighlights = True
    opts.highlightsAreCircles = False
    
    drawer.DrawMolecule(
        mol,
        highlightAtoms=list(atom_colors.keys()), 
        highlightAtomColors=atom_colors,
        highlightBonds=list(bond_colors.keys()), 
        highlightBondColors=bond_colors
    )
    drawer.FinishDrawing()
    
    png = drawer.GetDrawingText()
    image = Image.open(io.BytesIO(png)).convert("RGB")
    image = trim_white(image)
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    
    if show:
        fig, ax = plt.subplots(figsize=(6, 6), dpi=150)
        im = ax.imshow(image)
        ax.axis('off')
        ax.set_title(f"{molecule_data.name} — Type: {molecule_data.type}", pad=25)
        
        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
        cbar.set_label('Signed Attribution', fontsize=12)
        # Set major ticks with interval of 0.5
        from matplotlib.ticker import MultipleLocator
        cbar.ax.xaxis.set_major_locator(MultipleLocator(0.5))
        if save_path:
            plt.savefig(f"{save_path}.png", dpi=300, bbox_inches='tight')
        plt.show()
    
    return image

# === Run visualization ===
for molecule_name in molecule_node_attributions:
    mol_data = next(d for d in all_data if d.name == molecule_name)
    mol_type = mol_data.type
    save_path = f"IntegratedGradients_fig/val/{molecule_name}_{mol_type}"
    
    visualize_combined_molecule(
        mol_data, 
        molecule_node_attributions[molecule_name], 
        save_path=save_path
    )

# INTEGRATED GRADIENT ATTRIBUTION (TEST)

In [None]:
unique_named_graphs = {}
for graph in original_test_data:
    name = graph['name']
    if name not in unique_named_graphs:
        unique_named_graphs[name] = graph
unique_graph_list = list(unique_named_graphs.values())
mol_loader = DataLoader(unique_graph_list, batch_size=batch_size, shuffle=False)
# Execute the EmbeddingsExtractor
embedding_generator = EmbeddingsExtractor(model).to(device)
embedding_generator.eval()
all_embeddings = []
all_amines = []
all_data = []
with torch.no_grad():
    for data in mol_loader:
        data = data.to(device)
        combined_embedding = embedding_generator(data, 
                                                 include_conditions=False, 
                                                 extract_embeddings=True)
        all_embeddings.append(combined_embedding.cpu())
        all_amines.extend(data.name)
        all_data.extend(data.cpu().to_data_list())

# Now after the loop
all_embeddings = torch.cat(all_embeddings)
index_to_name = {i: name for i, name in enumerate(all_amines)}

In [None]:
from captum.attr import IntegratedGradients
import torch
import numpy as np

model.eval()

def model_output_with_node_input(x_nodes, data):
    """
    Forward wrapper for Captum IG:
    - x_nodes: node features (num_nodes x node_dim)
    - data: PyG Data object with edge_index, edge_attr, scalars
    Returns: tensor of shape [1,1] for a single molecule
    """
    data_modified = data.clone()
    data_modified.x = x_nodes

    # Ensure scalar features are tensors on the correct device
    device = x_nodes.device
    data_modified.conc = torch.tensor(data.conc, dtype=torch.float, device=device)
    data_modified.temp = torch.tensor(data.temp, dtype=torch.float, device=device)
    data_modified.pco2 = torch.tensor(data.pco2, dtype=torch.float, device=device)

    out = model(data_modified)  # shape: [1,1]
    
    # Ensure output has at least 2 dims (batch x output)
    if out.dim() == 0:
        out = out.unsqueeze(0).unsqueeze(1)
    elif out.dim() == 1:
        out = out.unsqueeze(0)
    
    return out

# Dictionary to store node-level attributions
molecule_node_attributions = {}

for molecule_name in [d.name for d in all_data]:
    print(f"\n=== Captum Analysis for {molecule_name} (Final output) ===")
    
    data = next(d for d in all_data if d.name == molecule_name)
    data = data.to(device)
    
    x = data.x.clone().detach().float().to(device)
    x.requires_grad = True
    
    # Zero baseline for node features
    baseline = torch.zeros_like(x)
    
    ig = IntegratedGradients(model_output_with_node_input)
    
    # Compute IG
    attr = ig.attribute(
        inputs=x,
        baselines=baseline,
        additional_forward_args=(data,),
        n_steps=50
    )  # shape: (num_nodes, num_node_features)
    
    # Aggregate feature-level attributions per node
    node_signed = attr.sum(dim=1).detach().cpu().numpy()      # signed attributions
    node_magnitude = np.abs(node_signed)                     # magnitude
    
    molecule_node_attributions[molecule_name] = {
        'signed': node_signed,
        'magnitude': node_magnitude
    }

    print(f"Node attributions (signed): {node_signed}")


In [None]:
from pathlib import Path
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from PIL import Image, ImageChops
import io

def trim_white(img):
    """Crop extra white space around the molecule drawing."""
    bg = Image.new(img.mode, img.size, (255, 255, 255))
    diff = ImageChops.difference(img, bg)
    bbox = diff.getbbox()
    if bbox:
        return img.crop(bbox)
    return img

# === Modified IG Visualization ===
def visualize_combined_molecule(molecule_data, node_attributions, figsize=(800, 800), save_path=None, show=True):
    mol = molecule_data.mol
    if mol is None:
        print(f"No RDKit molecule available for {molecule_data.name}")
        return
    
    mol = Chem.Mol(mol)
    if mol.GetNumConformers() == 0:
        AllChem.Compute2DCoords(mol)
    
    node_attribute = node_attributions['signed']
    # Normalize the IG results to [-1, 1] range after computation
    max_abs_val = max(abs(node_attribute.min()), abs(node_attribute.max()))
    if max_abs_val > 0:
        node_attribute = node_attribute / max_abs_val  # Scale to [-1, 1]
    norm = colors.Normalize(vmin=-1, vmax=1)
    cmap = plt.colormaps.get_cmap("bwr")  # blue = negative, red = positive
    
    # Atom colors
    atom_colors = {i: cmap(norm(val))[:3] for i, val in enumerate(node_attribute)}
    
    # Bond colors: average of connected atoms
    bond_colors = {}
    for bond in mol.GetBonds():
        idx1, idx2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        avg_val = (node_attribute[idx1] + node_attribute[idx2]) / 2
        bond_colors[bond.GetIdx()] = cmap(norm(avg_val))[:3]
    
    plt.rcParams['font.family'] = 'Times New Roman'
    
    # Draw molecule with highlights on top
    drawer = rdMolDraw2D.MolDraw2DCairo(figsize[0], figsize[1])
    opts = drawer.drawOptions()
    opts.useBWAtomPalette()
    opts.bondLineWidth = 1.5
    opts.atomLabelFontSize = 16
    opts.highlightColour = (1.0, 1.0, 1.0)  # Make highlights more transparent
    opts.fillHighlights = True
    opts.highlightsAreCircles = False
    
    drawer.DrawMolecule(
        mol,
        highlightAtoms=list(atom_colors.keys()), 
        highlightAtomColors=atom_colors,
        highlightBonds=list(bond_colors.keys()), 
        highlightBondColors=bond_colors
    )
    drawer.FinishDrawing()
    
    png = drawer.GetDrawingText()
    image = Image.open(io.BytesIO(png)).convert("RGB")
    image = trim_white(image)
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    
    if show:
        fig, ax = plt.subplots(figsize=(6, 6), dpi=150)
        im = ax.imshow(image)
        ax.axis('off')
        ax.set_title(f"{molecule_data.name} — Type: {molecule_data.type}", pad=25)
        
        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', fraction=0.046, pad=0.04)
        cbar.set_label('Signed Attribution', fontsize=12)
        # Set major ticks with interval of 0.5
        from matplotlib.ticker import MultipleLocator
        cbar.ax.xaxis.set_major_locator(MultipleLocator(0.5))
        if save_path:
            plt.savefig(f"{save_path}.png", dpi=300, bbox_inches='tight')
        plt.show()
    
    return image

# === Run visualization ===
for molecule_name in molecule_node_attributions:
    mol_data = next(d for d in all_data if d.name == molecule_name)
    mol_type = mol_data.type
    save_path = f"IntegratedGradients_fig/test/{molecule_name}_{mol_type}"
    
    visualize_combined_molecule(
        mol_data, 
        molecule_node_attributions[molecule_name], 
        save_path=save_path
    )

# BACKUP

In [None]:
""" import numpy as np
if not hasattr(np, "bool"):
    np.bool = np.bool_
import matplotlib.pyplot as plt
import seaborn as sns

# Step 1: Stack all embeddings into a 2D array (num_molecules x 64)
embeddings = np.stack([d.embedding if isinstance(d.embedding, np.ndarray) else d.embedding.numpy()
                       for d in unique_graph_list_with_embeddings])

# Step 2: Compute correlation matrix (64 x 64)
corr_matrix = np.corrcoef(embeddings, rowvar=False)

# Step 3: Plot heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, cmap='coolwarm', center=0, square=True)
plt.xlabel("Embedding Dimension")
plt.ylabel("Embedding Dimension")
plt.show()
# Extract off-diagonal correlations
off_diag_corrs = corr_matrix[np.triu_indices_from(corr_matrix, k=1)]

# Test against null hypothesis of orthogonality (mean ≈ 0)
from scipy import stats
t_stat, p_value = stats.ttest_1samp(off_diag_corrs, 0)

# Also check distribution statistics
print(f"Mean absolute correlation: {np.mean(np.abs(off_diag_corrs)):.3f}")
print(f"Std of correlations: {np.std(off_diag_corrs):.3f}") """

In [None]:
""" # Quick significance test you can run:
from scipy import stats

# Test if mean significantly different from orthogonal (0)
t_stat, p_value = stats.ttest_1samp(np.abs(off_diag_corrs), 0)
print(f"P-value for non-orthogonality: {p_value}")

# Also check what fraction are "strong" correlations
strong_corr_fraction = np.mean(np.abs(off_diag_corrs) > 0.3)
print(f"Fraction with |r| > 0.3: {strong_corr_fraction:.2%}") """