In [None]:
import gc
import pickle
import zlib
import bz2
# import lz4
import tqdm

import lmdb
import torch
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import pandas as pd

from DataClasses import lmdb_dataset
from joblib import Parallel, delayed

np.set_printoptions(linewidth=100, precision=4, suppress=True)

from ModelFunctions import to_bins_torch, convert_angles, restore_edge_angles, preprocessing, my_reshape
from torch_geometric.data import Data
from functools import partial

#feather

#import pyarrow.feather as feather

### Datasets

In [None]:
#for train sample
dataset_size_list = {
    0: "10k",
    1: "100k",
    2: "all"
}
mode_list = {
    0: "train",
    1: "val",
    2: "test"
}
dataset_list = {
    1: "id",
    2: "ood_ads",
    3: "ood_cat",
    4: "ood_both"
}

In [None]:
# setting section
root = "../../ocp_datasets/data/is2re"
dataset_size = dataset_size_list[0]
mode = mode_list[0]
dataset = dataset_list[4]
#

In [None]:
def path_build(root, dataset_size, mode, dataset): # : mode: "origin, origin_old, target"
    path = f'{root}/{dataset_size}/{mode}'
    if mode != mode_list[0]:
        path = f'{path}_{dataset}'
    return path
    
dataset_origin_path_pkl = f'{path_build(root, dataset_size, mode, dataset)}/structures.pkl'
dataset_origin_path = f'{path_build(root, dataset_size, mode, dataset)}/data_mod.lmdb'
dataset_target_path = f'{path_build(root, dataset_size, mode, dataset)}/data_mod_conv.lmdb'

(print(
    f'dataset_origin_path_pkl: {dataset_origin_path_pkl}',
    f'dataset_origin_path: {dataset_origin_path}',
    f'dataset_target_path: {dataset_target_path}',
    sep='\n')
)
#/home/alex/Documents/ocp_datasets/data/is2re/all/val_ood_both
#/home/alex/Documents/ocp_datasets/data/is2re/all/test_ood_both

In [None]:
# dataset_origin_old = SinglePointLmdbDataset({"src": dataset_origin_old_path})

In [None]:
# dataset_origin_old[0]['distances']

In [None]:
# dataset_origin = pd.read_pickle(dataset_origin_path)

In [None]:
data_10k = lmdb_dataset(dataset_origin_path, compressed=False)
data_10k[0]

### update dataset

In [None]:
def update_dataset(dataset_target_path, dataset_origin_old_path, dataset_origin_path, features_names=None):
    dataset_origin = pd.read_pickle(dataset_origin_path)
    
    dataset_origin_old = lmdb_dataset(dataset_origin_path)
    
    dataset_target = lmdb.open(
        dataset_target_path,
        map_size=int(1e9*5), #~ 5 Gbyte
        subdir=False,
        meminit=False,
        map_async=True,
    )

    idx = 0

    for ii, data_object_origin_old in enumerate(dataset_origin_old):

            # Substitute: edge_index -> edge_index_new
            data_object = dataset_origin_old[ii]
            for feature_name in features_names:
                feature = torch.from_numpy(dataset_origin[ii][feature_name+'_new'])
                data_object[feature_name] = feature

            # Write to LMDB
            txn = dataset_target.begin(write=True)
            txn.put(f"{idx}".encode("ascii"), pickle.dumps(data_object, protocol=-1))
            txn.commit()
            dataset_target.sync()
            if idx % 1000 == 0:
                print('{} of {} for file {}'.format(idx, len(dataset_origin_old), dataset_target_path))
            idx += 1

    dataset_target.close()
    print("done")

### update_dataset_pyg2dict

In [None]:
def update_dataset_pyg2dict(dataset_target_path, dataset_origin_path):
        
    dataset_origin = lmdb_dataset(dataset_origin_path, compressed=False)
    
    dataset_target = lmdb.open(
        dataset_target_path,
        map_size=int(1e12), #~ 5 Gbyte
        subdir=False,
        meminit=False,
        map_async=False,
    )

    idx = 0
    
    for ii, element in enumerate(dataset_origin):

            # Substitute: edge_index -> edge_index_new
            
            element = dict(list(element))
            del element['edge_angles'][1::2]    
            for ii, el in enumerate(element['edge_angles']):
               element['edge_angles'][ii] = element['edge_angles'][ii].reset_index().values
            
            # Write to LMDB
            
            txn = dataset_target.begin(write=True)
            txn.put(f"{idx}".encode("ascii"), zlib.compress(pickle.dumps(element, protocol=-1), level=1))
            
            #txn.put(key=f"{idx}".encode("ascii"), value=pickle.dumps(element, protocol=-1))
            txn.commit()
            dataset_target.sync()
            
            if idx==10:
                break
            
            if idx%1000==0:
                print('{} of {} for file {}'.format(idx, len(dataset_origin), dataset_target_path))
            idx += 1
            
    print(dataset_target.info())
    dataset_target.close()
    print("done")

In [None]:
def update_dataset_pyg2dict_1(dataset_target_path, dataset_origin_path):
        
    dataset_origin = lmdb_dataset(dataset_origin_path, compressed=False)
    
    dataset_target = lmdb.open(
        dataset_target_path,
        map_size=int(1e12), #~ 5 Gbyte
        subdir=False,
        meminit=False,
        map_async=False,
    )

    idx = 0
    
    for ii, element in enumerate(dataset_origin):

            # Substitute: edge_index -> edge_index_new
            
            element = dict(list(element))
            del element['edge_angles'][1::2]    
            for ii, el in enumerate(element['edge_angles']):
               element['edge_angles'][ii] = element['edge_angles'][ii].reset_index().values
            
            # Write to LMDB
            
            txn = dataset_target.begin(write=True)
            txn.put(f"{idx}".encode("ascii"), zlib.compress(pickle.dumps(element, protocol=-1), level=1))
            
            #txn.put(key=f"{idx}".encode("ascii"), value=pickle.dumps(element, protocol=-1))
            txn.commit()
            dataset_target.sync()
            
            if idx==10:
                break
            
            if idx%1000==0:
                print('{} of {} for file {}'.format(idx, len(dataset_origin), dataset_target_path))
            idx += 1
            
    print(dataset_target.info())
    dataset_target.close()
    print("done")

In [None]:
update_dataset_pyg2dict(dataset_target_path, dataset_origin_path)

In [None]:
dataset_target_path_test ='../../ocp_datasets/data/is2re/10k/train/data_mod.lmdb'
suffix = '.lmdb'

### Benchmark of diffferent options for .lmdb

In [None]:
a = []
def bench_dataset(el):
    a = el

In [None]:
%%time
dataset_target = lmdb_dataset(dataset_target_path)
for el in enumerate(dataset_target):
    a = el

In [None]:
%%time
dataset_target = lmdb_dataset(dataset_target_path)
a = Parallel(n_jobs=-1)(delayed(bench_dataset)(dataset_target[i]) for i in range(len(dataset_target)))

In [None]:
%%time
dataset_target = lmdb_dataset(dataset_target_path_test+'_orig'+suffix, compressed=False)
for el in enumerate(dataset_target):
    a = el

In [None]:
%%time
dataset_target = lmdb_dataset(dataset_target_path_test+'_dict'+suffix, compressed=False)
for el in enumerate(dataset_target):
    a = el

In [None]:
%%time
dataset_target = lmdb_dataset(dataset_target_path_test+'_dict_short'+suffix, compressed=False)
for el in enumerate(dataset_target):
    a = el

In [None]:
%%time
dataset_target = lmdb_dataset(dataset_target_path_test+'_dict_short_numpy'+suffix, compressed=False)
for el in enumerate(dataset_target):
    a = el

In [None]:
%%time
dataset_target = lmdb_dataset(dataset_target_path_test+'_dict_short_numpy_zip'+suffix, compressed=True)
for el in enumerate(dataset_target):
    a = el

In [None]:
dataset_target = lmdb_dataset(dataset_target_path_test'_dict_short_numpy_zip'+suffix, compressed=True)
dataset_target[0]['edge_angles'][0]

#### Compressed pickle

In [None]:
%%time
with open('data_10k.pkl', 'wb') as f:
    f.write(pickle.dumps(dataset_target[0]))

In [None]:
%%time
with open('data_10k.pkl', 'rb') as f:
    data = f.read()
    data = pickle.loads(data)

In [None]:
%%time
with open('data_10k.pbz2', 'wb') as f:
    f.write(zlib.compress(pickle.dumps(dataset_target[0]), level = 1))

In [None]:
%%time
with open('data_10k.pbz2', 'rb') as f:
    data = f.read()
    data = pickle.loads(zlib.decompress(data))

#### Feather file-format

#### restore_angles

In [None]:
def convert_angles(array):
    array[:, 1] = np.pi - array[:, 1]
    array[:, 3] = -array[:, 3]
    return array

def restore_edge_angles(list_of_arrays):
    el_new= []
    for el in list_of_arrays:
        el_new.append(el)
        el_new.append(convert_angles(el.copy()))        
    return el_new

### Benchmark preprocessing

#### mod2

In [None]:
dataset_target_path ='../../ocp_datasets/data/is2re/10k/train/data_mod2.lmdbz'
dataset_target = lmdb_dataset(dataset_target_path)
suffix = '.lmdb'

suffix = dataset_target_path.split('.')[-1]
print(suffix)

compressed = (
    True if suffix == 'lmdbz'
    else False
)

print(compressed)
dataset_target.stat()

In [None]:
%%time
print(dataset_target_path)
dataset_target = lmdb_dataset(dataset_target_path, compressed=True)
for el in dataset_target:
    a = preprocessing(el)

**multiprocessing (do not work with n_jobs > 1)**

In [None]:
print(dataset_target_path)
dataset_target = lmdb_dataset(dataset_target_path, compressed=True)
#a = Parallel(n_jobs=-1)(delayed(preprocessing)(el) for el in dataset_target)
a = Parallel(n_jobs=-1)(delayed(preprocessing)(dataset_target[i]) for i in range(len(dataset_target)))

#### mod1

In [None]:
dataset_target_path ='../../ocp_datasets/data/is2re/10k/train/data_mod.lmdb'
suffix = '.lmdb'

In [None]:
dataset_target = lmdb_dataset(dataset_target_path)

In [None]:
dataset_target.stat()

In [None]:
%%time
print(dataset_target_path)
dataset_target = lmdb_dataset(dataset_target_path)
for el in dataset_target:
    a = preprocessing(el, opt='edges_only')

In [None]:
%%time
print(dataset_target_path)
dataset_target = lmdb_dataset(dataset_target_path)
a = Parallel(n_jobs=-1)(delayed(preprocessing)(dataset_target[i], opt='edges_only') for i in range(len(dataset_target)))

In [None]:
%%time
print(dataset_target_path)
dataset_target = lmdb_dataset(dataset_target_path)
With Pool(12) as p:
    a = p.map(preprocessing, dataset_target[i], opt='edges_only'), range(len(dataset_target))

### Parallel test

In [None]:
# def func(x):
#     #return sqrt(x)
#     return np.sin(x)/np.cos(x)
from test import func

n = 1000000
c = 8

In [None]:
import time
from math import sqrt
from joblib import Parallel, delayed
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool

from concurrent.futures import ThreadPoolExecutor
import numpy as np

In [None]:
start_t = time.time()
list_comprehension = [func(i) for i in range(n)]
print('Parallel: {} s'.format(time.time() - start_t))

In [None]:
start_t = time.time()
if __name__ == '__main__':
    list_from_parallel = Parallel(n_jobs=c)(delayed(func)(i) for i in range(n))
    print('Parallel: {} s'.format(time.time() - start_t))

In [None]:
%%time
from joblib import Parallel, delayed, parallel_backend

with parallel_backend("loky", inner_max_num_threads=1):
    results = Parallel(n_jobs=c)(delayed(func)(x) for x in range(n))    

In [None]:
%%time
with ThreadPoolExecutor(c) as executor:
    results = Parallel(n_jobs=c)(delayed(func)(x) for x in range(n))

In [None]:
start_t = time.time()
list_comprehension = [func(i) for i in range(n)]
print('Parallel: 0 Pool, {} s'.format(time.time() - start_t))

if __name__ == '__main__':

    for c in range(1,13):
        start_t = time.time()
        with Pool(c) as p:
            result = p.map(func, list(range(n)))
            print('Parallel: {} core, {} s'.format(c, time.time() - start_t))

In [None]:
import numpy as np
from matplotlib.path import Path
from joblib import Parallel, delayed
import time
import sys

## Check if one line segment contains another. 

def check_paths(path):
    for other_path in a:
        res='no cross'
        chck = Path(other_path)
        if chck.contains_path(path)==1:
            res = 'cross'
            break
    return res

if __name__ == '__main__':
    ## Create pairs of points for line segments
    a = zip(np.random.rand(5000,2),np.random.rand(5000,2))
    a = [Path(x) for x in a]
    b = zip(np.random.rand(300,2),np.random.rand(300,2))
    c = 2
    now = time.time()

    if c >= 2:
        res = Parallel(n_jobs=c) (delayed(check_paths) (Path(points)) for points in b)
    else:
        res = [check_paths(Path(points)) for points in b]
    print("Finished in", time.time()-now , "sec")

### merge datasets

In [None]:
def merge_datasets(dataset_origin_path_list, dataset_target_path, root_dir=""):
        
    dataset_target = lmdb.open(
      f'{root_dir}{dataset_target_path}',
        map_size=int(1e12), #~ 5 Gbyte
        subdir=False,
        meminit=False,
        map_async=True,
    )

    idx = 0
    
    with dataset_target.begin(write=True) as txn:
        for dataset_origin_file in dataset_origin_path_list:
            if dataset_origin_file != "batch10_data_mod1.lmdbz":
                to_dict = True
                byte = False
            else:
                to_dict = False
                byte = True

            dataset_origin = lmdb_dataset(f'{root_dir}{dataset_origin_file}', byte=byte)

            for element in dataset_origin:
                    if to_dict: 
                        element = dict(list(element))
                        # Write to LMDB
                        txn.put(f"{idx}".encode("ascii"), zlib.compress(pickle.dumps(element, protocol=-1), level=1))
                        #txn.put(key=f"{idx}".encode("ascii"), value=pickle.dumps(element, protocol=-1))
                    else:
                        txn.put(f"{idx}".encode("ascii"), element)
                    idx += 1

                    if idx%5000==0:
                        print('{} of {} for file {}'.format(idx, len(dataset_origin), dataset_target_path))
                
    print(dataset_target.info())
    dataset_target.close()
    print("done")

In [None]:
dataset_origin_path_list = ['data_mod1_0_50039.lmdbz',
'batch10_data_mod1.lmdbz',
'data_mod1_last.lmdbz']
dataset_target_path = 'merged/data_mod1.lmdbz'
root_dir = '/share/catalyst/ocp_datasets/is2re_test_challenge_2021/'

In [None]:
merge_datasets(dataset_origin_path_list, dataset_target_path, root_dir)

In [None]:
%%time
# check lmdb file
a = [0]
data_10 = lmdb_dataset(root_dir+dataset_target_path, byte=True)
for element in data_10:
    a[0] = element

### get keys of lmdb file

In [None]:
data_10_path = root_dir+dataset_origin_path_list[1]
data_10_path_env = lmdb.open(
            data_10_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=True,
            meminit=False,
            max_readers=1000,
        )

data_10_path_env.stat()["entries"]

with data_10_path_env.begin() as txn:
    with txn.cursor() as curs:
        keys = [key for key, value in curs]
        # print('key is:', curs.get('key'.encode('ascii')))

print(keys[1])

data_10_path_datapoint_pickled = data_10_path_env.begin().get(keys[1])
print(type(data_10_path_datapoint_pickled))

data_10_obj = pickle.loads(zlib.decompress(data_10_path_datapoint_pickled))
print(data_10_obj['natoms'])