In [None]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("..")
print(os.getcwd())
import torch

from torch_geometric.loader import DataLoader
import numpy as np
import random

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
seed = 21
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
from modules.data_pipeline import DataPipeline
pipeline = DataPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')
pipeline.save_canonical_df(canonical_data, 'datasets/canonical_data.csv')

In [None]:
import modules.datasplit_module as dsm
# --- Split graphs ---
random.shuffle(graph_list)
sampled_graph_list = graph_list
train, val, test = \
    dsm.system_disjoint_split(sampled_graph_list, random_state=seed, stratify_by_components=True)

In [None]:
import pandas as pd
compset = pd.read_csv('datasets/components.csv')

In [None]:
from rdkit import Chem 
atoms = {}

for smi in compset["smiles_can"].dropna():
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        continue
    for atom in mol.GetAtoms():
        atoms[atom.GetSymbol()] = atom.GetAtomicNum()

# symbols only, sorted by atomic number
sorted_atoms = [sym for sym, Z in sorted(atoms.items(), key=lambda x: x[1])]

print(sorted_atoms)

In [None]:
sample = train[0]
print("=== Node-level (Structural) ===")
print("x.shape:", sample.x.shape)
print("edge_index.shape:", sample.edge_index.shape)
print("edge_attr.shape:", sample.edge_attr.shape)
print("mol_batch.shape:", sample.mol_batch.shape)
print("mol_batch unique:", sample.mol_batch.unique())

print("\n=== Component-level (Contextual) ===")
print("component_mole_frac:", sample.component_mole_frac)
print("component_mole_frac.shape:", sample.component_mole_frac.shape)
print("component_names:", sample.component_names)
print("Number of components:", len(sample.component_names))
print("component_gammas.shape:", sample.component_gammas.shape)

print("\n=== Batch structure check ===")
print("component_batch:", sample.component_batch)
print("component_batch.shape:", sample.component_batch.shape)
print("Does component_batch match num_components?", 
      sample.component_batch.shape[0] == len(sample.component_names))

In [None]:
loader = DataLoader(train[:3], batch_size=3, shuffle=False, follow_batch=['component_batch'])
batch = next(iter(loader))

print("="*70)
print("INTUITIVE BATCH STRUCTURE WITH NAMES")
print("="*70)

# Show individual systems
print("\n### INDIVIDUAL SYSTEMS (before batching) ###\n")
for i in range(3):
    g = train[i]
    print(f"System {i}: {' / '.join(g.component_names)}")
    print(f"  Type: {g.system_type}")
    print(f"  Components: {g.component_names}")
    print(f"  Mole fractions: {g.component_mole_frac.tolist()}")
    print(f"  Atoms: {g.x.shape[0]}, Molecules: {len(g.component_names)}")
    print(f"  mol_batch unique: {g.mol_batch.unique().tolist()}")
    print()

# Show the batched structure
print("\n### BATCHED STRUCTURE (what the model sees) ###\n")

print(f"Total atoms in batch: {batch.x.shape[0]}")
print(f"Total molecules in batch: {batch.mol_batch.unique().shape[0]}")
print(f"Total systems in batch: {batch.batch.unique().shape[0]}")

print("\n--- MOLECULE-LEVEL VIEW ---")
for mol_id in batch.mol_batch.unique():
    # Find which atoms belong to this molecule
    atom_mask = batch.mol_batch == mol_id
    num_atoms = atom_mask.sum().item()
    
    # Find which system this molecule belongs to
    system_id = batch.batch[atom_mask][0].item()
    
    # Find the component index within the system
    # We need to map global molecule ID to component ID
    atom_indices = torch.where(atom_mask)[0]
    first_atom = atom_indices[0].item()
    sys_id = batch.batch[first_atom].item()
    
    print(f"Molecule {mol_id}: {num_atoms} atoms, belongs to System {system_id}")

print("\n--- COMPONENT-LEVEL VIEW (with names!) ---")
comp_ptr = batch.component_batch_ptr
for sys_idx in range(len(comp_ptr) - 1):
    start = comp_ptr[sys_idx]
    end = comp_ptr[sys_idx + 1]
    
    print(f"\nSystem {sys_idx}:")
    print(f"  Components: {batch.component_names[sys_idx]}")
    print(f"  Component indices in batch: [{start}:{end}]")
    
    for comp_local_idx, comp_global_idx in enumerate(range(start, end)):
        comp_name = batch.component_names[sys_idx][comp_local_idx]
        mole_frac = batch.component_mole_frac[comp_global_idx].item()
        print(f"    Component {comp_global_idx} (local {comp_local_idx}): {comp_name}, x={mole_frac:.4f}")

print("\n--- SYSTEM-LEVEL VIEW ---")
for sys_idx in range(batch.batch.max().item() + 1):
    # Get system info
    system_name = ' / '.join(batch.component_names[sys_idx])
    
    # Count atoms in this system
    atom_mask = batch.batch == sys_idx
    num_atoms = atom_mask.sum().item()
    
    # Get molecules in this system
    mol_ids = batch.mol_batch[atom_mask].unique()
    
    # Get components
    comp_mask = batch.component_batch_batch == sys_idx
    comp_indices = torch.where(comp_mask)[0]
    
    print(f"\nSystem {sys_idx}: {system_name}")
    print(f"  Atoms: {num_atoms}")
    print(f"  Molecules (global IDs): {mol_ids.tolist()}")
    print(f"  Components (global IDs): {comp_indices.tolist()}")

In [None]:
loader = DataLoader(train[:3], batch_size=5, shuffle=False, follow_batch=['component_batch'])
batch = next(iter(loader))

print("=== Complete batch structure ===")
print(f"Batch size: {batch.batch.max().item() + 1} systems")
print(f"Total atoms: {batch.x.shape[0]}")
print(f"Total components: {batch.component_batch.shape[0]}")

print("\n=== Critical mappings ===")
print("1. batch.batch (atoms→systems):", batch.batch.shape)
print("2. batch.mol_batch (atoms→molecules):", batch.mol_batch.shape)
print("3. batch.component_batch_batch (components→systems):", batch.component_batch_batch.shape)
print("4. batch.component_batch_ptr:", batch.component_batch_ptr)

print("\n=== Verify per-system aggregation ===")
for sys_id in range(min(3, batch.batch.max().item() + 1)):
    comp_mask = batch.component_batch_batch == sys_id
    print(f"\nSystem {sys_id}:")
    print(f"  Components: {comp_mask.sum().item()}")
    print(f"  Mole fracs: {batch.component_mole_frac[comp_mask]}")
    print(f"  Sum: {batch.component_mole_frac[comp_mask].sum():.6f}")

In [None]:
loader = DataLoader(train[:3], batch_size=3, shuffle=False, follow_batch=['component_batch'])
batch = next(iter(loader))

print("=== Same size, different meanings ===")
print(f"batch.batch.shape: {batch.batch.shape}")
print(f"batch.mol_batch.shape: {batch.mol_batch.shape}")

print("\n=== batch.batch (atoms → systems) ===")
print(batch.batch)
print(f"Unique values: {batch.batch.unique()}")
print("Meaning: Which SYSTEM does each atom belong to?")

print("\n=== batch.mol_batch (atoms → molecules) ===")
print(batch.mol_batch)
print(f"Unique values: {batch.mol_batch.unique()}")
print("Meaning: Which MOLECULE does each atom belong to?")

print("\n=== Key difference ===")
print("System 0 has how many molecules?")
sys0_atoms = batch.batch == 0
sys0_molecules = batch.mol_batch[sys0_atoms].unique()
print(f"  System 0 molecules: {sys0_molecules}")

print("\nSystem 1 has how many molecules?")
sys1_atoms = batch.batch == 1
sys1_molecules = batch.mol_batch[sys1_atoms].unique()
print(f"  System 1 molecules: {sys1_molecules}")

print("\n=== Insight ===")
print("mol_batch indices are GLOBAL (0,1,2,3,4,5,6...)")
print("batch indices are SYSTEM IDs (0,0,0..., 1,1,1..., 2,2,2...)")

In [None]:
loader = DataLoader(train[:3], batch_size=3, shuffle=False, follow_batch=['component_batch'])
batch = next(iter(loader))

print("=== Individual graphs ===")
for i in range(3):
    g = train[i]
    print(f"\nGraph {i}:")
    print(f"  Atoms: {g.x.shape[0]}")
    print(f"  Edges: {g.edge_index.shape[1]}")
    print(f"  edge_index min: {g.edge_index.min()}, max: {g.edge_index.max()}")
    print(f"  Edge range: [0, {g.x.shape[0]-1}] (should match atom count)")

print("\n=== Batched structure ===")
print(f"Total atoms: {batch.x.shape[0]}")
print(f"Total edges: {batch.edge_index.shape[1]}")
print(f"edge_index min: {batch.edge_index.min()}, max: {batch.edge_index.max()}")
print(f"Edge range: [0, {batch.x.shape[0]-1}] expected")

print("\n=== Are edges correctly offset? ===")
print("First 10 edges in batch:")
print(batch.edge_index[:, :10])

print("\nLast 10 edges in batch:")
print(batch.edge_index[:, -10:])

print("\n=== Check cross-system edges (should be NONE!) ===")
# Edges should only connect atoms within same system
edge_src_system = batch.batch[batch.edge_index[0]]
edge_dst_system = batch.batch[batch.edge_index[1]]
cross_system_edges = (edge_src_system != edge_dst_system).sum()
print(f"Cross-system edges: {cross_system_edges} (should be 0!)")

print("\n=== Check cross-molecule edges ===")
edge_src_mol = batch.mol_batch[batch.edge_index[0]]
edge_dst_mol = batch.mol_batch[batch.edge_index[1]]
cross_mol_edges = (edge_src_mol != edge_dst_mol).sum()
print(f"Cross-molecule edges: {cross_mol_edges} (should be 0!)")