In [None]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../..")
print(os.getcwd())
seed = 21

In [None]:
print("\n--- Loading Full Dataset ---")
from modules.amine_blend_pipeline import AmineBlendPipeline
import modules.datasplit_module as dsm
pipeline = AmineBlendPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')
pipeline.save_canonical_df(canonical_data, 'datasets/canonical_data.csv')

train_raw, val_raw, test_raw, train_std, val_std, test_std, stats = \
    dsm.standardized_system_disjoint_split(graph_list, random_state=seed)
train_df, val_df, test_df = dsm.system_disjoint_split(canonical_data, random_state=seed)

In [None]:
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_std, batch_size=64, shuffle=False)
val_loader = DataLoader(val_std, batch_size=64, shuffle=False)
test_loader = DataLoader(test_std, batch_size=64, shuffle=False)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

def compute_blend_size(df):
    df = df.copy()
    df['blend_size'] = df['component_list'].apply(
        lambda comps: sum(1 for c in comps if c.upper() not in ['H2O', 'WATER'])
    )
    return df
train_df = compute_blend_size(train_df)
val_df = compute_blend_size(val_df)
test_df = compute_blend_size(test_df)
system_counts = pd.DataFrame({
    'Train': train_df.groupby('blend_size')['canonical_system'].nunique(),
    'Val': val_df.groupby('blend_size')['canonical_system'].nunique(),
    'Test': test_df.groupby('blend_size')['canonical_system'].nunique()
}).fillna(0).astype(int)

system_counts = system_counts.sort_index()

plt.figure(figsize=(6, 6))
sns.heatmap(system_counts, annot=True, fmt='d', cmap='YlGnBu', cbar_kws={'label': 'Unique Systems'}, cbar=False)
plt.title('Blend Size Distribution Across Train/Val/Test')
plt.xlabel('Split')
plt.ylabel('Blend Size (# of amines)')
plt.yticks(rotation=0)
plt.show()
test_systems_df = test_df[['canonical_system', 'blend_size']].drop_duplicates().sort_values('blend_size')
display(test_systems_df)

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
import numpy as np

def visualize_graph(data, title='Graph', seed=2, edge_width=1.0):
    G = to_networkx(data, to_undirected=True)
    mol_batch = data.mol_batch.numpy()
    unique_mols = sorted(set(mol_batch))
    color_map = plt.get_cmap('tab10')
    node_colors = [color_map(i / len(unique_mols)) for i in mol_batch]
    node_labels = {}
    for i, mol_idx in enumerate(mol_batch):
        indices = np.where(mol_batch == mol_idx)[0]
        sub_idx = i - indices.min()
        node_labels[i] = f'{i}\n{sub_idx}'

    pos = nx.spring_layout(G, seed=seed, k=0.8, iterations=250)
    plt.figure(figsize=(10,8))
    nx.draw(G, pos, with_labels=False, node_size=400, node_color=node_colors, cmap=color_map, edge_color='gray', width=edge_width)
    nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10, font_color='black')

    handles = []
    for i, mol_idx in enumerate(unique_mols):
        color = color_map(i / len(unique_mols))
        mol_name = data.component_names[mol_idx] if hasattr(data, 'component_names') else f'Mol {mol_idx}'
        handle = plt.Line2D([0], [0], marker='o', color='w', label=mol_name, markersize=10, markerfacecolor=color)
        handles.append(handle)

    plt.legend(handles=handles, title='Molecules', loc='best')
    plt.title(title)
    plt.axis('off')
    plt.show()
viz_tech = graph_list[4465]
visualize_graph(viz_tech, title='Disjoint Union G', edge_width=1)


In [None]:
# --- COMBINATORIAL DIVERSITY + PRESENCE ANALYSIS ---
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations

def get_amines_excluding_water(components):
    return [c for c in components if c.upper() not in ['H2O', 'WATER']]

canonical_data['amines_only'] = canonical_data['component_list'].apply(get_amines_excluding_water)
canonical_data['blend_size'] = canonical_data['amines_only'].apply(len)
all_amines = sorted({amine for amines in canonical_data['amines_only'] for amine in amines})
n_amines = len(all_amines)
print(f"Unique amines (excluding water): {all_amines}")
print(f"Total amines: {n_amines}")

combinatorial_space = {}
for blend_size in range(1, min(4, n_amines+1)):
    combos = list(combinations(all_amines, blend_size))
    combinatorial_space[blend_size] = combos
    print(f"Blend size {blend_size}: {len(combos)} possible systems")

coverage = {}
for blend_size, combos in combinatorial_space.items():
    dataset_combos = set()
    for comps in canonical_data['amines_only']:
        if len(comps) == blend_size:
            dataset_combos.add(tuple(sorted(comps)))
    coverage[blend_size] = {
        'possible': len(combos),
        'observed': len(dataset_combos),
        'coverage_pct': len(dataset_combos)/len(combos)*100
    }

summary_df = pd.DataFrame.from_dict(coverage, orient='index').reset_index()
summary_df.rename(columns={'index':'blend_size'}, inplace=True)
print("\nCombinatorial coverage summary:")
display(summary_df)

blend_sizes = sorted(canonical_data['blend_size'].unique())
all_amines = sorted({amine for amines in canonical_data['amines_only'] for amine in amines})

presence_table = pd.DataFrame(0, index=all_amines, columns=blend_sizes)

for _, row in canonical_data.iterrows():
    for amine in row['amines_only']:
        presence_table.loc[amine, row['blend_size']] = 1
plt.figure(figsize=(12, max(6, len(all_amines)*0.3)))
sns.heatmap(
    presence_table,
    cmap='Greens',
    linewidths=0.5,
    annot=True,
    cbar=False
)
plt.xlabel('Blend Size')
plt.ylabel('Amine')
plt.title('Amine Presence Across Blend Sizes')
plt.tight_layout()
plt.show()

In [None]:
# --- Portion analysis ---
canonical_data['amines_only'] = canonical_data['component_list'].apply(
    lambda comps: [c for c in comps if c.upper() not in ['H2O', 'WATER']]
)
canonical_data['blend_size'] = canonical_data['amines_only'].apply(len)
blend_counts = canonical_data['blend_size'].value_counts().sort_index()
total_datapoints = len(canonical_data)
blend_portion = (blend_counts / total_datapoints * 100).to_frame(name='portion_pct')
blend_portion.index.name = 'blend_size'

print("Blend system representation (% of total datapoints):")
display(blend_portion)