# LIBRARY IMPORT

In [None]:
# IMPORTING LIBRARIES
import copy
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'
from torch_geometric.loader import DataLoader
# FUNCTIONS
from data_processing import load_dataset, process_dataset
from path_helpers import get_path
from stats_compute import compute_statistics, scale_graphs
from EnhancedDataSplit import DataSplitter
from rdkit import Chem
import matplotlib.pyplot as plt
import seaborn as sns
# DIRECTORY SETUP
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

# PRE-ANALYSIS

In [None]:
# LOAD & GRAPH GENERATION
df_components = load_dataset(get_path(file_name = 'components_set.csv', folder_name='datasets'))
smiles_dict = dict(zip(df_components['Abbreviation'], df_components['SMILES']))
df_systems = load_dataset(get_path(file_name = 'systems_set.csv', folder_name='datasets'))
smiles_list = df_components["SMILES"].dropna().tolist()
mol_name_dict = smiles_dict.copy()
# GRAPH
system_graphs = process_dataset(df_systems, smiles_dict)
# LOAD DATASET
splitter = DataSplitter(system_graphs, random_state=42)
# Options: stratified_random_split (train val test uncontaminated) or rarity_aware_unseen_amine_split
train_data, val_data, test_data = splitter.rarity_aware_unseen_amine_split()
# Retrieve the statistics of train_data
stats = compute_statistics(train_data)
conc_mean = stats[0]
conc_std = stats[1]
temp_mean = stats[2]
temp_std = stats[3]
pco2_mean = stats[4]
pco2_std = stats[5]
#Apply the scaling to validation and test
original_train_data = copy.deepcopy(train_data)
original_val_data = copy.deepcopy(val_data)
original_test_data = copy.deepcopy(test_data)
combined_original_data = original_train_data + original_val_data + original_test_data
train_data = scale_graphs(train_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
val_data = scale_graphs(val_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
test_data = scale_graphs(test_data, conc_mean, conc_std, temp_mean, temp_std, pco2_mean, pco2_std)
#Load the data into DataLoader
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
unique_named_graphs = {}
for graph in combined_original_data:
    name = graph['name']
    if name not in unique_named_graphs:
        unique_named_graphs[name] = graph  # Keep the first occurrence

# Optional: convert to a list
unique_graph_list = list(unique_named_graphs.values())
unique_graph_list

In [None]:
# DATA ANALYTICS
# Duplicate cross-check
print("---------------------------------------------------------------------------------------")
def smiles_to_mol(smiles):
    return Chem.MolFromSmiles(smiles)

def standardize_smiles(smiles):
    mol = smiles_to_mol(smiles)
    return Chem.MolToSmiles(mol, canonical=True) if mol else None

def find_duplicate_molecules(smiles_list):
    standardized_smiles = {smiles: standardize_smiles(smiles) for smiles in smiles_list}
    unique_smiles = set()
    duplicates = []
    
    for original, canonical in standardized_smiles.items():
        if canonical:
            if canonical in unique_smiles:
                duplicates.append(original)
            else:
                unique_smiles.add(canonical)
    return duplicates

# Load dataset
smiles_list = df_components["SMILES"].dropna().tolist()
# Find duplicate molecules
duplicate_molecules = find_duplicate_molecules(smiles_list)
# Display results
if duplicate_molecules:
    print("Duplicate molecules found:")
    for smiles in duplicate_molecules:
        print(f"Duplicate: {smiles}")
else:
    print("No duplicate molecules found.")
print("---------------------------------------------------------------------------------------")
print(df_systems.head())  # Display the first few rows
print(df_systems.info())  # Check column types and missing values
print(df_systems.describe())  # Summary statistics
print("---------------------------------------------------------------------------------------")
for col in df_systems.select_dtypes(include='object').columns:
    print(f"{col}: {df_systems[col].nunique()} unique values")
    print(df_systems[col].unique()[:10])  # Show the first 10 unique values
print("---------------------------------------------------------------------------------------")
# Numerical data distribution
numerical_cols = df_systems.select_dtypes(include=['number']).columns
fig, axes = plt.subplots(1, len(numerical_cols), figsize=(5 * len(numerical_cols), 5))
if len(numerical_cols) == 1:
    axes = [axes]  # Ensure `axes` is iterable for a single column
for ax, col in zip(axes, numerical_cols):
    sns.histplot(df_systems[col], bins=30, kde=True, ax=ax)
    ax.set_title(f"Distribution of {col}")
plt.tight_layout()
plt.show()
print("---------------------------------------------------------------------------------------")
# AMINE TRAIN-VAL-TEST DISTRIBUTION
df_systems["System"].value_counts().plot(kind="bar", figsize=(12, 5))
plt.title("Number of Data Points per Amine System")
plt.xlabel("Amine System")
plt.ylabel("Count")
plt.xticks(rotation=90)
plt.show()

In [None]:
# Define the output directory three levels up
output_dir = os.path.join(
    os.path.dirname(os.getcwd()),
    "saved_plots"
)
os.makedirs(output_dir, exist_ok=True)
# Get unique amines (systems)
systems = df_systems['System'].unique()
# Iterate over each amine (system)
for system in systems:
    # Filter data for the current amine
    system_data = df_systems[df_systems['System'] == system]
    # Get unique concentrations for the current amine
    concentrations = system_data['Conc (mol/L)'].unique()
    # Create a new plot for each amine
    for conc in concentrations:
        conc_data = system_data[system_data['Conc (mol/L)'] == conc]
        # Get unique references for the current concentration
        references = conc_data['Reference'].unique()
        # Create the plot
        plt.figure(figsize=(12, 8))
        # Plot temperature isotherms for each reference
        for ref in references:
            ref_data = conc_data[conc_data['Reference'] == ref]
            # Plot for each temperature in the reference
            for temp in ref_data['T (K)'].unique():
                temp_data = ref_data[ref_data['T (K)'] == temp]
                plt.plot(temp_data['PCO2 (kPa)'], temp_data['aCO2 (mol CO2/mol amine)'], 
                         label=f'Ref = {ref}, T = {temp} K', marker='o', linestyle='-', markersize=6)
        # Add labels and title for each plot
        plt.title(f'{system} AT {conc} mol/L', fontsize=16)
        plt.xlabel('PCO2 (kPa)', fontsize=14)
        plt.ylabel('aCO2 (mol CO2/mol amine)', fontsize=14)
        plt.xscale('log')
        plt.legend(title='Reference & Temperature', fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')
        # Save the plot as an image file
        plot_filename = f"{system}_at_{conc}mol_L.png"
        plot_path = os.path.join(output_dir, plot_filename)
        plt.tight_layout()
        plt.savefig(plot_path)
        plt.close()  # Close the plot to free memory
print(f"Plots saved to: {output_dir}")

In [None]:
# Get the single Data object with the smallest number of atoms
smallest_graph = min(system_graphs, key=lambda data: data.x.shape[0])
print(f"Name: {smallest_graph.name}, Atom count: {smallest_graph.x.shape[0]}")
smallest_graph.edge_attr

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

# Function to visualize and save a PyTorch Geometric graph with a transparent background
def visualize_and_save_graph(graph, filename="graph.png", title="Molecular Graph"):
    # Convert PyG graph to NetworkX
    nx_graph = to_networkx(graph, node_attrs=["x"], edge_attrs=["edge_attr"])

    # Create figure with transparent background
    fig, ax = plt.subplots(figsize=(9, 5), facecolor="none")
    ax.set_facecolor("none")  # Ensure axes background is also transparent

    # Layout for better visualization
    pos = nx.spring_layout(nx_graph, seed=13)
    
    # Draw the graph
    nx.draw(nx_graph, pos, with_labels=False, node_size=1500, 
            node_color="lightblue", edge_color="gray", font_color="black", ax=ax)

    labels = {node: str(node) for node in nx_graph.nodes()}
    nx.draw_networkx_labels(nx_graph, pos, labels, font_family="Times New Roman", 
                            font_size=20, font_color="black", font_weight="bold")    
    plt.title(title, fontsize=14, fontweight="bold", color="black")
    
    # Save the figure with transparency
    #plt.savefig(filename, transparent=True, dpi=300)
    plt.show()

A = smallest_graph
# Visualize and save target graph
visualize_and_save_graph(A, filename="MEA_graph.png", title="MEA Molecular Graph")