In [19]:
import copy
import os
import os.path as osp
import csv
import util as ut
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
from torch.nn import GRU, Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, Set2Set
from torch_geometric.utils import remove_self_loops
from torch_geometric.nn import global_mean_pool
import torch_geometric
import logging

from pathlib import Path
import trainer
import json 

from detanet_model import *
import wandb

In [20]:
import random
random.seed(42)

In [21]:
batch_size = 128
epochs = 60
lr=5e-4
num_freqs=61

high_spec_cutoff = 0.1
low_fraction = 0.008

In [22]:

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
data_dir = os.path.join(parent_dir, 'data')
csv_path = data_dir + "/ee_polarizabilities_qm9s.csv"

dataset = []
frequencies = ut.load_unique_frequencies(csv_path)

csv_path_geometries = data_dir + "/KITqm9_geometries.csv"
geometries = ut.load_geometry(csv_path_geometries)

csv_spectra = data_dir + "/DATA_QM9_reduced_2025_03_06.csv"
sprectras = ut.load_spectra(csv_spectra)

model_path = "trained_param/ee_polarizabilities_all_freq_KITqm9_freq_emb.pth"

In [23]:
count = 0

with open(csv_path, newline='', encoding='utf-8') as csvfile:
    csv_reader = csv.reader(csvfile, delimiter=',')
    
    # Read the header to identify column indices
    header = next(csv_reader)
    frequency_idx = header.index("frequency")
    matrix_real_idx = header.index("matrix_real")
    matrix_imag_idx = header.index("matrix_imag")
    
    # Read each row
    for row in csv_reader:
        try:
            idx = int(row[0])
        except ValueError:
            print("Can't read index:", row[0])
            continue

        freq_str = row[frequency_idx]
        try:
            freq_val = float(freq_str)
        except ValueError:
            continue

        if freq_val not in frequencies:
            continue

        mol = None
        # Now you can look up any 'idx' in constant time
        if idx in geometries:
            mol = geometries[idx]
        else:
            continue        
        pos = mol.pos
        z = mol.z
        spectrum_value = ut.get_closest_spectrum_value(sprectras, idx, freq_val)

        # Parse JSON for real matrix
        matrix_real_str = row[matrix_real_idx]
        matrix_imag_str = row[matrix_imag_idx]
        try:
            real_3x3 = json.loads(matrix_real_str)  # expected shape [3,3]
        except json.JSONDecodeError:
            print("Warning: Could not parse real part of matrix for idx:", idx)
            continue

        try:
            imag_3x3 = json.loads(matrix_imag_str)  # expected shape [3,3]
        except json.JSONDecodeError:
            print("Warning: Could not parse imaginary part of matrix for idx:", idx)
            continue

        real_mat = torch.tensor(real_3x3, dtype=torch.float32)
        imag_mat = torch.tensor(imag_3x3, dtype=torch.float32)
        
        y = torch.cat([real_mat, imag_mat], dim=-1)  # shape [12]
            
        data_entry = Data(
            idx = mol.idx,
            pos=pos.to(torch.float32),    # Atomic positions
            z=torch.LongTensor(z),        # Atomic numbers
            freq=torch.tensor(float(freq_val), dtype=torch.float32),
            spec=torch.tensor(float(spectrum_value), dtype=torch.float32),
            y=y,  # Polarizability tensor (target)
        )
        if spectrum_value > high_spec_cutoff:
            dataset.append(data_entry)
            count += 1
        else:
         # Randomly sample ~0.2% of the "low-spec" data
            if random.random() < low_fraction:
                dataset.append(data_entry)
                

In [24]:

print(f"Collected {count} high-spec (>0.1) entries.")
print(f"Total dataset length: {len(dataset)}")

ex1 = dataset[0]
ex2 = dataset[5]

print("dataset[0] :", ex1.idx, ex1.freq, ex1.spec)
print("dataset[5] :", ex2.idx, ex2.freq, ex2.spec)

Collected 302 high-spec (>0.1) entries.
Total dataset length: 1269
dataset[0] : 34 tensor(3.0221) tensor(6.2459e-06)
dataset[5] : 381 tensor(3.2546) tensor(9.7373e-06)


In [25]:
spec_values = [item.spec.item() for item in dataset]
spec_mean = np.mean(spec_values)
spec_std = np.std(spec_values)

print("Spec mean, std =", spec_mean, spec_std)
for item in dataset:
    old_val = item.spec.item()
    norm_val = (old_val - spec_mean) / (spec_std + 1e-8)  # avoid div by zero
    item.spec = torch.tensor(norm_val, dtype=torch.float32)

Spec mean, std = 0.0422457436900231 0.0879089541274016


In [26]:

import numpy as np

y_vals = []
epsilon = 1e-8

for item in dataset:
    y = item.y.reshape(-1).tolist()   
    y_vals.extend(y)
# compute mean, std
y_mean, y_std = np.mean(y_vals), np.std(y_vals)

print("y mean, std =", y_mean, y_std)

# Now transform each data entry
for item in dataset:
    y = item.y  # [3,6]
    # 4) do standard z-score
    y_norm = (y - y_mean)/(y_std + epsilon)
    item.y = y_norm

y mean, std = 28.25078968967985 179.75417543967845


In [27]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = DetaNet(num_features=128,
                    act='swish',
                    maxl=3,
                    num_block=3,
                    radial_type='trainable_bessel',
                    num_radial=32,
                    attention_head=8,
                    rc=5.0,
                    dropout=0.0,
                    use_cutoff=False,
                    max_atomic_number=9,
                    atom_ref=None,
                    scale=1.0,
                    scalar_outsize= 4, # 2,#4, 
                    irreps_out= '2x2e', #'2e',# '2e+2e',
                    summation=True,
                    norm=False,
                    out_type='complex_2_tensor', # '2_tensor',
                    grad_type=None,
                    device=device)

In [28]:
state_dict = torch.load(model_path)
model.load_state_dict(state_dict=state_dict)
model.to(device)

DetaNet(
  (Embedding): Embedding(
    (act): Swish()
    (elec_emb): Linear(in_features=16, out_features=128, bias=False)
    (nuclare_emb): Embedding(10, 128)
    (ls): Linear(in_features=128, out_features=128, bias=True)
    (spec_lin): Linear(in_features=1, out_features=128, bias=False)
  )
  (Radial): Radial_Basis(
    (radial): Bessel_Function()
  )
  (blocks): Sequential(
    (0): Interaction_Block(
      (message): Message(
        (Attention): Edge_Attention(
          (actq): Swish()
          (actk): Swish()
          (actv): Swish()
          (acta): Swish()
          (softmax): Softmax(dim=-1)
          (lq): Linear(in_features=128, out_features=128, bias=True)
          (lk): Linear(in_features=128, out_features=128, bias=True)
          (lv): Linear(in_features=128, out_features=256, bias=True)
          (la): Linear(in_features=256, out_features=256, bias=True)
          (lrbf): Linear(in_features=32, out_features=128, bias=False)
          (lkrbf): Linear(in_features=1

In [29]:
sample = dataset[1]


In [42]:
import torch_cluster
import ase
from ase.io import read
from ase.visualize import view
from ase.build import molecule
#from code.util.visualize_polarizability import smiles_to_atoms, visualize_polarizability, compare_polarizabilities_eigen

result = model(pos=sample.pos.to(device), z=sample.z.to(device), spec=sample.spec.to(device))
print(result)



spec_emb  tensor([[-0.0847,  0.0188,  0.0676,  0.0151,  0.0974,  0.0412,  0.0524,  0.0455,
          0.0566,  0.0803, -0.0541,  0.0693,  0.0960,  0.0213, -0.0595, -0.0643,
         -0.0471, -0.0022,  0.0248, -0.0901,  0.0234, -0.0041,  0.0059,  0.0910,
         -0.0809,  0.0914,  0.0207, -0.0507,  0.0871, -0.0870, -0.0361, -0.0091,
         -0.0901, -0.0169,  0.0238, -0.0204, -0.0285,  0.0212, -0.0218,  0.1031,
         -0.0197,  0.0123, -0.0899,  0.0742, -0.0600, -0.0519, -0.0630, -0.0315,
         -0.0074, -0.0608,  0.0343, -0.0563, -0.0783,  0.0773, -0.0300,  0.0495,
          0.0354, -0.0337,  0.0121,  0.0102,  0.0353, -0.0618,  0.0704, -0.0508,
          0.0555, -0.0419,  0.0972, -0.0490,  0.0862, -0.0606,  0.0499,  0.0299,
          0.0024,  0.0181,  0.0420,  0.0332, -0.0942, -0.0582,  0.0826,  0.0248,
         -0.0179, -0.1018,  0.0427,  0.0220, -0.1040, -0.0092,  0.0523, -0.0613,
          0.0141, -0.0304, -0.0796, -0.0046, -0.0240,  0.1031,  0.0777,  0.0264,
          0.0700, 

In [43]:
sample.y

tensor([[ 0.1789, -0.2263, -0.0981, -0.1568, -0.1573, -0.1570],
        [-0.2263,  0.2006, -0.1233, -0.1573, -0.1568, -0.1571],
        [-0.0981, -0.1233,  0.1520, -0.1570, -0.1571, -0.1568]])

In [50]:
pred_y_denorm = result * (y_std + epsilon) + y_mean
pred_y_denorm = np.array(pred_y_denorm.cpu().detach())

y_denorm = sample.y * (y_std + epsilon) + y_mean
y_denorm = np.array(y_denorm.cpu().detach())

print("pred_y ", pred_y_denorm)
print("y ", y_denorm)

pred_y  [[47.115906 62.384483 19.072948 15.633654 32.774036 20.486916]
 [62.384483 38.540894 28.57941  32.774036 37.355232 40.964867]
 [19.072948 28.57941  73.59104  20.486916 40.964867 28.804932]]
y  [[ 6.0405468e+01 -1.2426197e+01  1.0619621e+01  6.3978195e-02
  -2.1568298e-02  2.2394180e-02]
 [-1.2426197e+01  6.4315720e+01  6.0885277e+00 -2.1568298e-02
   6.9450378e-02  1.2605667e-02]
 [ 1.0619621e+01  6.0885277e+00  5.5577129e+01  2.2394180e-02
   1.2605667e-02  5.7174683e-02]]


In [33]:
from ase.build import molecule

qm9s = torch.load("../data/qm9s.pt")

In [34]:
qm9s_dict = {entry.number: entry for entry in qm9s}

In [45]:
print(qm9s_dict[sample.idx].number)
print(qm9s_dict[sample.idx].z)
print(qm9s_dict[sample.idx].smile)
print(sample.idx)
print(sample.z)



202
tensor([6, 6, 6, 7, 6, 7, 1, 1, 1, 1, 1, 1])
CC1=CNC=N1
202
tensor([6, 6, 6, 7, 6, 7, 1, 1, 1, 1, 1, 1])


In [46]:
from util.visualize_polarizability import smiles_to_atoms, visualize_polarizability, compare_polarizabilities_eigen
atoms = smiles_to_atoms(qm9s_dict[sample.idx].smile)
view(atoms, viewer='x3d')

In [52]:
real_part_true_denorm = pred_y_denorm[:,:3]
print(real_part_true_denorm)

visualize_polarizability(atoms, real_part_true_denorm)

real_part_predicted_denorm = y_denorm[:,:3]
print(real_part_predicted_denorm)

[[47.115906 62.384483 19.072948]
 [62.384483 38.540894 28.57941 ]
 [19.072948 28.57941  73.59104 ]]
Raw Eigenvalues (True Tensor): [126.622375 -20.431532  53.057   ]
Real Eigenvalues (True Tensor): [126.622375 -20.431532  53.057   ]


[[ 60.405468  -12.426197   10.619621 ]
 [-12.426197   64.31572     6.0885277]
 [ 10.619621    6.0885277  55.57713  ]]


In [53]:
real_part_true = sample.y[:,:3]
print(real_part_true)

visualize_polarizability(atoms, real_part_true)

real_part_predicted = y[:,:3]
print(real_part_predicted)

tensor([[ 0.1789, -0.2263, -0.0981],
        [-0.2263,  0.2006, -0.1233],
        [-0.0981, -0.1233,  0.1520]])
Raw Eigenvalues (True Tensor): [-0.1249928   0.41849768  0.23803183]
Real Eigenvalues (True Tensor): [-0.1249928   0.41849768  0.23803183]


tensor([[ 60.4055, -12.4262,  10.6196],
        [-12.4262,  64.3157,   6.0885],
        [ 10.6196,   6.0885,  55.5771]])


In [54]:
import plotly.graph_objects as go

def compare_polarizabilities(atoms, true_tensor, predicted_tensor):
    """
    Visualize a molecule with true and predicted polarizability tensors using Plotly.
    
    Parameters:
        atoms: ASE Atoms object
        true_tensor: (3, 3) numpy array representing the true polarizability tensor (Bohr^3)
        predicted_tensor: (3, 3) numpy array representing the predicted polarizability tensor (Bohr^3)
    """
    def create_polarizability_traces(tensor, center, color, name, scale=0.05):

        tensor = tensor.reshape(3,3)

        # Eigen decomposition of the polarizability tensor
        eigenvalues, eigenvectors = np.linalg.eig(tensor)
        
        print("Raw Eigenvalues (True Tensor):", eigenvalues)
        print("Real Eigenvalues (True Tensor):", np.real(eigenvalues))
        # Ensure eigenvalues and eigenvectors are real
        eigenvalues = np.real(eigenvalues)
        eigenvectors = np.real(eigenvectors)


        # Calculate arrows for the principal axes
        arrows = [scale * eigenvalue * eigenvector for eigenvalue, eigenvector in zip(eigenvalues, eigenvectors.T)]

        # Create arrow traces
        arrow_traces = []
        for arrow in arrows:
            arrow_trace = go.Scatter3d(
                x=[center[0], center[0] + arrow[0]],
                y=[center[1], center[1] + arrow[1]],
                z=[center[2], center[2] + arrow[2]],
                mode='lines+markers',
                line=dict(color=color, width=4),
                marker=dict(size=4, color=color),
                name=f"{name} Arrow"
            )
            arrow_traces.append(arrow_trace)

        # Generate ellipsoid points
        eigenvalues = np.abs(eigenvalues) * 0.05  # Adjust scaling factor
        u = np.linspace(0, 2 * np.pi, 50)
        v = np.linspace(0, np.pi, 25)
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones_like(u), np.cos(v))

        # Scale and rotate the ellipsoid
        ellipsoid = np.dot(np.column_stack([x.flatten(), y.flatten(), z.flatten()]), np.diag(eigenvalues))
        ellipsoid = np.dot(ellipsoid, eigenvectors.T)
        x_ellipsoid = ellipsoid[:, 0].reshape(x.shape)
        y_ellipsoid = ellipsoid[:, 1].reshape(y.shape)
        z_ellipsoid = ellipsoid[:, 2].reshape(z.shape)

        # Create ellipsoid trace
        ellipsoid_trace = go.Surface(
            x=x_ellipsoid + center[0],
            y=y_ellipsoid + center[1],
            z=z_ellipsoid + center[2],
            colorscale=[[0, color], [1, color]],
            opacity=0.2,
            showscale=False,
            name=f"{name} Ellipsoid"
        )
        return arrow_traces, ellipsoid_trace

    # Calculate center of the molecule
    center = atoms.get_center_of_mass()

    # Create polarizability traces for true tensor
    true_arrows, true_ellipsoid = create_polarizability_traces(true_tensor, center, 'blue', 'True')

    # Create polarizability traces for predicted tensor
    pred_arrows, pred_ellipsoid = create_polarizability_traces(predicted_tensor, center, 'red', 'Predicted')

    # Extract atomic positions and symbols
    positions = atoms.get_positions()
    symbols = atoms.get_chemical_symbols()

    # Define colors for different atom types
    atom_colors = {
        'H': 'grey',
        'C': 'black',
        'O': 'pink',
        'N': 'purple',
        'S': 'yellow',
        'Cl': 'green',
        'F': 'cyan'
    }
    colors = [atom_colors.get(symbol, 'gray') for symbol in symbols]

    # Create scatter plot for the atoms with different colors
    atom_trace = go.Scatter3d(
        x=positions[:, 0],
        y=positions[:, 1],
        z=positions[:, 2],
        mode='markers+text',
        marker=dict(size=6, color=colors),
        text=symbols,
        textposition="top center",
        name="Atoms"
    )

    # Combine all traces
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z')
        ),
        title="True vs Predicted Polarizability Comparison"
    )

    fig = go.Figure(data=[atom_trace, true_ellipsoid, pred_ellipsoid] + true_arrows + pred_arrows, layout=layout)

    # Display the plot
    fig.show()

In [55]:
compare_polarizabilities(
                atoms,
                real_part_true_denorm,
                real_part_predicted_denorm
            )

Raw Eigenvalues (True Tensor): [126.622375 -20.431532  53.057   ]
Real Eigenvalues (True Tensor): [126.622375 -20.431532  53.057   ]
Raw Eigenvalues (True Tensor): [40.249687 75.30716  64.74147 ]
Real Eigenvalues (True Tensor): [40.249687 75.30716  64.74147 ]


In [39]:
compare_polarizabilities(
                atoms,
                np.array(real_part_true.cpu().detach()),
                np.array(real_part_predicted.cpu().detach())
            )

NameError: name 'real_part_predicted' is not defined