In [None]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../..")
print(os.getcwd())
import torch
import numpy as np
import random

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
seed = 21
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
from modules.data_pipeline import DataPipeline
pipeline = DataPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')

In [None]:
import modules.datasplit_module as dsm
# --- Split graphs ---
random.shuffle(graph_list)
sampled_graph_list = graph_list
train, val, test = \
    dsm.system_disjoint_split(sampled_graph_list, random_state=seed, stratify_by_components=True)

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    dataset=train,
    batch_size=1024,
    shuffle=True,
    follow_batch=['component_batch']
)

val_loader = DataLoader(
    dataset=val,
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

test_loader = DataLoader(
    dataset=test,
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

In [None]:
import modules.trainer_module as tm
import modules.dtmpnn as gm
device = 'cuda' if torch.cuda.is_available() else 'cpu'

""" 
    To involve GD 
    we need to set
    both gradient tracking and include to be True
"""
track_grad = True
include_gd = False
gd_weight = 1

# Create model
model = gm.DTMPNN(
    node_dim=train[0].x.shape[1],
    edge_dim=train[0].edge_attr.shape[1],
    graph_hidden_dim=256,
    latent_dim=256,
    context_dim=256,
    graph_layers=3,
    track_grad=True
).to(device)

# Initialize trainer
trainer = tm.DTMPNNTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    include_gd=include_gd,
    device=device,
    lr=0.0001,
    weight_decay=0,
    data_driven_weight=1.0,
    gd_weight=gd_weight
)

# Train the model
history = trainer.train(
    epochs=100,
    save_dir='notebooks/training_phase/single_checkpoints',
    log_file_path='notebooks/training_phase/single_checkpoints/training_run_log.txt',
    save_best=True,
    save_every=25
)

# Plot training curves
trainer.plot_history(save_path=f'notebooks/training_phase/single_checkpoints/training_history_GD_Backprop_{include_gd}.png')
torch.cuda.empty_cache()

In [None]:
trainer.plot_history()

In [None]:
# Test batching
for batch in train_loader:
    batch = batch.to('cuda')  # Move batch, not loader
    print("=== Batch Structure ===")
    print(f"Number of graphs (mixtures): {batch.num_graphs}")
    print(f"Total nodes: {batch.x.shape[0]}")
    print(f"Total components: {len(batch.component_mole_frac)}")
    print(f"Total gammas: {len(batch.component_gammas)}")
    print(f"\nmol_batch (node-to-molecule): {batch.mol_batch}")
    print(f"mol_batch unique values: {batch.mol_batch.unique()}")
    print(f"\ncomponent_batch_batch (component-to-mixture): {batch.component_batch_batch}")
    print(f"component_batch_batch unique: {batch.component_batch_batch.unique()}")
    
    # Run model
    y_pred, latent_vecs, comp_emb = model(batch)
    print(f"\n=== Model Outputs ===")
    print(f"y_pred shape: {y_pred.shape}")
    print(f"component_gammas shape: {batch.component_gammas.shape}")
    print(f"comp_emb shape: {comp_emb.shape}")
    print(f"latent_vecs shape: {latent_vecs.shape}")
    
    print(f"\n=== Alignment Check ===")
    print(f"Do shapes match? y_pred vs gammas: {y_pred.shape == batch.component_gammas.shape}")
    
    break  # Just test first batch

In [None]:
for batch in train_loader:
    batch = batch.to('cuda')
    y_pred, latent_vecs, comp_emb = model(batch)
    
    print("=== Value Check ===")
    print(f"y_pred: {y_pred}")
    print(f"y_true: {batch.component_gammas}")
    print(f"Difference: {y_pred - batch.component_gammas}")
    
    # Check if predictions are reasonable
    print(f"\ny_pred range: [{y_pred.min():.4f}, {y_pred.max():.4f}]")
    print(f"y_true range: [{batch.component_gammas.min():.4f}, {batch.component_gammas.max():.4f}]")
    
    # Compute loss manually
    loss = ((y_pred - batch.component_gammas) ** 2).mean()
    print(f"\nManual MSE: {loss.item():.6f}")
    
    break

In [None]:
""" import modules.trainer_module as tm
# Load best model and evaluate
print("Loading best model for final evaluation...")
print("="*70)
trainer.load_checkpoint('notebooks/training_phase/single_checkpoints/checkpoint_epoch_150.pt')
# --- Refactored Final Evaluation Script ---
def evaluate_and_print_results(trainer_obj, loader, set_name):
    """
    Runs validation and prints formatted results for a given data set,
    respecting the dynamic data loss name and GD printing symmetry.
    """
    _, data_loss, gd_loss, rmse, mae, r2, mape = trainer_obj.validate(loader)

    gd_print = f"Test GD      : {gd_loss:.4f}"


    data_loss_name = trainer_obj.datadriven_loss_name

    print(f"\nFinal Results {set_name}:")
    print(f"Test {data_loss_name:<8}: {data_loss:.4f}")
    print(gd_print)
    print(f"Test RMSE    : {rmse:.4f}")
    print(f"Test MAE     : {mae:.4f}")
    print(f"Test R2      : {r2:.4f}")
    print(f"Test MAPE    : {mape:.4f}%")
    print("="*70)


# Evaluate across all three sets using the new function
evaluate_and_print_results(trainer, trainer.train_loader, "TRAINING SET")
evaluate_and_print_results(trainer, trainer.val_loader, "VALIDATION SET")
evaluate_and_print_results(trainer, trainer.test_loader, "TESTING SET") """

In [None]:
""" # Create model
model = model
# --- Forward pass + backprop ---
trying_loader = DataLoader(
    dataset=test,
    batch_size=1,
    shuffle=False,
    follow_batch=['component_batch']
)
model.eval()
target_batch = 9
for i, batched_data in enumerate(trying_loader):
    if i == target_batch:
        batched_data = batched_data.to(device)
        y_pred, latent_vectors, comp_emb = model(batched_data)
        print("\n" + "="*120)
        print("Forward Pass inputs:")
        print(f"Component names                                     : {batched_data.component_names}")
        print(f"Component mole fractions                            : {batched_data.component_mole_frac}")
        print("\n" + "="*120)
        print("Forward Pass Outputs:")        

        print("\n" + "="*120)
        print("Solvent subsystem analysis:")            
        print(f"Aggregated latent vector (gamma_i)                  : {(latent_vectors.mean(dim=1).view(1,-1)).detach()}")
        print(f"Hypothetical Excess Gibss Energy (G_i_excess)       : {(latent_vectors.mean(dim=1).view(1,-1) * 8.314 * (batched_data.T * batched_data.T_std + batched_data.T_mean).view(-1,1)).detach()} J/mol")
        print(f"Hypothetical Sum of Excess Gibss Energy             : {((latent_vectors.mean(dim=1).view(1,-1) * 8.314 * (batched_data.T * batched_data.T_std + batched_data.T_mean).view(-1,1)).detach()).sum(dim=1)} J/mol")
        print(f"Aggregated components embedding                     : {(comp_emb.mean(dim=1).view(1,-1)).detach()}")
        break """