In [1]:
import glob
import logging
from contextlib import contextmanager
from pathlib import Path

import dask.dataframe as dd
import numpy as np
import pandas as pd
import tensorflow as tf
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from pymatgen.core import Structure
from tqdm.notebook import tqdm

tqdm.pandas()
logger = logging.getLogger()

  from distributed.utils import tmpfile
2022-06-22 09:40:00.465996: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /nopt/slurm/current/lib:/nopt/slurm/current/lib::/home/pstjohn/lib:/home/pstjohn/lib
2022-06-22 09:40:00.466035: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
import os

os.chdir("../")
os.getcwd()

'/home/pstjohn/Research/rlmolecule/examples/crystal_energy'

In [3]:
import sys

sys.path.insert(0, "../../")

In [4]:
from rlmolecule.sql.run_config import RunConfig
from scripts import compute_reward_decors as rew_decors

INFO:rdkit:Enabling RDKit 2022.03.3 jupyter extensions


In [5]:
config_file = "config/20220617_lt15stoich_battclust0_01/r_90.yaml"
energy_model_file = (
    "/projects/rlmolecule/pstjohn/models/20220607_icsd_and_battery/best_model.hdf5"
)

run_config = RunConfig(config_file)

24 actions_to_ignore


In [6]:
# load the decoration IDs that have already been computed
strc_ids_files = [
    "/projects/rlmolecule/jlaw/logs/crystal_energy/20220617-batt-icsd-vol-r90-2/states_seen.csv.gz",
    "/projects/rlmolecule/jlaw/logs/crystal_energy/20220617-batt-icsd-vol-r90-no-cond-ion-2/states_seen.csv.gz",
    "/projects/rlmolecule/jlaw/logs/crystal_energy/20220617-batt-icsd-vol-r90-no-halides-2/states_seen.csv.gz",
]

states_seen = set()
for strc_ids_file in strc_ids_files:
    states = set(pd.read_csv(strc_ids_file)["states"])
    states_seen.update(states)
    print(f"{len(states)} states read from {strc_ids_file}")
print(f"{len(states_seen)} total")

4197573 states read from /projects/rlmolecule/jlaw/logs/crystal_energy/20220617-batt-icsd-vol-r90-2/states_seen.csv.gz
4177129 states read from /projects/rlmolecule/jlaw/logs/crystal_energy/20220617-batt-icsd-vol-r90-no-cond-ion-2/states_seen.csv.gz
3381483 states read from /projects/rlmolecule/jlaw/logs/crystal_energy/20220617-batt-icsd-vol-r90-no-halides-2/states_seen.csv.gz
6578563 total


In [7]:
competing_phases = rew_decors.load_competing_phases("inputs/competing_phases.csv")

# load the icsd prototype structures
prob_config = run_config.problem_config
prototypes_file = prob_config["prototypes_file"]
prototype_structures = rew_decors.read_structures_file(prototypes_file)
# make sure the prototype structures don't have oxidation states
from pymatgen.transformations.standard_transformations import (
    OxidationStateRemovalTransformation,
)

oxidation_remover = OxidationStateRemovalTransformation()
prototype_structures = {
    s_id: oxidation_remover.apply_transformation(s)
    for s_id, s in prototype_structures.items()
}

preprocessor = rew_decors.AtomicNumberPreprocessor()
# energy_model = rew_decors.load_model(energy_model_file)

	12682 lines
  sortedformula   icsdnum  energyperatom reduced_composition
0    Ag10Br3Te4  173116.0      -1.718985          Ag10Br3Te4
1   Ag11K1O16V4  391344.0      -4.797702         Ag11K1O16V4


INFO:scripts.compute_reward_decors:reading ../../rlmolecule/crystal/inputs/icsd_train_and_proto_max_comp_atoms15/KLiNa_add_clust0_01_min10prototypes.json.gz


	12682 entries


INFO:scripts.compute_reward_decors:	14494 structures read


### Generate all possible decoration IDs

In [8]:
from rlmolecule.crystal.crystal_reward import CrystalStateReward

rewarder = CrystalStateReward(
    competing_phases, prototype_structures, None, preprocessor
)

# generate all the decoration IDs
prob_config = run_config.problem_config
builder = rew_decors.CrystalBuilder(
    G=prob_config.get("action_graph1"),
    G2=prob_config.get("action_graph2"),
    actions_to_ignore=prob_config.get("actions_to_ignore"),
)

gen_decors = rew_decors.GenerateDecorations(builder)
decor_ids = gen_decors.generate_all_decorations()
decor_ids = set(decor_ids)

Read G1: ../../rlmolecule/crystal/inputs/icsd_train_and_proto_max_comp_atoms15/KLiNa_add_clust0_01_min10eles_to_comps.edgelist.gz (236167 nodes, 250002 edges)
Read G2: ../../rlmolecule/crystal/inputs/icsd_train_and_proto_max_comp_atoms15/KLiNa_add_clust0_01_min10comp_type_to_decors.edgelist.gz (49108 nodes, 48931 edges)
24 and 0 actions to ignore in G and G2, respectively


  0%|          | 0/15000000 [00:00<?, ?it/s]

In [9]:
new_decors = decor_ids - states_seen
print(f"{len(new_decors)} new decorations")

7704569 new decorations


## Compute the reward for each decoration

### Example for a few structures

In [10]:
# test on a couple decoration ids
# decoration_rewards = rew_decors.compute_rewards(list(decor_ids)[:10], rewarder, info_to_keep=info_to_keep)
df_ids = pd.DataFrame(list(decor_ids), columns=["decor_id"])
print(df_ids.head(2))

                                         decor_id
0  Na1Zr1O1P1|_1_1_1_1|orthorhombic|icsd_028355|7
1  K2Ti1Hg2S5|_1_2_2_5|orthorhombic|icsd_422825|1


In [11]:
# df_ids.to_parquet('/projects/rlmolecule/pstjohn/crystal_dask/20220622_decorations.parquet')

In [12]:
decor_id = df_ids.iloc[0].decor_id

In [13]:
from rlmolecule.crystal.crystal_state import CrystalState


def create_structure(decor_id):
    comp = decor_id.split("|")[0]
    action_node = "|".join(decor_id.split("|")[1:])
    state = CrystalState(action_node, composition=comp, terminal=True)
    structure = rewarder.generate_structure(state)
    return structure

In [24]:
def load_rewarder():
    competing_phases = rew_decors.load_competing_phases("inputs/competing_phases.csv")
    energy_model_file = (
        "/projects/rlmolecule/pstjohn/models/20220607_icsd_and_battery/best_model.hdf5"
    )
    energy_model = rew_decors.load_model(energy_model_file)

    # load the icsd prototype structures
    prob_config = run_config.problem_config
    prototypes_file = prob_config["prototypes_file"]
    prototype_structures = rew_decors.read_structures_file(prototypes_file)
    # make sure the prototype structures don't have oxidation states
    from pymatgen.transformations.standard_transformations import (
        OxidationStateRemovalTransformation,
    )

    oxidation_remover = OxidationStateRemovalTransformation()
    prototype_structures = {
        s_id: oxidation_remover.apply_transformation(s)
        for s_id, s in prototype_structures.items()
    }

    preprocessor = rew_decors.AtomicNumberPreprocessor()
    energy_model = rew_decors.load_model(energy_model_file)

    rewarder = rew_decors.CrystalStateReward(
        competing_phases, prototype_structures, energy_model, preprocessor
    )
    return rewarder


test_rewarder = load_rewarder()

	12682 lines
  sortedformula   icsdnum  energyperatom reduced_composition
0    Ag10Br3Te4  173116.0      -1.718985          Ag10Br3Te4
1   Ag11K1O16V4  391344.0      -4.797702         Ag11K1O16V4
	12682 entries
Reading /projects/rlmolecule/pstjohn/models/20220607_icsd_and_battery/best_model.hdf5


INFO:scripts.compute_reward_decors:reading ../../rlmolecule/crystal/inputs/icsd_train_and_proto_max_comp_atoms15/KLiNa_add_clust0_01_min10prototypes.json.gz
INFO:scripts.compute_reward_decors:	14494 structures read


Reading /projects/rlmolecule/pstjohn/models/20220607_icsd_and_battery/best_model.hdf5


In [25]:
info_to_keep = [
    "predicted_energy",
    "decomp_energy",
    "cond_ion_frac",
    "reduction",
    "oxidation",
    "stability_window",
]

decoration_rewards = df_ids.head(10).decor_id.progress_apply(
    lambda x: rew_decors.compute_reward(
        x,
        test_rewarder,
        info_to_keep=info_to_keep,
    )
)

cols = ["id", "reward"] + info_to_keep

results = {}
for i, col in enumerate(cols):
    results[col] = decoration_rewards.map(lambda x: x[i] if i < len(x) else np.nan)

test_results = pd.DataFrame(results)

  0%|          | 0/10 [00:00<?, ?it/s]

In [36]:
from contextlib import contextmanager

import dask.dataframe as dd
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

n_nodes = 10


@contextmanager
def dask_cluster():

    n_processes = 36  # number of processes to run on each node
    memory = 90000  # to fit on a standard node; ask for 184,000 for a bigmem node

    cluster = SLURMCluster(
        project="rlmolecule",
        walltime="180",  # 30 minutes to fit in the debug queue; 180 to fit in short
        job_mem=str(memory),
        job_cpu=36,
        interface="ib0",
        local_directory="/tmp/scratch/dask-worker-space",
        cores=36,
        processes=n_processes,
        memory="{}MB".format(memory),
        extra=["--lifetime-stagger", "60m"],
        # job_extra=["--partition debug"],
    )

    print(cluster.job_script())

    client = Client(cluster)
    cluster.scale(n_processes * n_nodes)

    try:
        yield client, cluster

    finally:
        cluster.close()
        client.close()

In [None]:
def batch_calc_rewards(x):

    rewarder = load_rewarder()

    info_to_keep = [
        "predicted_energy",
        "decomp_energy",
        "cond_ion_frac",
        "reduction",
        "oxidation",
        "stability_window",
    ]

    decoration_rewards = x.apply(
        lambda x: rew_decors.compute_reward(
            x,
            rewarder,
            info_to_keep=info_to_keep,
        )
    )

    cols = ["id", "reward"] + info_to_keep

    results = {}
    for i, col in enumerate(cols):
        results[col] = decoration_rewards.map(lambda x: x[i] if i < len(x) else np.nan)

    results = pd.DataFrame(results)
    return results


df_dask = dd.from_pandas(df_ids, npartitions=(n_nodes * 36)//5)
results = df_dask.decor_id.map_partitions(
    batch_calc_rewards,
    meta=test_results,
)

with dask_cluster():
    finished = results.compute()


# df_ids["structure"] = finished
# df_ids.to_parquet(
#     "/projects/rlmolecule/pstjohn/crystal_dask/20220622_decorated_structures.parquet"
# )

#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -A rlmolecule
#SBATCH -n 1
#SBATCH --cpus-per-task=36
#SBATCH --mem=90000
#SBATCH -t 180

/home/pstjohn/mambaforge/envs/rlmol/bin/python -m distributed.cli.dask_worker tcp://10.148.7.241:41040 --nthreads 1 --nprocs 36 --memory-limit 2.33GiB --name dummy-name --nanny --death-timeout 60 --local-directory /tmp/scratch/dask-worker-space --lifetime-stagger 60m --interface ib0 --protocol tcp://



In [34]:
finished

Unnamed: 0,id,reward,predicted_energy,decomp_energy,cond_ion_frac,reduction,oxidation,stability_window
0,Na1Zr1O1P1|_1_1_1_1|orthorhombic|icsd_028355|7,0.036,-1.354,4.984,0.250,,,
1,K2Ti1Hg2S5|_1_2_2_5|orthorhombic|icsd_422825|1,0.054,0.752,4.534,0.200,,,
2,K3Cl1S1|_1_1_3|trigonal|icsd_038068|1,0.332,-1.252,2.025,0.600,,,
3,Li1Zn1Sn2Br7O2|_1_1_2_2_7|orthorhombic|icsd_07...,0.236,-2.663,0.938,0.077,,,
4,Li1Hf2B3O9|_1_2_3_9|trigonal|icsd_202666|1,0.194,-6.735,1.677,0.067,,,
...,...,...,...,...,...,...,...,...
95,Na1Cd1B1O3|_1_1_1_3|orthorhombic|icsd_422802|6,0.149,-3.062,2.739,0.167,,,
96,Li2Hf1Ge1F4Br6|_1_1_2_4_6|orthorhombic|icsd_24...,0.195,-2.608,1.851,0.143,,,
97,Li1Cd1Hg3S3N1|_1_1_1_3_3|orthorhombic|icsd_416...,0.094,0.484,3.578,0.111,,,
98,Na1La1Hf1S4|_1_1_1_4|orthorhombic|icsd_097767|4,0.245,-5.079,0.954,0.143,,,
