In [57]:
import copy
import os
import os.path as osp
import csv
import util as ut
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
from torch.nn import GRU, Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, Set2Set
from torch_geometric.utils import remove_self_loops
from torch_geometric.nn import global_mean_pool
import torch_geometric
import logging

from pathlib import Path
import trainer
import json 

from detanet_model import *
import wandb

In [58]:
import random
random.seed(42)

In [59]:
batch_size = 128
epochs = 60
lr=5e-4
num_freqs=61

high_spec_cutoff = 0.1
low_fraction = 0.006

In [60]:

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
data_dir = os.path.join(parent_dir, 'data')
csv_path = data_dir + "/ee_polarizabilities_qm9s.csv"

dataset = []
frequencies = ut.load_unique_frequencies(csv_path)

csv_path_geometries = data_dir + "/KITqm9_geometries.csv"
geometries = ut.load_geometry(csv_path_geometries)

csv_spectra = data_dir + "/DATA_QM9_reduced_2025_03_06.csv"
sprectras = ut.load_spectra(csv_spectra)

model_path = "trained_param/ee_polarizabilities_all_freq_KITqm9_low_fraction_0_006.pth"

In [61]:
count = 0

with open(csv_path, newline='', encoding='utf-8') as csvfile:
    csv_reader = csv.reader(csvfile, delimiter=',')
    
    # Read the header to identify column indices
    header = next(csv_reader)
    frequency_idx = header.index("frequency")
    matrix_real_idx = header.index("matrix_real")
    matrix_imag_idx = header.index("matrix_imag")
    
    # Read each row
    for row in csv_reader:
        try:
            idx = int(row[0])
        except ValueError:
            print("Can't read index:", row[0])
            continue

        freq_str = row[frequency_idx]
        try:
            freq_val = float(freq_str)
        except ValueError:
            continue

        if freq_val not in frequencies:
            continue

        mol = None
        # Now you can look up any 'idx' in constant time
        if idx in geometries:
            mol = geometries[idx]
        else:
            continue        
        pos = mol.pos
        z = mol.z
        spectrum_value = ut.get_closest_spectrum_value(sprectras, idx, freq_val)

        # Parse JSON for real matrix
        matrix_real_str = row[matrix_real_idx]
        matrix_imag_str = row[matrix_imag_idx]
        try:
            real_3x3 = json.loads(matrix_real_str)  # expected shape [3,3]
        except json.JSONDecodeError:
            print("Warning: Could not parse real part of matrix for idx:", idx)
            continue

        try:
            imag_3x3 = json.loads(matrix_imag_str)  # expected shape [3,3]
        except json.JSONDecodeError:
            print("Warning: Could not parse imaginary part of matrix for idx:", idx)
            continue

        real_mat = torch.tensor(real_3x3, dtype=torch.float32)
        imag_mat = torch.tensor(imag_3x3, dtype=torch.float32)
        
        y = torch.cat([real_mat, imag_mat], dim=-1)  # shape [12]
            
        data_entry = Data(
            idx = mol.idx,
            pos=pos.to(torch.float32),    # Atomic positions
            z=torch.LongTensor(z),        # Atomic numbers
            freq=torch.tensor(float(freq_val), dtype=torch.float32),
            spec=torch.tensor(float(spectrum_value), dtype=torch.float32),
            y=y,  # Polarizability tensor (target)
        )
        if spectrum_value > high_spec_cutoff:
            dataset.append(data_entry)
            count += 1
        else:
            # Randomly sample ~0.2% of the "low-spec" data
            if random.random() < low_fraction:
                dataset.append(data_entry)
                

In [62]:

print(f"Collected {count} high-spec (>0.1) entries.")
print(f"Total dataset length: {len(dataset)}")

ex1 = dataset[0]
ex2 = dataset[5]

print("dataset[0] :", ex1.idx, ex1.freq, ex1.spec)
print("dataset[5] :", ex2.idx, ex2.freq, ex2.spec)

Collected 302 high-spec (>0.1) entries.
Total dataset length: 1027
dataset[0] : 202 tensor(1.5498) tensor(1.4367e-05)
dataset[5] : 416 tensor(5.1918) tensor(1.9698e-05)


In [63]:
spec_values = [item.spec.item() for item in dataset]
spec_mean = np.mean(spec_values)
spec_std = np.std(spec_values)

print("Spec mean, std =", spec_mean, spec_std)
for item in dataset:
    old_val = item.spec.item()
    norm_val = (old_val - spec_mean) / (spec_std + 1e-8)  # avoid div by zero
    item.spec = torch.tensor(norm_val, dtype=torch.float32)

Spec mean, std = 0.051924891750768254 0.0951434506295883


In [64]:

import numpy as np

real_vals = []
imag_vals = []

for item in dataset:
    y = item.y  # shape [3,6]
    # real => columns [:,:3]
    # imag => columns [:,3:]
    real_part = y[:, :3].reshape(-1).tolist()   # shape [9]
    imag_part = y[:, 3:].reshape(-1).tolist()   # shape [9]
    real_vals.extend(real_part)
    imag_vals.extend(imag_part)

# compute mean, std
real_mean, real_std = np.mean(real_vals), np.std(real_vals)
imag_mean, imag_std = np.mean(imag_vals), np.std(imag_vals)

print("Real mean, std =", real_mean, real_std)
print("Imag mean, std =", imag_mean, imag_std)

# Now transform each data entry
for item in dataset:
    y = item.y  # [3,6]
    
    # real => y[:, :3], shape [3,3]
    # imag => y[:, 3:], shape [3,3]
    real_slice = y[:, :3]
    imag_slice = y[:, 3:]
    
    # 4) do standard z-score
    real_norm = (real_slice - real_mean)/(real_std + 1e-8)
    imag_norm = (imag_slice - imag_mean)/(imag_std + 1e-8)
    
    # reassign
    y[:, :3] = real_norm
    y[:, 3:] = imag_norm
    
    item.y = y

Real mean, std = 31.399704151449814 165.56148659246568
Imag mean, std = 30.816447998693935 226.93200375170045


In [65]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = DetaNet(num_features=128,
                    act='swish',
                    maxl=3,
                    num_block=3,
                    radial_type='trainable_bessel',
                    num_radial=32,
                    attention_head=8,
                    rc=5.0,
                    dropout=0.0,
                    use_cutoff=False,
                    max_atomic_number=9,
                    atom_ref=None,
                    scale=1.0,
                    scalar_outsize= 4, # 2,#4, 
                    irreps_out= '2x2e', #'2e',# '2e+2e',
                    summation=True,
                    norm=False,
                    out_type='complex_2_tensor', # '2_tensor',
                    grad_type=None,
                    device=device)

In [66]:
state_dict = torch.load(model_path)
model.load_state_dict(state_dict=state_dict)
model.to(device)

DetaNet(
  (Embedding): Embedding(
    (act): Swish()
    (elec_emb): Linear(in_features=16, out_features=128, bias=False)
    (nuclare_emb): Embedding(10, 128)
    (ls): Linear(in_features=128, out_features=128, bias=True)
  )
  (Radial): Radial_Basis(
    (radial): Bessel_Function()
  )
  (blocks): Sequential(
    (0): Interaction_Block(
      (message): Message(
        (Attention): Edge_Attention(
          (actq): Swish()
          (actk): Swish()
          (actv): Swish()
          (acta): Swish()
          (softmax): Softmax(dim=-1)
          (lq): Linear(in_features=128, out_features=128, bias=True)
          (lk): Linear(in_features=128, out_features=128, bias=True)
          (lv): Linear(in_features=128, out_features=256, bias=True)
          (la): Linear(in_features=256, out_features=256, bias=True)
          (lrbf): Linear(in_features=32, out_features=128, bias=False)
          (lkrbf): Linear(in_features=128, out_features=128, bias=False)
          (lvrbf): Linear(in_featu

In [67]:
sample = dataset[30]


In [None]:
import torch_cluster
import ase
from ase.io import read
from ase.visualize import view
from ase.build import molecule
from code.util.visualize_polarizability import smiles_to_atoms, visualize_polarizability, compare_polarizabilities_eigen

result = model(pos=sample.pos.to(device), z=sample.z.to(device), spec=sample.spec.to(device))
print(result)



tensor([[ 0.5754,  0.3170, -0.0187, -0.1610, -0.5234,  0.6148],
        [ 0.3170, -0.6707, -0.1661, -0.5234,  2.0553,  0.5919],
        [-0.0187, -0.1661,  0.4377,  0.6148,  0.5919,  2.3250]],
       device='cuda:0', grad_fn=<MulBackward0>)


In [69]:
sample.y

tensor([[ 0.2032, -0.3479, -0.0881,  0.3966, -0.7132,  1.2752],
        [-0.3479,  0.8839, -0.0064, -0.7132,  0.5099, -1.6574],
        [-0.0881, -0.0064,  0.6884,  1.2752, -1.6574,  3.6398]])