In [1]:
import gc
import pickle
import zlib

import lmdb
import torch

import pandas as pd
from dataloader import lmdb_dataset

In [2]:
#for train sample
dataset_size_list = {
    0: "10k",
    1: "100k",
    2: "all"
}
mode_list = {
    0: "train",
    1: "val",
    2: "test"
}
dataset_list = {
    1: "id",
    2: "ood_ads",
    3: "ood_cat",
    4: "ood_both"
}

# setting section
root = "../../ocp_datasets/data/is2re"
dataset_size = dataset_size_list[0]
mode = mode_list[0]
dataset = dataset_list[4]
#

def path_build(root, dataset_size, mode, dataset): # : mode: "origin, origin_old, target"
    path = f'{root}/{dataset_size}/{mode}'
    if mode != mode_list[0]:
        path = f'{path}_{dataset}'
    return path
    
dataset_origin_path_pkl = f'{path_build(root, dataset_size, mode, dataset)}/structures.pkl'
dataset_origin_path = f'{path_build(root, dataset_size, mode, dataset)}/data_mod2.lmdb'
dataset_target_path = f'{path_build(root, dataset_size, mode, dataset)}/data_mod2_torch.lmdb'

(print(
    f'dataset_origin_path_pkl: {dataset_origin_path_pkl}',
    f'dataset_origin_path: {dataset_origin_path}',
    f'dataset_target_path: {dataset_target_path}',
    sep='\n')
)
#/home/alex/Documents/ocp_datasets/data/is2re/all/val_ood_both
#/home/alex/Documents/ocp_datasets/data/is2re/all/test_ood_both


dataset_origin_path_pkl: ../../ocp_datasets/data/is2re/10k/train/structures.pkl
dataset_origin_path: ../../ocp_datasets/data/is2re/10k/train/data_mod2.lmdb
dataset_target_path: ../../ocp_datasets/data/is2re/10k/train/data_mod2_torch.lmdb


In [3]:
# dataset_origin_old = SinglePointLmdbDataset({"src": dataset_origin_old_path})

In [4]:
# dataset_origin_old[0]['distances']

In [5]:
# dataset_origin = pd.read_pickle(dataset_origin_path)

In [5]:
data_10k = lmdb_dataset(dataset_origin_path, compressed=False)
data_10k[0]

Data(atomic_numbers=[86], cell=[1, 3, 3], cell_offsets=[2964, 3], cell_offsets_new=[1214, 3], contact_solid_angles=[1214], direct_neighbor=[1214], distances=[2964], distances_new=[1214], edge_angles=[1214], edge_index=[2, 2964], edge_index_new=[2, 1214], fixed=[86], force=[86, 3], natoms=86, pos=[86, 3], pos_relaxed=[86, 3], sid=2472718, spherical_domain_radii=[86], tags=[86], voronoi_surface_areas=[86], voronoi_volumes=[86], y_init=6.282500615000004, y_relaxed=-0.025550085000020317)

### update dataset

In [7]:
def update_dataset(dataset_target_path, dataset_origin_old_path, dataset_origin_path, features_names=None):
    dataset_origin = pd.read_pickle(dataset_origin_path)
    
    dataset_origin_old = lmdb_dataset(dataset_origin_path)
    
    dataset_target = lmdb.open(
        dataset_target_path,
        map_size=int(1e9*5), #~ 5 Gbyte
        subdir=False,
        meminit=False,
        map_async=True,
    )

    idx = 0

    for ii, data_object_origin_old in enumerate(dataset_origin_old):

            # Substitute: edge_index -> edge_index_new
            data_object = dataset_origin_old[ii]
            for feature_name in features_names:
                feature = torch.from_numpy(dataset_origin[ii][feature_name+'_new'])
                data_object[feature_name] = feature

            # Write to LMDB
            txn = dataset_target.begin(write=True)
            txn.put(f"{idx}".encode("ascii"), pickle.dumps(data_object, protocol=-1))
            txn.commit()
            dataset_target.sync()
            if idx % 1000 == 0:
                print('{} of {} for file {}'.format(idx, len(dataset_origin_old), dataset_target_path))
            idx += 1

    dataset_target.close()
    print("done")

#### update_dataset_pyg2dict

In [8]:
def update_dataset_pyg2dict(dataset_target_path, dataset_origin_path):
        
    dataset_origin = lmdb_dataset(dataset_origin_path, compressed=False)
    
    dataset_target = lmdb.open(
        dataset_target_path,
        map_size=int(1e9*20), #~ 5 Gbyte
        subdir=False,
        meminit=False,
        map_async=True,
    )

    idx = 0

    for ii, element in enumerate(dataset_origin):

            # Substitute: edge_index -> edge_index_new
            #element = dict(list(element))
            #del element['edge_angles'][1::2]    

            # Write to LMDB
            txn = dataset_target.begin(write=True)
#             txn.put(f"{idx}".encode("ascii"), zlib.compress(pickle.dumps(element, protocol=-1), level=1))
            txn.put(f"{idx}".encode("ascii"), pickle.dumps(element, protocol=-1))
            txn.commit()
            dataset_target.sync()
            if idx % 1000 == 0:
                print('{} of {} for file {}'.format(idx, len(dataset_origin), dataset_target_path))
            idx += 1

    dataset_target.close()
    print("done")

In [None]:
update_dataset_pyg2dict(dataset_target_path, dataset_origin_path)

0 of 10000 for file ../../ocp_datasets/data/is2re/10k/train/data_mod2_torch.lmdb


In [87]:
dataset_target = lmdb_dataset(dataset_target_path)

In [None]:
dataset_target[0]

#### Compressed pickle

In [54]:
%%time
with open('data_10k.pkl', 'wb') as f:
    f.write(pickle.dumps(dataset_target[0]))

CPU times: user 217 ms, sys: 0 ns, total: 217 ms
Wall time: 215 ms


In [60]:
%%time
with open('data_10k.pkl', 'rb') as f:
    data = f.read()
    data = pickle.loads(data)

<class 'bytes'>
CPU times: user 106 ms, sys: 14 µs, total: 106 ms
Wall time: 103 ms


In [56]:
%%time
with open('data_10k.pbz2', 'wb') as f:
    f.write(zlib.compress(pickle.dumps(dataset_target[0]), level = 1))

CPU times: user 218 ms, sys: 51 µs, total: 218 ms
Wall time: 216 ms


In [59]:
%%time
with open('data_10k.pbz2', 'rb') as f:
    data = f.read()
    data = pickle.loads(zlib.decompress(data))

<class 'bytes'>
CPU times: user 187 ms, sys: 3.95 ms, total: 191 ms
Wall time: 189 ms


#### remove_odd_edge_angles

In [83]:
data = dataset_target[0]
edge = data['edge_angles']

del data['edge_angles'][1::2]
# del edge[1::2]

print(len(data['edge_angles']))
print(len(edge))

607
607
