In [None]:
#已有QM9展宽数据的4个csv，将分子数据和光谱数据统一起来保存，将光谱数据进行单独的归一化

In [None]:
import pandas as pd
from scipy.special import wofz
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil
from tqdm import trange
import seaborn as sns
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader

In [None]:
def one_hot_encoding(x, permitted_list):
    """
    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.
    """

    if x not in permitted_list:
        x = permitted_list[-1]

    binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]

    return binary_encoding
def get_atom_features(atom,
                      use_chirality = True,
                      hydrogens_implicit = True):
    #输入rdkit atom，输出1d numpy array
    permitted_list_of_atoms =  ['C','N','O','S','F','P','Cl','Br','Unknown']
    if hydrogens_implicit == False:
        permitted_list_of_atoms = ['H'] + permitted_list_of_atoms
        
    atom_type_enc = one_hot_encoding(str(atom.GetSymbol()),permitted_list_of_atoms)
    n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])
    formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])
    hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
    is_in_a_ring_enc = [int(atom.IsInRing())]
    is_aromatic_enc = [int(atom.GetIsAromatic())]
    atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)]
    vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)]
    covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]
    
    atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled
    
    if use_chirality == True:
        chirality_type_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
        atom_feature_vector += chirality_type_enc
    if hydrogens_implicit == True:
        n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"])
        atom_feature_vector += n_hydrogens_enc
        
    return np.array(atom_feature_vector)


def get_bond_features(bond, 
                      use_stereochemistry = True):
   #输入rdkit bond，输出1d numpy array

    permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]

    bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)    
    bond_is_conj_enc = [int(bond.GetIsConjugated())]    
    bond_is_in_ring_enc = [int(bond.IsInRing())]  
    
    bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc
    
    if use_stereochemistry == True:
        stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
        bond_feature_vector += stereo_type_enc

    return np.array(bond_feature_vector)

In [None]:
def get_feature_and_labels(i):
    #对于x---输入图（atom数*45维）edges_index(2*键数)edge_attr(键数*10)
    #对于y---输出nparray50维
    smi = data.iloc[i][0]
    mol = Chem.MolFromSmiles(smi)
    n_nodes = mol.GetNumAtoms()
    n_edges = 2*mol.GetNumBonds()
    X =np.zeros((n_nodes,45))
    
    for atom in mol.GetAtoms():
        X[atom.GetIdx(),:] = get_atom_features(atom)
    X = torch.tensor(X,dtype = torch.float)

    # construct edge index array E of shape (2, n_edges)
    (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))
    torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
    torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
    E = torch.stack([torch_rows, torch_cols], dim = 0)
    
    # construct edge feature array EF of shape (n_edges, n_edge_features)
    EF = np.zeros((n_edges, 10))        
    for (k, (n,j)) in enumerate(zip(rows, cols)):
        EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(n),int(j)))        
    EF = torch.tensor(EF, dtype = torch.float)
    
    # construct y
    s4500 = data.iloc[i][1:].tolist()
    scale = max(s4500)
    s4500_1 = [value / scale for value in s4500]
    s_50 = [sum(s4500_1[i:i+90]) for i in range(0, 4500, 90)]   
    y = torch.tensor(s_50)
    return Data(x=X,edge_index=E,edge_attr=EF,y=y,smi = smi)

In [None]:
data_num_list = [1,2,3,4]
for data_num in data_num_list:
    print(data_num)
    data_dir = os.path.join("/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_ir_broaden_part"
                        +str(data_num)+".csv")
    data = pd.read_csv(data_dir)
    data2 = []
    for i in trange(len(data)):
        piece = get_feature_and_labels(i)
        data2.append(piece)
    save_dir = os.path.join("/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input_part"
                            +str(data_num)+".pt")
    torch.save(data2,save_dir)

In [None]:
part1 = torch.load('/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input_part1.pt')
part2 = torch.load('/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input_part2.pt')
part3 = torch.load('/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input_part3.pt')
part4 = torch.load('/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input_part4.pt')
QM9_input = part1+part2+part3+part4
torch.save(QM9_input,'/home/chengc/workspace/cc/new_exp_0427/data/qm9_ir_broaden/qm9_input.pt')

In [None]:
len(QM9_input)