In [None]:

# %% [markdown]
# ### Stage 1: Initial Training

# %%
params, valid_loss = train_model(
    key=key, 
    model=model,
    writer=None,                 # TensorBoard writer (optional)
    train_data=train_data, 
    valid_data=valid_data,
    num_epochs=30,               # Number of training epochs
    learning_rate=1e-3,          # Initial learning rate
    batch_size=1,                # Batch size
    ndcm=model.n_dcm,            # Number of DCM sites
    esp_w=10000.0,               # Weight for ESP loss term
)

print(f"Stage 1 validation loss: {valid_loss:.6f}")

# %% [markdown]
# ### Stage 2: Fine-tuning

# %%
params, valid_loss = train_model(
    key=key, 
    model=model,
    writer=None,
    train_data=train_data, 
    valid_data=valid_data,
    num_epochs=100,              # Extended training
    learning_rate=5e-4,          # Reduced learning rate
    batch_size=1,
    ndcm=model.n_dcm,
    esp_w=1000.0,                # Reduced ESP weight
    restart_params=params,       # Continue from stage 1
)

print(f"Stage 2 validation loss: {valid_loss:.6f}")

# %% [markdown]
# ## Model Analysis and Visualization

# %% [markdown]
# ### Prepare Test Batch

# %%
# Load a test molecule for analysis
batch = prepare_batch(data_path, index=0)
batch['com'] = np.mean(batch['R'].T, axis=-1)  # Center of mass
batch["Dxyz"] = batch["R"] - batch["com"]      # Coordinates relative to COM

# Run model prediction
output = dcmnet_analysis(params, model, batch)

print(f"RMSE (all points): {output['rmse_model']:.6f}")
print(f"RMSE (masked): {output['rmse_model_masked']:.6f}")

# %% [markdown]
# ### Visualize Molecular Structure

# %%
# Create ASE atoms object for visualization
n_atoms = int(batch["N"])
atoms = ase.Atoms(
    batch["Z"][:n_atoms],  # Atomic numbers
    batch["R"][:n_atoms]   # Atomic positions
)
view(atoms, viewer="x3d")

# %% [markdown]
# ### ESP Prediction Quality

# %%
VMAX = 0.01  # Color scale for ESP visualization
N_SURFACE_POINTS = 4150

# Correlation plot: predicted vs. true ESP
correlation_ax = pw.Brick()
correlation_ax.scatter(batch["esp"], output['esp_pred'], s=1, alpha=0.5)
max_val = np.sqrt(max(np.max(batch["esp"]**2), np.max(output['esp_pred']**2)))
correlation_ax.plot(
    np.linspace(-max_val, max_val, 100), 
    np.linspace(-max_val, max_val, 100), 
    'r--', linewidth=2
)
correlation_ax.set_aspect('equal')
correlation_ax.set_xlabel('True ESP')
correlation_ax.set_ylabel('Predicted ESP')
correlation_ax.set_title('ESP Correlation')

# True ESP on VdW surface
ax_true = pw.Brick()
ax_true.scatter(
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 0], 
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 1], 
    c=batch["esp"][0][:N_SURFACE_POINTS],
    s=0.01, vmin=-VMAX, vmax=VMAX, cmap='coolwarm'
)
ax_true.set_aspect('equal')
ax_true.set_title('True ESP')

# Predicted ESP on VdW surface
ax_pred = pw.Brick()
ax_pred.scatter(
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 0], 
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 1], 
    c=output['esp_pred'][:N_SURFACE_POINTS],
    s=0.01, vmin=-VMAX, vmax=VMAX, cmap='coolwarm'
)
ax_pred.set_aspect('equal')
ax_pred.set_title('Predicted ESP')

# Difference (error) map
ax_diff = pw.Brick()
ax_diff.scatter(
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 0], 
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 1], 
    c=batch["esp"][0][:N_SURFACE_POINTS] - output['esp_pred'][:N_SURFACE_POINTS],
    s=0.01, vmin=-VMAX, vmax=VMAX, cmap='coolwarm'
)
ax_diff.set_aspect('equal')
ax_diff.set_title('Error (True - Pred)')

# Combine all plots
correlation_ax | ax_true | ax_pred | ax_diff

# %% [markdown]
# ### Predicted Charge Distribution

# %%
# Visualize predicted monopole (charge) distribution
charge_ax = pw.Brick()
charge_ax.matshow(
    output["mono"][0][:n_atoms], 
    vmin=-1, vmax=1, cmap='RdBu_r'
)
charge_ax.set_title('Charge Multipoles')
charge_ax.set_xlabel('Multipole index')
charge_ax.set_ylabel('Atom index')

# Sum of charges per atom
sum_charge_ax = pw.Brick()
sum_charge_ax.matshow(
    output["mono"][0][:n_atoms].sum(axis=-1)[:, None],
    vmin=-1, vmax=1, cmap='RdBu_r'
)
sum_charge_ax.axis("off")

# Combine with colorbar
fig_charges = (sum_charge_ax | charge_ax)
fig_charges.add_colorbar(vmin=-1, vmax=1)

# Print total charge
total_charge = output["mono"][0][:n_atoms].sum()
print(f"Total molecular charge: {total_charge:.4f}")

# %% [markdown]
# ### Visualize DCM Sites

# %%
# Extract DCM positions for visualization
dcm_positions = output["dipo"][:n_atoms * NDCM]

# Create pseudo-atoms for DCM sites (using hydrogen)
dcm_z = np.ones(len(dcm_positions), dtype=int)

dcm_atoms = ase.Atoms(dcm_z, dcm_positions)
view(dcm_atoms, viewer="x3d")

print(f"Number of DCM sites: {len(dcm_positions)}")
print(f"DCM sites per atom: {NDCM}")

In [None]:

print(f"Stage 2 validation loss: {valid_loss:.6f}")

# %% [markdown]
# ## Model Analysis and Visualization

# %% [markdown]
# ### Prepare Test Batch

# %%
# Load a test molecule for analysis
batch = prepare_batch(data_path, index=0)
batch['com'] = np.mean(batch['R'].T, axis=-1)  # Center of mass
batch["Dxyz"] = batch["R"] - batch["com"]      # Coordinates relative to COM

# Run model prediction
output = dcmnet_analysis(params, model, batch)

print(f"RMSE (all points): {output['rmse_model']:.6f}")
print(f"RMSE (masked): {output['rmse_model_masked']:.6f}")

# %% [markdown]
# ### Visualize Molecular Structure

# %%
# Create ASE atoms object for visualization
n_atoms = int(batch["N"])
atoms = ase.Atoms(
    batch["Z"][:n_atoms],  # Atomic numbers
    batch["R"][:n_atoms]   # Atomic positions
)
view(atoms, viewer="x3d")

# %% [markdown]
# ### ESP Prediction Quality

# %%
VMAX = 0.01  # Color scale for ESP visualization
N_SURFACE_POINTS = 4150

# Correlation plot: predicted vs. true ESP
correlation_ax = pw.Brick()
correlation_ax.scatter(batch["esp"], output['esp_pred'], s=1, alpha=0.5)
max_val = np.sqrt(max(np.max(batch["esp"]**2), np.max(output['esp_pred']**2)))
correlation_ax.plot(
    np.linspace(-max_val, max_val, 100), 
    np.linspace(-max_val, max_val, 100), 
    'r--', linewidth=2
)
correlation_ax.set_aspect('equal')
correlation_ax.set_xlabel('True ESP')
correlation_ax.set_ylabel('Predicted ESP')
correlation_ax.set_title('ESP Correlation')

# True ESP on VdW surface
ax_true = pw.Brick()
ax_true.scatter(
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 0], 
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 1], 
    c=batch["esp"][0][:N_SURFACE_POINTS],
    s=0.01, vmin=-VMAX, vmax=VMAX, cmap='coolwarm'
)
ax_true.set_aspect('equal')
ax_true.set_title('True ESP')

# Predicted ESP on VdW surface
ax_pred = pw.Brick()
ax_pred.scatter(
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 0], 
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 1], 
    c=output['esp_pred'][:N_SURFACE_POINTS],
    s=0.01, vmin=-VMAX, vmax=VMAX, cmap='coolwarm'
)
ax_pred.set_aspect('equal')
ax_pred.set_title('Predicted ESP')

# Difference (error) map
ax_diff = pw.Brick()
ax_diff.scatter(
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 0], 
    batch["vdw_surface"][0][:N_SURFACE_POINTS, 1], 
    c=batch["esp"][0][:N_SURFACE_POINTS] - output['esp_pred'][:N_SURFACE_POINTS],
    s=0.01, vmin=-VMAX, vmax=VMAX, cmap='coolwarm'
)
ax_diff.set_aspect('equal')
ax_diff.set_title('Error (True - Pred)')

# Combine all plots
correlation_ax | ax_true | ax_pred | ax_diff

# %% [markdown]
# ### Predicted Charge Distribution

# %%
# Visualize predicted monopole (charge) distribution
charge_ax = pw.Brick()
charge_ax.matshow(
    output["mono"][0][:n_atoms], 
    vmin=-1, vmax=1, cmap='RdBu_r'
)
charge_ax.set_title('Charge Multipoles')
charge_ax.set_xlabel('Multipole index')
charge_ax.set_ylabel('Atom index')

# Sum of charges per atom
sum_charge_ax = pw.Brick()
sum_charge_ax.matshow(
    output["mono"][0][:n_atoms].sum(axis=-1)[:, None],
    vmin=-1, vmax=1, cmap='RdBu_r'
)
sum_charge_ax.axis("off")

# Combine with colorbar
fig_charges = (sum_charge_ax | charge_ax)
fig_charges.add_colorbar(vmin=-1, vmax=1)

# Print total charge
total_charge = output["mono"][0][:n_atoms].sum()
print(f"Total molecular charge: {total_charge:.4f}")

# %% [markdown]
# ### Visualize DCM Sites

# %%
# Extract DCM positions for visualization
dcm_positions = output["dipo"][:n_atoms * NDCM]

# Create pseudo-atoms for DCM sites (using hydrogen)
dcm_z = np.ones(len(dcm_positions), dtype=int)

dcm_atoms = ase.Atoms(dcm_z, dcm_positions)
view(dcm_atoms, viewer="x3d")

print(f"Number of DCM sites: {len(dcm_positions)}")
print(f"DCM sites per atom: {NDCM}")