In [1]:
import os
import numpy as np
import pandas as pd
import pickle
import lmdb
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch_geometric.data import Data

In [2]:
df=pd.read_json('zinc_ev.json')
df.head()

Unnamed: 0,atomic_number,ID,pos,force,natoms,E
0,"[30, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 1, 6, 1,...",5338_con_7,"[[-2.0293085785, -0.06465452890000001, 0.26844...","[[-0.0068917709, 0.000859557, 0.0125995095], [...",45,-286.736237
1,"[30, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 1, 6, 1,...",5338_con_3,"[[-2.0085821602, 0.0827816171, 0.2863010338], ...","[[-0.0061126062000000005, -0.0038732978, 0.012...",45,-286.76944
2,"[30, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 1, 6, 1,...",5338_con_2,"[[-1.9969567119, -0.29290715030000003, 0.22799...","[[-0.005940026, -0.0030673873, 0.0121709923], ...",45,-286.76944
3,"[30, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 1, 6, 1,...",5338_con_18,"[[-1.9803014145, -0.24409732120000002, 0.29334...","[[-0.0137301872, -0.0010087369, -0.0020015222]...",45,-286.716309
4,"[30, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 1, 6, 1,...",5338_con_10,"[[-2.0240986138, 0.0766847985, 0.2796353446], ...","[[-0.0064623808000000005, -0.0038127322, 0.011...",45,-286.742889


In [3]:
df_shuffled = df.sample(frac=1, random_state=42)
n_train = 31611
n_val = 4000

In [4]:
df_shuffled['E'][:n_train].mean()

-301.2988981640536

In [5]:
df_shuffled['E'][:n_train].std()

53.383123813548536

In [6]:
db = lmdb.open('train.lmdb',
    map_size=1099511627776 * 2,
    subdir=False,
    meminit=False,
    map_async=True,
)

In [7]:
i=0
for index, entry in tqdm(df_shuffled[:n_train].iterrows(), total=len(df_shuffled[:n_train])):
    
    atomic_numbers = torch.tensor(entry.atomic_number)
    pos = torch.tensor(entry.pos, dtype=torch.float32)
    natoms = torch.tensor(entry.natoms)
    fixed = torch.zeros(natoms, dtype=torch.float32)
    sid = entry.ID
    y = torch.tensor(entry.E)
    
    data = Data(
            pos=pos,
            atomic_numbers=atomic_numbers,
            natoms=natoms,
            sid=sid,
            fixed=fixed,
            y = y
        )
    
    txn = db.begin(write=True)
    txn.put(f"{i}".encode("ascii"), pickle.dumps(data, protocol=-1))
    txn.commit()
    db.sync()
    i+=1

db.close()  

100%|██████████| 31611/31611 [01:20<00:00, 392.71it/s]


In [8]:
db = lmdb.open(
    'val.lmdb',
    map_size=1099511627776 * 2,
    subdir=False,
    meminit=False,
    map_async=True,
)


i=0
for index, entry in tqdm(df_shuffled[n_train:n_train+n_val].iterrows(), total=len(df_shuffled[n_train:n_train+n_val])):
    
    atomic_numbers = torch.tensor(entry.atomic_number)
    pos = torch.tensor(entry.pos, dtype=torch.float32)
    natoms = torch.tensor(entry.natoms)
    fixed = torch.zeros(natoms, dtype=torch.float32)
    sid = entry.ID
    y = torch.tensor(entry.E)
    
    data = Data(
            pos=pos,
            atomic_numbers=atomic_numbers,
            natoms=natoms,
            sid=sid,
            fixed=fixed,
            y = y
        )
    
    txn = db.begin(write=True)
    txn.put(f"{i}".encode("ascii"), pickle.dumps(data, protocol=-1))
    txn.commit()
    db.sync()
    i+=1

db.close()

100%|██████████| 4000/4000 [00:09<00:00, 415.31it/s]


In [9]:
db = lmdb.open(
    'test.lmdb',
    map_size=1099511627776 * 2,
    subdir=False,
    meminit=False,
    map_async=True,
)


i=0
for index, entry in tqdm(df_shuffled[n_train+n_val:].iterrows(), total=len(df_shuffled[n_train+n_val:])):
    
    atomic_numbers = torch.tensor(entry.atomic_number)
    pos = torch.tensor(entry.pos, dtype=torch.float32)
    natoms = torch.tensor(entry.natoms)
    fixed = torch.zeros(natoms, dtype=torch.float32)
    sid = entry.ID
    y = torch.tensor(entry.E)
    
    data = Data(
            pos=pos,
            atomic_numbers=atomic_numbers,
            natoms=natoms,
            sid=sid,
            fixed=fixed,
            y = y
        )
    
    txn = db.begin(write=True)
    txn.put(f"{i}".encode("ascii"), pickle.dumps(data, protocol=-1))
    txn.commit()
    db.sync()
    i+=1

db.close()

100%|██████████| 4000/4000 [00:09<00:00, 405.19it/s]
