In [None]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("..")
print(os.getcwd())
import torch

from torch_geometric.loader import DataLoader
import numpy as np
import random

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
seed = 21
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
from modules.data_pipeline import DataPipeline
pipeline = DataPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')

In [None]:
import modules.datasplit_module as dsm
# --- Split graphs ---
random.shuffle(graph_list)
sampled_graph_list = graph_list
train, val, test = \
    dsm.system_disjoint_split(sampled_graph_list, random_state=seed, stratify_by_components=True)

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    dataset=train[:1000],
    batch_size=512,
    shuffle=True,
    follow_batch=['component_batch']  # Creates component_batch_batch automatically!
)

val_loader = DataLoader(
    dataset=val[:1000],
    batch_size=512,
    shuffle=False,
    follow_batch=['component_batch']
)

test_loader = DataLoader(
    dataset=test[:1000],
    batch_size=512,
    shuffle=False,
    follow_batch=['component_batch']
)



In [None]:
from modules.dtmpnn import DTMPNN

# Example: take 5 mixtures from train
test_batch = DataLoader(train[:10], batch_size=5, shuffle=False, follow_batch=['component_batch'])
batch = next(iter(test_batch))

model = DTMPNN(
    node_dim=batch.x.shape[1],
    edge_dim=batch.edge_attr.shape[1],
    graph_hidden_dim=5,
    latent_dim=5,
    context_dim=5,
    graph_layers=2,
    track_grad=True,
    constraint_type='hard'
).cuda()

device = next(model.parameters()).device
batch = batch.to(device)

print("=== INPUT DATA STRUCTURE ===")
print(f"batch.x: {batch.x.shape}")
print(f"batch.edge_index: {batch.edge_index.shape}")
print(f"batch.edge_attr: {batch.edge_attr.shape}")
print(f"batch.mol_batch: {batch.mol_batch.shape}, unique: {batch.mol_batch.unique()}")
print(f"batch.component_batch: {batch.component_batch.shape}")
print(f"batch.component_batch_batch: {batch.component_batch_batch.shape}")
print(f"batch.component_mole_frac: {batch.component_mole_frac.shape}")

print("\n=== FORWARD PASS ===")
try:
    gamma_calc, latent_vectors, comp_emb = model(batch)
    print("[OK] Forward pass successful!")
    print(f"gamma_calc shape: {gamma_calc.shape}")
    print(f"latent_vectors shape: {latent_vectors.shape}")
    print(f"comp_emb shape: {comp_emb.shape}")
    print(f"gradient tracking: {comp_emb.shape}")
except Exception as e:
    print(f"[ERROR] Forward pass failed!")
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()

print("\n=== EXPECTED vs ACTUAL ===")
print(f"Expected predictions: {batch.component_gammas.shape[0]} systems")
print(f"Expected latent_vectors: {batch.component_batch.shape[0]} components")

print("\n=== GRADIENT TRACKING ===")
print("mole_frac.requires_grad =", batch.component_mole_frac.requires_grad)

=== INPUT DATA STRUCTURE ===
batch.x: torch.Size([146, 18])
batch.edge_index: torch.Size([2, 274])
batch.edge_attr: torch.Size([274, 4])
batch.mol_batch: torch.Size([146]), unique: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], device='cuda:0')
batch.component_batch: torch.Size([12])
batch.component_batch_batch: torch.Size([12])
batch.component_mole_frac: torch.Size([12])

=== FORWARD PASS ===
tensor([0.6084, 0.6077, 0.6123, 0.6148, 0.6072, 0.6154, 0.6060, 0.6100, 0.6096,
        0.6095, 0.6069, 0.6146], device='cuda:0', grad_fn=<SqueezeBackward1>)
[OK] Forward pass successful!
gamma_calc shape: torch.Size([12])
latent_vectors shape: torch.Size([12, 5])
comp_emb shape: torch.Size([12, 5])
gradient tracking: torch.Size([12, 5])

=== EXPECTED vs ACTUAL ===
Expected predictions: 12 systems
Expected latent_vectors: 12 components

=== GRADIENT TRACKING ===
mole_frac.requires_grad = True


In [None]:
# GDMPNN
from modules.dtmpnn import DTMPNN
from modules.loss_func import GibbsDuhemLoss
from torch_geometric.loader import DataLoader

# Example: take 5 mixtures from train
test_batch = DataLoader(train[:10], batch_size=5, shuffle=False, follow_batch=['component_batch'])
batch = next(iter(test_batch))
model = DTMPNN(
    node_dim=batch.x.shape[1],
    edge_dim=batch.edge_attr.shape[1],
    graph_hidden_dim=1,
    latent_dim=1,
    context_dim=1,
    graph_layers=1,
    track_grad=True,
    constraint_type='hard'
).cuda()

# Gibbs-Duhem loss
gd_loss_fn = GibbsDuhemLoss()
device = next(model.parameters()).device
batch = batch.to(device)

# Forward pass through the model
gamma_calc, latent_vectors, comp_emb = model(batch)

print("gamma_calc shape:", gamma_calc.shape)
print("latent_vectors shape:", latent_vectors.shape)
print("comp_emb shape:", comp_emb.shape)

print("\n=== Diagnostic Info ===")
print("Mole frac requires_grad:", batch.component_mole_frac.requires_grad)
print("Latent vectors requires_grad:", latent_vectors.requires_grad)
print("Number of unique batches:", torch.unique(batch.component_batch_batch).numel())
print("Components per batch:", [(batch.component_batch_batch == i).sum().item() 
                                 for i in torch.unique(batch.component_batch_batch)])
gd_loss = gd_loss_fn(batch, gamma_calc)
print("GD loss:", gd_loss.item())

In [None]:
def test_gibbs_duhem_losses(batch, gamma_calc):
    import torch.autograd as autograd
    """Test Gibbs-Duhem loss for different loss types."""
    loss_types = ['explicit', 'optimized']
    results = {}
    
    for loss_type in loss_types:
        print(f"\n=== Testing loss type: {loss_type} ===")
        gd_loss_fn = GibbsDuhemLoss(loss_type=loss_type)
        gd_loss = gd_loss_fn(batch, gamma_calc)
        loss_value = gd_loss.item()
        print(f"GD loss computed: {loss_value}")
        results[loss_type] = loss_value
    
    return results

results = test_gibbs_duhem_losses(batch, gamma_calc)