# Train/test/val split for an ase db

In [1]:
from pathlib import Path
import numpy as np
from ase.db import connect

def train_test_val_split(ase_db, ttv=(0.8, 0.1, .1), files=('train.db', 'test.db', 'val.db'), seed=42):
    """Split an ase db into train, test and validation dbs.
    
    ase_db: path to an ase db containing all the data.
    ttv: a tuple containing the fraction of train, test and val data. This will be normalized.
    files: a tuple of filenames to write the splits into. An exception is raised if these exist. 
           You should delete them first.
    seed: an integer for the random number generator seed
    
    Returns the absolute path to files.
    """
    
    for db in files:
        if os.path.exists(db):
            raise Exception('{db} exists. Please delete it before proceeding.')
            
    src = connect(ase_db)
    N = src.count()
    
    ttv = np.array(ttv)
    ttv /= ttv.sum()
    
    train_end = int(N * ttv[0])
    test_end = train_end + int(N * ttv[1])
    
    train = connect(files[0])
    test = connect(files[1])
    val = connect(files[2])
    
    ids = np.arange(1, N + 1)
    rng = np.random.default_rng(seed=42)
    rng.shuffle(ids)
    
    for _id in ids[0:train_end]:
        row = src.get(id=int(_id))
        train.write(row.toatoms())
    
    for _id in ids[train_end:test_end]:
        row = src.get(id=int(_id))
        test.write(row.toatoms())
    
    for _id in ids[test_end:]:
        row = src.get(id=int(_id))
        val.write(row.toatoms())
        
    return [Path(f).absolute() for f in files]

# Generating a config from a checkpoint

In [126]:
from yaml import load, dump
from yaml import CLoader as Loader, CDumper as Dumper
import torch
import os
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
from io import StringIO
import sys
import contextlib

def generate_yml_config(checkpoint_path, yml='run.yml', delete=(), update=()):
    """Generate a yml config file from an existing checkpoint file.
    
    checkpoint_path: string to path of an existing checkpoint
    yml: name of file to write to.
    pop: list of keys to remove from the config
    update: dictionary of key:values to update
    
    Use a dot notation in update.
    
    Returns an absolute path to the generated yml file.
    """
             
    # You can't just read in the checkpoint with torch. The calculator does some things to it. 
    # Rather than recreate that here I just reuse the calculator machinery. I don't want to 
    # see the output though, so I capture it.

    with contextlib.redirect_stdout(StringIO()) as _:
        config = OCPCalculator(checkpoint=os.path.expanduser(checkpoint)).config
                       
    for key in delete:
        if key in config and len(key.split('.')) == 1:
            del config[key]
        else:
            keys = key.split('.')
            d = config[keys[0]]
            if isinstance(d, dict):
                for k in keys[1:]:
                    if isinstance(d[k], dict):
                        d = d[k]
                    else:
                        del d[k]
                        
    def nested_set(dic, keys, value):
        for key in keys[:-1]:
            dic = dic.setdefault(key, {})
        dic[keys[-1]] = value
               
    for _key in update:
        keys = _key.split('.')
        nested_set(config, keys, update[_key])

        
    out = dump(config)
    with open(yml, 'wb') as f:
        f.write(out.encode('utf-8'))
        
    return Path(yml).absolute()