In [None]:
import torch
import pickle
import os

from module.molecule_dataset import DiagonalDataPreprocessor

data_folder_path = './QM9_pyscf'

raw_data = []
for i in range(5000):
    with open(os.path.join(data_folder_path, f'molecule_{i}.pkl'), 'rb') as f:
        data = pickle.load(f)
        raw_data.append(data)

preprocessor = DiagonalDataPreprocessor(basis='sto-3g')

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [2]:
from tools.AtomAtom import AtomBlockDecomposer

atom_basis_dict = {
    "hydrogen": preprocessor.atom_basis_dict["H"],
    "second_period": preprocessor.atom_basis_dict["C"]
}

hydrogen_decomposer = AtomBlockDecomposer(atom_basis_dict["hydrogen"], atom_basis_dict["hydrogen"])
second_period_decomposer = AtomBlockDecomposer(atom_basis_dict["second_period"], atom_basis_dict["second_period"])

hydrogen_mf_irreps_structure = hydrogen_decomposer.all_decomposed_irreps
second_period_mf_irreps_structure = second_period_decomposer.all_decomposed_irreps

hydrogen_mf_irreps = "+".join(["+".join(coupled_channel) for coupled_channel in hydrogen_mf_irreps_structure])
second_period_mf_irreps = "+".join(["+".join(coupled_channel) for coupled_channel in second_period_mf_irreps_structure])

In [3]:
def get_carbon_feature_and_label(original_data):
    
    data = original_data["second_period"]
    
    target_one_hot = torch.tensor([0, 1, 0, 0, 0])
    mask = torch.all(data['one_hot'] == target_one_hot, dim=1)
    
    filtered_sad_decomposed = data['sad_decomposed'][mask]
    filtered_hf_dm_decomposed = data['hf_dm_decomposed'][mask]
    
    return {
        "feature": filtered_sad_decomposed,
        "label": filtered_hf_dm_decomposed,
    }

In [7]:
from tqdm import tqdm

def get_all_carbon_features_and_labels(raw_data):
    all_features = []
    all_labels = []
    
    for data in tqdm(raw_data):
        carbon_data = get_carbon_feature_and_label(preprocessor.preprocess(data, cutoff=2.5))
        all_features.append(carbon_data["feature"])
        all_labels.append(carbon_data["label"])
    
    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    return {
        "features": all_features,
        "labels": all_labels,
    }

In [8]:
result = get_all_carbon_features_and_labels(raw_data)

100%|██████████| 5000/5000 [01:19<00:00, 62.96it/s]


In [9]:
with open('./dataset.pkl', 'wb') as f:
    pickle.dump(result, f)

(torch.Size([23769, 25]), torch.Size([23769, 25]))