In [1]:
# Max Model: Delta-Learning for CCSD(T) Corrections from MP2
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score
import matplotlib.pyplot as plt
from ast import literal_eval
import shap

In [None]:
# Load and Preprocess the Dataset

# Load merged molecular data
df = pd.read_csv("merged_with_total_E.csv")

# Parse atomic information
# Convert the string list of atomic numbers into a Python list, and compute useful summary stats
df["atomic_numbers"] = df["atomic_numbers"].apply(literal_eval)
df["Z_sum"] = df["atomic_numbers"].apply(sum)
df["n_atoms"] = df["atomic_numbers"].apply(len)

# Define delta correlation energy target
# This is the correction needed to bring MP2 to CCSD(T)
df["delta_corr"] = df["corr_CCSDT"] - df["corr_MP2"]

# Select input features and target
feature_cols = [
    "mu", "alpha", "homo", "lumo", "gap", "r2",
    "mol_weight", "Z_sum", "n_atoms", "corr_MP2"
]
target_col = "delta_corr"

# Drop any rows with missing values
df = df.dropna(subset=feature_cols + [target_col])

# Extract and scale features and targets
X = df[feature_cols].values
y = df[target_col].values.reshape(-1, 1)

scaler_X = StandardScaler()
scaler_y = StandardScaler()
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)

# Convert to PyTorch tensors
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
y_tensor = torch.tensor(y_scaled, dtype=torch.float32)

In [None]:
# Dataset Splitting and Loading

# 80/20 train-test split
dataset = TensorDataset(X_tensor, y_tensor)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16)

In [None]:
 # Define Neural Network Architecture

class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# Train the Model

model = MLP(input_size=X_tensor.shape[1])
loss_fn = nn.SmoothL1Loss()  # Huber loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

train_losses = []
test_losses = []
best_loss = float("inf")
patience = 10
counter = 0

for epoch in range(70):
    model.train()
    running_loss = 0
    for xb, yb in train_loader:
        pred = model(xb)
        loss = loss_fn(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_loader)

    model.eval()
    test_loss = 0
    with torch.no_grad():
        for xb, yb in test_loader:
            pred = model(xb)
            test_loss += loss_fn(pred, yb).item()
    avg_test_loss = test_loss / len(test_loader)
    scheduler.step(avg_test_loss)

    train_losses.append(avg_train_loss)
    test_losses.append(avg_test_loss)

    print(f"Epoch {epoch+1:2d}/70 — Train Loss: {avg_train_loss:.6f} — Test Loss: {avg_test_loss:.6f}")

    if avg_test_loss < best_loss:
        best_loss = avg_test_loss
        counter = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        counter += 1
        if counter >= patience:
            print(f"⏹️  Early stopping at epoch {epoch+1}")
            break

In [None]:
# Evaluate Final Model

model.load_state_dict(torch.load("best_model.pt"))
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        preds = model(xb)
        y_pred.extend(preds.numpy())
        y_true.extend(yb.numpy())

y_true = scaler_y.inverse_transform(y_true)
y_pred = scaler_y.inverse_transform(y_pred)

# Reconstruct full CCSD(T) predictions
test_indices = [i for i in range(len(dataset)) if i not in train_set.indices]
corr_MP2_test = df.iloc[test_indices]["corr_MP2"].values.reshape(-1, 1)
ccsdt_true = y_true + corr_MP2_test
ccsdt_pred = y_pred + corr_MP2_test

# Compute metrics
mae = mean_absolute_error(ccsdt_true, ccsdt_pred)
r2 = r2_score(ccsdt_true, ccsdt_pred)
print(f"\n📉 Final MAE: {mae:.6f} Hartree")
print(f"📈 R² Score: {r2:.3f}")

# Save results
results_df = pd.DataFrame({
    "corr_CCSDT_true": np.ravel(ccsdt_true),
    "corr_CCSDT_pred": np.ravel(ccsdt_pred)
})
results_df.to_csv("ccsdt_predictions.csv", index=False)

# Plot parity
plt.figure(figsize=(6, 6))
plt.scatter(ccsdt_true, ccsdt_pred, alpha=0.7)
plt.plot([min(ccsdt_true), max(ccsdt_true)], [min(ccsdt_true), max(ccsdt_true)], 'r--')
plt.xlabel("True corr_CCSDT (Hartree)")
plt.ylabel("Predicted corr_CCSDT (Hartree)")
plt.grid(True)
plt.tight_layout()
plt.savefig("ccsdt_parity_plot.png", dpi=300)
plt.show()

# Plot training and testing losses
plt.figure(figsize=(8, 5))
plt.plot(train_losses, label='Training Loss', marker='o')
plt.plot(test_losses, label='Testing Loss', marker='s')
plt.xlabel("Epoch")
plt.ylabel("Loss (Hartree)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("ccsdt_loss_curve.png", dpi=300)
plt.show()

In [None]:
# SHAP Feature Importance Analysis

def model_predict(x_numpy):
    x_tensor = torch.tensor(x_numpy, dtype=torch.float32)
    with torch.no_grad():
        return model(x_tensor).numpy()

# Use SHAP KernelExplainer
explainer = shap.KernelExplainer(model_predict, X_scaled[:50])
shap_values = explainer.shap_values(X_scaled[:100], nsamples=100)

# If single-output, squeeze
if isinstance(shap_values, list):
    shap_values = shap_values[0]
shap_values = np.squeeze(shap_values)

# Plot SHAP summary
shap.summary_plot(shap_values, X_scaled[:100], feature_names=feature_cols)
plt.tight_layout()
plt.savefig("shap_summary_plot.png", dpi=300)
plt.show()