In [1]:
import os
import collections
import numpy as np
import pandas as pd
import pickle
import shutil
import torch
import matplotlib.pyplot as plt

In [2]:
ibm_path = 'data/finetune_datasets_from_molformer/qm9/'
ori_path = 'data/dsgdb9nsd.xyz/'
save_path = 'data/qm9_z_and_pos/'

ibm_all = 'qm9.csv'
ibm_train = 'qm9_train.csv'
ibm_valid = 'qm9_valid.csv'
ibm_test = 'qm9_test.csv'

ori_file = sorted(os.listdir(ori_path))

ele2num = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'S': 16}

In [3]:
cnt = 0

for i in range(len(ori_file)):
    with open(os.path.join(ori_path, ori_file[i]), 'r') as f:
        s = f.readlines()
        s = ''.join(s)
        if '*^' in s:
            cnt += 1

cnt  # 176, just a little

176

In [4]:
def get_u0(s):
    return float(s[1].split()[12])


def get_smile(s):
    return s[-2].split()[0]


def get_z(s):
    num_atoms = int(s[0].split()[0])
    ret = []

    for i in range(num_atoms):
        ret.append(ele2num[s[i + 2].split()[0]])

    return ret


def get_pos(s):
    num_atoms = int(s[0].split()[0])
    ret = []

    def process(s):
        """不太清楚 *^ 表示什么, 这里处理为 e, 原因是下面这个 torch_geometric qm9 数据集用的数据链接
        中的数据把含 *^ 的项都处理为 0.0 了, 所以处理为 * 表示省略后面的数字且 ^ 表示指数的话就不太合适,
        因为这样的话比如 gdb_10422 的 -1.0119*^-6 就会是 -0.9 左右, 就比较大了.

        Link:
            https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip

        关于链接中 sdf 的文件说明:
            https://blog.csdn.net/weixin_43135178/article/details/128566390
        """
        if '*^' not in s:
            return float(s)

        s = s.split('*^')
        a, b = float(s[0]), int(s[1])
        return a * (10**b)

    for i in range(num_atoms):
        line = s[i + 2].split()[1:]
        pos = []
        pos.append(process(line[0]))
        pos.append(process(line[1]))
        pos.append(process(line[2]))
        ret.append(pos)

    return ret


with open(os.path.join(ori_path, ori_file[0]), 'r') as f:
    s = f.readlines()
    print(get_pos(s))

with open(os.path.join(ori_path, ori_file[211]), 'r') as f:  # *^ exists
    s = f.readlines()
    print(get_pos(s))

with open(os.path.join(ori_path, ori_file[4723]), 'r') as f:  # *^ exists
    s = f.readlines()
    print(get_pos(s))

with open(os.path.join(ori_path, ori_file[10421]), 'r') as f:  # *^ exists
    s = f.readlines()
    print(get_pos(s))

with open(os.path.join(ori_path, ori_file[17521]), 'r') as f:  # *^ exists
    s = f.readlines()
    print(get_pos(s))

[[-0.0126981359, 1.0858041578, 0.0080009958], [0.002150416, -0.0060313176, 0.0019761204], [1.0117308433, 1.4637511618, 0.0002765748], [-0.540815069, 1.4475266138, -0.8766437152], [-0.5238136345, 1.4379326443, 0.9063972942]]
[[2.1997e-06, 1.4462618059, 0.0098312216], [-0.0395853465, -0.0001481182, 0.0022628164], [-0.0102031441, -0.7914294915, -1.1197711939], [0.0424480479, -2.1071237431, -0.7205292439], [0.051416093, -2.1148079393, 0.7011850782], [0.0029673515, -0.8035395846, 1.1152367693], [1.027926131, 1.828132226, 0.0028554018], [-0.5211791052, 1.8337374247, -0.8690329378], [-0.5054481987, 1.8237838058, 0.9020802045], [-0.0304306812, -0.3464381728, -2.1023011685], [0.0571095205, -2.9656370805, -1.3743309873], [0.0741869784, -2.9803348886, 1.3454340259], [-0.0055890963, -0.3691945839, 2.1026885337]]
[[0.0133507793, 1.5040546487, 0.006989626], [1.6733e-06, 0.0133769095, -0.0262146453], [-0.078210901, -0.8319277405, -1.0873479499], [-0.056644575, -2.1143654586, -0.6279941804], [0.035907

In [5]:
# inspect ibm all
df = pd.read_csv(os.path.join(ibm_path, ibm_all))
ibm_smile = df['smiles'].tolist()
ibm_u0 = df['u0'].tolist()

diff_smile_ibm = []
diff_smile_ori = []
diff_u0_ibm = []
diff_u0_ori = []

for i in range(len(ori_file)):
    with open(os.path.join(ori_path, ori_file[i]), 'r') as f:
        s = f.readlines()
        u0 = get_u0(s)
        smile = get_smile(s)

        if u0 != ibm_u0[i]:
            diff_u0_ori.append(u0)
            diff_u0_ibm.append(ibm_u0[i])

        if smile != ibm_smile[i]:
            diff_smile_ori.append(smile)
            diff_smile_ibm.append(ibm_smile[i])

print('all')
print(f'total: ibm {len(df)}, ori {len(ori_file)}')  # total: ibm 133885, ori 133885
print(f'{len(diff_u0_ibm)} u0 diff, {len(diff_smile_ibm)} smile diff')  # u0: 0, smile: 26059
print(f'u0: ibm{diff_u0_ibm}, ori{diff_u0_ori}')
print(f'smile head 3: ibm{diff_smile_ibm[:3]}, ori{diff_smile_ori[:3]}')
print(f'smile tail 3: ibm{diff_smile_ibm[-3:]}, ori{diff_smile_ori[-3:]}')

all
total: ibm 133885, ori 133885
0 u0 diff, 26059 smile diff
u0: ibm[], ori[]
smile head 3: ibm['C(=O)N', 'CC(=O)C', 'CC(=O)N'], ori['NC=O', 'CC(C)=O', 'CC(N)=O']
smile tail 3: ibm['C(CCO)CC(F)(F)F', 'C(CO)NCC(F)(F)F', 'C(COCC(F)(F)F)O'], ori['OCCCCC(F)(F)F', 'OCCNCC(F)(F)F', 'OCCOCC(F)(F)F']


In [6]:
# inspect ibm train
df = pd.read_csv(os.path.join(ibm_path, ibm_train))
ibm_smile = df['smiles'].tolist()
ibm_u0 = df['u0'].tolist()
ibm_id = df['mol_id']

diff_smile_ibm = []
diff_smile_ori = []
diff_u0_ibm = []
diff_u0_ori = []

for i in range(len(df)):
    id = int(ibm_id[i][4:])
    with open(os.path.join(ori_path, ori_file[id - 1]), 'r') as f:
        s = f.readlines()
        u0 = get_u0(s)
        smile = get_smile(s)

        if u0 != ibm_u0[i]:
            diff_u0_ori.append(u0)
            diff_u0_ibm.append(ibm_u0[i])

        if smile != ibm_smile[i]:
            diff_smile_ori.append(smile)
            diff_smile_ibm.append(ibm_smile[i])

print('train')
print(f'total: ibm {len(df)}, ori {len(df)}')  # ibm 108446, ori 108446
print(f'{len(diff_u0_ibm)} u0 diff, {len(diff_smile_ibm)} smile diff')  # u0: 0, smile: 21152
print(f'u0: ibm{diff_u0_ibm}, ori{diff_u0_ori}')
print(f'smile head 3: ibm{diff_smile_ibm[:3]}, ori{diff_smile_ori[:3]}')
print(f'smile tail 3: ibm{diff_smile_ibm[-3:]}, ori{diff_smile_ori[-3:]}')

train
total: ibm 108446, ori 108446
0 u0 diff, 21152 smile diff
u0: ibm[], ori[]
smile head 3: ibm['Cc1c2c(c[nH]1)CC=C2', 'C1CC1CC2CC2', 'c1(noc(=O)o1)N'], ori['CC1=C2C=CCC2=CN1', 'C(C1CC1)C1CC1', 'NC1=NOC(=O)O1']
smile tail 3: ibm['c1nc(c2n1CCO2)O', 'c1c(oc(n1)CO)C#N', 'C(#CC#N)C#CC(=O)N'], ori['OC1=C2OCCN2C=N1', 'OCC1=NC=C(O1)C#N', 'NC(=O)C#CC#CC#N']


In [7]:
# inspect ibm valid
df = pd.read_csv(os.path.join(ibm_path, ibm_valid))
ibm_smile = df['smiles'].tolist()
ibm_u0 = df['u0'].tolist()
ibm_id = df['mol_id']

diff_smile_ibm = []
diff_smile_ori = []
diff_u0_ibm = []
diff_u0_ori = []

for i in range(len(df)):
    id = int(ibm_id[i][4:])
    with open(os.path.join(ori_path, ori_file[id - 1]), 'r') as f:
        s = f.readlines()
        u0 = get_u0(s)
        smile = get_smile(s)

        if u0 != ibm_u0[i]:
            diff_u0_ori.append(u0)
            diff_u0_ibm.append(ibm_u0[i])

        if smile != ibm_smile[i]:
            diff_smile_ori.append(smile)
            diff_smile_ibm.append(ibm_smile[i])

print('valid')
print(f'total: ibm {len(df)}, ori {len(df)}')  # ibm 12050, ori 12050
print(f'{len(diff_u0_ibm)} u0 diff, {len(diff_smile_ibm)} smile diff')  # u0: 0, smile: 2311
print(f'u0: ibm{diff_u0_ibm}, ori{diff_u0_ori}')
print(f'smile head 3: ibm{diff_smile_ibm[:3]}, ori{diff_smile_ori[:3]}')
print(f'smile tail 3: ibm{diff_smile_ibm[-3:]}, ori{diff_smile_ori[-3:]}')

valid
total: ibm 12050, ori 12050
0 u0 diff, 2311 smile diff
u0: ibm[], ori[]
smile head 3: ibm['c1([nH]c(=O)nc(n1)O)N', 'c1noc(=NCCO)o1', 'CCC(CC)OCC'], ori['NC1=NC(O)=NC(=O)N1', 'OCCN=C1OC=NO1', 'CCOC(CC)CC']
smile tail 3: ibm['Cc1c(ncnc1F)N', 'c1c(c(c[nH]1)C=O)CO', 'C#Cc1c([nH]nn1)CO'], ori['CC1=C(N)N=CN=C1F', 'OCC1=CNC=C1C=O', 'OCC1=C(N=NN1)C#C']


In [8]:
# inspect ibm test
df = pd.read_csv(os.path.join(ibm_path, ibm_test))
ibm_smile = df['smiles'].tolist()
ibm_u0 = df['u0'].tolist()
ibm_id = df['mol_id']

diff_smile_ibm = []
diff_smile_ori = []
diff_u0_ibm = []
diff_u0_ori = []

for i in range(len(df)):
    id = int(ibm_id[i][4:])
    with open(os.path.join(ori_path, ori_file[id - 1]), 'r') as f:
        s = f.readlines()
        u0 = get_u0(s)
        smile = get_smile(s)

        if u0 != ibm_u0[i]:
            diff_u0_ori.append(u0)
            diff_u0_ibm.append(ibm_u0[i])

        if smile != ibm_smile[i]:
            diff_smile_ori.append(smile)
            diff_smile_ibm.append(ibm_smile[i])

print('test')
print(f'total: ibm {len(df)}, ori {len(df)}')  # ibm 13389, ori 13389
print(f'{len(diff_u0_ibm)} u0 diff, {len(diff_smile_ibm)} smile diff')  # u0: 0, smile: 2596
print(f'u0: ibm{diff_u0_ibm}, ori{diff_u0_ori}')
print(f'smile head 3: ibm{diff_smile_ibm[:3]}, ori{diff_smile_ori[:3]}')
print(f'smile tail 3: ibm{diff_smile_ibm[-3:]}, ori{diff_smile_ori[-3:]}')

test
total: ibm 13389, ori 13389
0 u0 diff, 2596 smile diff
u0: ibm[], ori[]
smile head 3: ibm['CC(C)COC=NC', 'CN=C(C(=O)N)N1CC1', 'CC1(CC(=O)CO1)C'], ori['CN=COCC(C)C', 'CN=C(N1CC1)C(N)=O', 'CC1(C)CC(=O)CO1']
smile tail 3: ibm['C1CC(=NO)C(=C1)C=O', 'CCOc1c(nno1)N', 'CNc1ncnn1C'], ori['ON=C1CCC=C1C=O', 'CCOC1=C(N)N=NO1', 'CNC1=NC=NN1C']


In [9]:
def process_and_save(mode):
    """Process data and save to file.

    Returns
    -------
    save_dict: dict
        A dict containing z and pos.
        z is a 2D list, z[0] represents the first molecule, z[0][0] is the atomic number of
        one of the atoms in z[0].
        pos is a 3D list, pos[0] represents the first molecule, pos[0][0] is the xyz-coordinate of one of
        the atoms in the first molecule, pos[0][0][0] is x-coordinate (unit: Angstrom).
    """
    if mode == 'train':
        f = ibm_train
    elif mode == 'valid':
        f = ibm_valid
    elif mode == 'test':
        f = ibm_test
    else:
        raise RuntimeError('No {mode} mode.')

    df = pd.read_csv(os.path.join(ibm_path, f))
    ibm_id = df['mol_id']

    z = []
    pos = []

    for i in range(len(df)):
        id = int(ibm_id[i][4:])
        with open(os.path.join(ori_path, ori_file[id - 1]), 'r') as f:
            s = f.readlines()
            zz = get_z(s)
            pp = get_pos(s)
            z.append(zz)
            pos.append(pp)

    save_dict = {'z': z, 'pos': pos}
    with open(os.path.join(save_path, f'qm9_{mode}.pkl'), 'wb') as pkl_file:
        pickle.dump(save_dict, pkl_file)

In [10]:
if os.path.exists(save_path):
    print('Remove the original save file and regenerate.')
    shutil.rmtree(save_path)

os.mkdir(save_path)

process_and_save('train')
process_and_save('valid')
process_and_save('test')

Remove the original save file and regenerate.


In [13]:
# inspect pkl
with open(os.path.join(save_path, 'qm9_train.pkl'), 'rb') as f:
    data = pickle.load(f)
    print(data['z'][0])
    print(data['pos'][0])
    print(data['z'][1])
    print(data['pos'][1])
    print(data['z'][-1])
    print(data['pos'][-1])

[6, 8, 6, 6, 6, 8, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[[0.0963776586, 1.2907754001, 0.3875361606], [-0.0134906138, -0.0040809446, -0.1752918024], [-1.3039041799, -0.5175242395, -0.142579665], [-1.8554497207, -1.0766970244, 1.155527066], [-1.5133211443, -1.9900183774, -0.0054290001], [-2.5815626996, -2.390722721, -0.8174030835], [-2.7793779018, -1.387405183, -1.8331084743], [-4.2243363016, -1.4381600614, -2.2933760041], [-2.3216309068, -0.0561030174, -1.1923438233], [1.1515541068, 1.5694953214, 0.3413062617], [-0.490824695, 2.0322110605, -0.1732546198], [-0.2357921833, 1.3055387235, 1.435057744], [-1.1758583873, -1.0972006743, 2.0016471808], [-2.9036424929, -0.9302694912, 1.3944811698], [-0.6875244151, -2.6899992002, 0.0729807649], [-2.1127372941, -1.6177672702, -2.6794364372], [-4.8991139651, -1.20951188, -1.4624383976], [-4.3961202888, -0.7111189737, -3.0933997854], [-4.4724655925, -2.432956593, -2.6737181517], [-3.1687904259, 0.4807952524, -0.7509536733], [-1.8425652369, 0.