# FG_Dataset lmdb creation

## Variables

In [78]:
from pathlib import Path

ROOT_DIR = Path("./datasets/")              # Working directory

TARBALL = ROOT_DIR/"FG_dataset_lite.tar.xz" # Location of the dataset tarball
# TARBALL = None                            # Set to False or None to avoid extraction.
DS_DIR = ROOT_DIR/"FG_dataset_lite"         # Dir of the initial Dataset
DS_DIR_OUT = Path("./lmdb/lmdb_crossval")   # Dir of the output dataset
INITIAL_GEOMETRY = "contcar"                # Either look for contcar or poscar files
SPLIT_CV = { "seed": 42                     # Seed that will be used during the random splitting
            ,"n_splits": 5                  # Number of splits
            ,"val_size": 1                  # Number of splits in the validation set
            ,"test_size": 1 }               # Number of splits in the test set
#SPLIT_CV = False                           # Set it to False or None to avoid splitting

## Extract Tarball

In [79]:
# Extract tarball to DS_DIR location
if TARBALL:
    import tarfile
    tar_ds = tarfile.open(TARBALL, mode="r:xz")
    tar_ds.extractall(DS_DIR)
    tar_ds.close()

## Auxiliary Functions

In [80]:
# Read a file with two columns and transform it to a dictionary
def read_two_columns(filename):
    with open(filename, "r", encoding="utf-8") as f:
        return map(
            lambda l: l.split()
            , f.readlines())

## Read structures and Energies

Read structures, inital and final energies.

In [81]:
from pathlib import Path
from ase.io.vasp import read_vasp, read_vasp_out
from ase.calculators.singlepoint import SinglePointCalculator

# Get energies in files
iener_dict = dict(read_two_columns(DS_DIR/"energies_i.dat"))
fener_dict = dict(read_two_columns(DS_DIR/"energies.dat"))

def get_struct(fname):
    final = read_vasp(fname)
    final._calc = SinglePointCalculator(final, energy=float(fener_dict[fname.stem]))
    return final                                
                                                
strct_dict = dict(map(
    lambda d: (d.stem, get_struct(d))
    , DS_DIR.glob(f"./*/*/*.{INITIAL_GEOMETRY}")))

## Get Groups

In [82]:
from functools import reduce
from itertools import chain

def reduce_grp(d, i):
    match i:
        case (k, v) if k in d: d[k].append(v)
        case (k, v): d[k] = [v]
    return d

groups_direct = dict(read_two_columns(DS_DIR/"groups.dat"))
groups_invert = map(
    lambda xs: xs[::-1]
    , groups_direct.items())

groups_dict = reduce(
    reduce_grp
    , groups_invert
    , {})

## Samples Dictionary

In [87]:
# Apply a filter to avoid collecting the metals
filter_fn = lambda x: "0000" not in x
# Use only final energies if contcar is selected
if INITIAL_GEOMETRY == "contcar": 
    ener_pvt_dict = fener_dict
else:
    ener_pvt_dict = iener_dict

ener_strct_dict = dict(map(
    lambda x: (x, dict(name=x
                       , fener=float(fener_dict[x])
                       , iener=float(ener_pvt_dict[x])
                       , image=strct_dict[x]
                       , group=groups_direct[x]))
    , filter(filter_fn, fener_dict.keys())))

## Extract Structures

In [90]:
from ocpmodels.preprocessing import AtomsToGraphs
import torch

a2g = AtomsToGraphs(
    max_neigh=50,
    radius=6,
    r_energy=True,
    r_forces=False,
    r_distances=False,
    r_fixed=True,
)

def read_entry_extract_features(a2g, strc):
    tags = strc.get_tags()
    data_objects = a2g.convert_all([strc], disable_tqdm=True)
    data_objects[0].tags = torch.LongTensor(tags)
    return data_objects

def model_dict(xs):
    idx = 0
    out_dict = {}
    for key, value in xs.items():
        data_objects = read_entry_extract_features(a2g, value['image'])
        init = data_objects[0]
    
        init.y_init = value["iener"]
        init.y_relaxed = init.y
        del init.y
        # As we are performing a IS2RE the final structure is not needed.
        init.pos_relaxed = init.pos 
    
        init.sid = idx
        # Saving name and group for later identification.
        init.name = value["name"]
        init.group = value["group"]
        
        if init.edge_index.shape[1] == 0:
            print("no neighbors", idx)
            continue
        idx += 1
        out_dict[key] = init
    return out_dict

In [91]:
ase_dict = model_dict(ener_strct_dict)

In [7]:
print(f"Samples in the dataset: {len(ase_dict)}")

Samples in the dataset: 3255


## Process Structures

## Split Sets

split_ds function splits the dataset into n_split 

In [8]:
def split_ds(ase_dict, groups_dict, seed=42, n_splits=5, val_size=1, test_size=1):
    from random import seed, shuffle
    from collections import deque
    from numpy import array_split
    from itertools import chain, combinations, product
    
    seed(seed)

    # Randomly shuffle the values stored in groups dict
    deque(map(
        shuffle
        , groups_dict.values())
        , maxlen=0)

    # Filter structures that are in groups_dict but are not present
    # in ase_dict
    filtered_groups_dict = dict(map(
        lambda xs: (xs[0]
                    , tuple(filter(lambda x: x in ase_dict.keys()
                             , xs[1])))
        , groups_dict.items()))
    
    # Split the groups entries equally into n_splits slices
    slices = reduce(
        lambda l, t: map(lambda x: tuple(chain.from_iterable(x))
                         , zip(l, t))
        , map(lambda x: array_split(x, n_splits)
            , filtered_groups_dict.values())
        , [[]]*n_splits)

    k_sets = set(map(
        lambda x: tuple(map(ase_dict.get, x))
        , slices))
    
    val_set = combinations(k_sets, val_size)
    test_set = combinations(k_sets, test_size)
    # Quick filter to discard combinations that lead to intersections between
    # validation and test datasets.
    val_test_comb = filter(
        lambda xs: not set(xs[0]).intersection(set(xs[1]))
        , product(val_set, test_set))
    
    # Chain the slices into training test and val
    chain_n_tuple = lambda xs: tuple(chain.from_iterable(xs))
    return map(
        lambda xs: (chain_n_tuple(k_sets.difference(set(set(chain.from_iterable(xs)))))
                    , chain_n_tuple(xs[0])
                    , chain_n_tuple(xs[1]))
        , val_test_comb)

## Write data to LMDB

Write the three datasets into the lmdb format

In [9]:
import lmdb
import pickle
from pathlib import Path

def dump_db(xs, db_path):
    db = lmdb.open(
        str(db_path),
        map_size=1099511627776 * 2,
        subdir=False,
        meminit=False,
        map_async=True,
    )
    idx = 0
    for value in xs:
        txn = db.begin(write=True)
        txn.put(f"{idx}".encode("ascii"), pickle.dumps(value, protocol=-1))
        txn.commit()
        db.sync()
        idx += 1
    db.close()

In [10]:
from pathlib import Path
from os import makedirs

mkdir_p = lambda p: p.is_dir() or makedirs(p)

# Write three different lmdb for each of the splittings.
if SPLIT_CV:
    splitted_sets = split_ds(ase_dict, groups_dict, **SPLIT_CV)
    for idx, n_set in enumerate(splitted_sets):
        train, test, val = n_set
        dpath = DS_DIR_OUT/str(idx)
        mkdir_p(dpath)
        dump_db(train, dpath/"train.lmdb")
        dump_db(test, dpath/"test.lmdb")
        dump_db(val, dpath/"val.lmdb")
else:
    dpath = DS_DIR_OUT
    mkdir_p(dpath)
    dump_db(ase_dict, dpath/"test.lmdb")

## Metrics

### Compute Metrics

In [62]:
import numpy as np
import pandas as pd
from ocpmodels.datasets import SinglePointLmdbDataset
from pathlib import Path

if SPLIT_CV:
    target_glob = "./*/*.lmdb"
else:
    target_glob = "./*.lmdb"

# Extract some useful metrics from the datasets
# This step is not needed and can be done before without
# needing to read the datasets again, but I used it as a
# sanity check, as it is easy to detect errors.
def get_metrics(lmdb_ds_path):
    ds_arr = np.asarray(tuple(
        map(
            lambda x: x.y_relaxed
            , SinglePointLmdbDataset({"src": str(lmdb_ds_path)})))
        , dtype=float)
    return {
        "mean": np.mean(ds_arr)
        , "std": np.std(ds_arr)
        , "idx": lmdb_ds_path.parent.name
        , "split": lmdb_ds_path.stem
        , "path": lmdb_ds_path
        , "samples": ds_arr.shape[0]
    }

metrics_df = pd.DataFrame(map(
    get_metrics
    , DS_DIR_OUT.glob(("./*.lmdb", "./*/*.lmdb")[bool(SPLIT_CV)])))

  data = list(data)


In [125]:
from collections import deque

deque(map(lambda xs: xs[1].to_csv( Path(xs[1]["path"].iloc[0]).parent/"metrics.csv"
                            , header=False
                            , index=False)
    , metrics_df.groupby("idx")[["split", "path", "mean", "std"]]))

print(f"Written metrics for the ds in {DS_DIR_OUT}")

Written metrics for the ds in data/toy_ds


### Show Metrics

In [83]:
metrics_df

Unnamed: 0,mean,std,idx,split,path,samples
0,-321.685422,150.912688,4,val,data/toy_ds/4/val.lmdb,650
1,-332.431521,146.242296,4,test,data/toy_ds/4/test.lmdb,650
2,-318.989351,152.354499,4,train,data/toy_ds/4/train.lmdb,1955
3,-321.417012,153.527547,14,test,data/toy_ds/14/test.lmdb,651
4,-321.964564,149.681418,14,train,data/toy_ds/14/train.lmdb,1952
5,-323.746765,152.143822,14,val,data/toy_ds/14/val.lmdb,652
6,-321.417012,153.527547,13,test,data/toy_ds/13/test.lmdb,651
7,-319.077396,151.483806,13,train,data/toy_ds/13/train.lmdb,1954
8,-332.431521,146.242296,13,val,data/toy_ds/13/val.lmdb,650
9,-311.808,151.121444,19,test,data/toy_ds/19/test.lmdb,652
