In [1]:
import copy
import os
import os.path as osp
import csv
import utils as ut
import util_load_data as ud
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import GRU, Linear, ReLU, Sequential

import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import NNConv, Set2Set
from torch_geometric.utils import remove_self_loops
from torch_geometric.nn import global_mean_pool
import torch_geometric
import matplotlib.pyplot as plt

from pathlib import Path
import trainer
import json 
import wandb
import random

from detanet_model import *

# -------------------------------
# Config
# -------------------------------
random.seed(42)
batch_size = 64
epochs = 10
lr = 5e-4

normalize = False
fine_tune = False 
pol_type = 'ee'

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
data_dir   = os.path.join(parent_dir, 'data')

# -------------------------------
# Build each dataset
# -------------------------------
HOPV_dataset = ud.build_KITDatasets(
    data_dir,
    geometry_file="HOPV_geometries.csv", 
    spectra_file="HarvardOPV_original_2025_04_12_40states.csv",
    polarizability_file="polarizabilities_HOPV.csv",
    fun_type='l',
    pol_type=pol_type
)

KITqm9_dataset = ud.build_KITDatasets(
    data_dir,
    geometry_file="KITqm9_geometries.csv",
    spectra_file="DATA_QM9_reduced_2025_03_06.csv",
    polarizability_file="polarizabilities_qm9.csv",
    fun_type='l',
    pol_type=pol_type
)

# Merge into a single dataset
dataset = HOPV_dataset + KITqm9_dataset
print(f"Combined dataset size: {len(dataset)}")

ex1 = dataset[0]
print("ex1.y.shape:", ex1.polar.shape)  # Only if y is defined


Loaded 347 polarizability entries.
Loaded 2027 polarizability entries.
Combined dataset size: 2373
ex1.y.shape: torch.Size([62, 3, 6])


In [5]:
# -------------------------------
# Create Model
# -------------------------------
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = DetaNet(
    num_features=128,
    act='swish',
    maxl=3,
    num_block=3,
    radial_type='trainable_bessel',
    num_radial=32,
    attention_head=8,
    rc=5.0,
    dropout=0.0,
    use_cutoff=False,
    max_atomic_number=34,
    atom_ref=None,
    scale=1.0,
    scalar_outsize=(4*62),
    irreps_out='124x2e',
    summation=True,
    norm=False,
    out_type='complex_61_tensor', # e.g. your custom config
    grad_type=None,
    device=device
)

params='/media/maria/work_space/detanet-complex/code/trained_param/pol_spec.pth'
state_dict = torch.load(params)
model.to(device)
model.load_state_dict(state_dict=state_dict)

<All keys matched successfully>

In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

def plot_polarizability_elements(dataset, model, out_dir="."):
    """
    Assumptions:
      - data_entry.real -> shape [N, 3, 3]
      - data_entry.imag -> shape [N, 3, 3]
      - The model returns shape [N, 3, 6], where:
          [:, :, 0:3] => real part (3 columns -> [N, 3, 3])
          [:, :, 3:6] => imag part (3 columns -> [N, 3, 3])
    We then plot the 6 unique elements of a symmetric 3×3 matrix:
      (0,0)=xx, (0,1)=xy, (0,2)=xz, (1,1)=yy, (1,2)=yz, (2,2)=zz
    """

    # 1) Find the correct data entry
    data_entry = dataset[0]  # Assuming you want to plot the first entry
    mol_idx = data_entry.mol_idx
    print(f"Mol idx = {mol_idx}")

    # 2) Convert GT real & imag to NumPy arrays of shape [N, 3, 3]
    real_list = np.array(data_entry.real)  # shape [N, 3, 3]
    imag_list = np.array(data_entry.imag)  # shape [N, 3, 3]
    N = real_list.shape[0]
    print(f"Number of frequency steps = {N}")

    # 3) Get predicted polarizabilities from the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    z = data_entry.z.to(device)
    pos = data_entry.pos.to(device)

    with torch.no_grad():
        # pol_spec has shape [N, 3, 6]
        pol_spec = model(z=z, pos=pos).cpu().numpy()
        print("pol_spec", pol_spec)
        print("gt", real_list[0])

    # Split out real vs. imaginary from the last dimension
    # first 3 columns => real => shape [N, 3, 3]
    # last 3 columns  => imag => shape [N, 3, 3]
    pred_pol_real = pol_spec[..., 0:3]
    pred_pol_imag = pol_spec[..., 3:6]

    # 4) We only need the 6 unique elements of the symmetrical 3×3
    pairs = [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)]
    labels = ["xx", "xy", "xz", "yy", "yz", "zz"]

    # ---------------- PLOT REAL PART ----------------
    fig_real, axes_real = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))
    fig_real.suptitle(f"Polarizability (Real) - Mol idx={mol_idx}", fontsize=14)

        # --- 4) Determine global min/max for Real part ---
    real_values = []
    real_pred_values = []
    for (i, j) in pairs:
        # Gather GT values and predicted values
        gt_vals = real_list[:, i, j]
        pred_vals = pred_pol_real[:, i, j]  # If shape is [N, 3, 3], check indexing
        real_values.extend(gt_vals)
        real_pred_values.extend(pred_vals)


    global_min_real = -50 #min(real_values)
    global_max_real = 200 #max(real_values)


    for subplot_idx, (i, j) in enumerate(pairs):
        row = subplot_idx // 2
        col = subplot_idx % 2
        ax = axes_real[row, col]

        # Ground-truth
        gt_vals = real_list[:, i, j]          # shape [N]
        # Predicted
        pred_vals = pred_pol_real[:, i, j]    # shape [N]

        ax.scatter(range(N), gt_vals,  label="GT Real",  color="purple", s=0.9)
        ax.scatter(range(N), pred_vals, label="Pred Real", color="red",  alpha=1)

        ax.set_title(f"{labels[subplot_idx]} (real)")
        ax.set_xlabel("Index (freq step)")
        ax.set_ylabel("Polarizability")
        ax.legend()
        ax.grid(True)
        ax.set_ylim(global_min_real, global_max_real)

    plt.tight_layout()
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(os.path.join(out_dir, f"polar_elements_real_{mol_idx}.png"), dpi=300)
    plt.show()
    plt.close(fig_real)

    # ---------------- PLOT IMAG PART ----------------
    fig_imag, axes_imag = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))
    fig_imag.suptitle(f"Polarizability (Imag) - Mol idx={mol_idx}", fontsize=14)

        # --- 6) Determine global min/max for Imag part ---
    imag_values = []
    for (i, j) in pairs:
        gt_vals = imag_list[:, i, j]
        pred_vals = pred_pol_imag[i, j]
        imag_values.extend(gt_vals)
        if np.isscalar(pred_vals):
            imag_values.append(pred_vals)
        else:
            imag_values.extend(pred_vals)

    global_min_imag = -50# min(imag_values)
    global_max_imag = 100 # max(imag_values)

    for subplot_idx, (i, j) in enumerate(pairs):
        row = subplot_idx // 2
        col = subplot_idx % 2
        ax = axes_imag[row, col]

        # Ground-truth
        gt_vals = imag_list[:, i, j]          # shape [N]

        # Predicted
        pred_vals = pred_pol_imag[:, i, j]    # shape [N]

        ax.scatter(range(N), gt_vals,  label="GT Imag",  color="purple", s=0.9)
        ax.scatter(range(N), pred_vals, label="Pred Imag", color="red",  alpha=1)

        ax.set_title(f"α_{labels[subplot_idx]} (imag)")
        ax.set_xlabel("Index (freq step)")
        ax.set_ylabel("Polarizability")
        ax.legend()
        ax.grid(True)
        ax.set_ylim(global_min_imag, global_max_imag)

    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"polar_elements_imag_{mol_idx}.png"), dpi=300)
    plt.show()
    plt.close(fig_imag)


In [7]:

plot_polarizability_elements(dataset,model, out_dir="/media/maria/work_space/detanet-complex/code/plots")

AttributeError: 'GlobalStorage' object has no attribute 'mol_idx'