In [None]:
import sys
sys.path.append('../src')
import random
# increase RLIMIT
import resource
resource.setrlimit(resource.RLIMIT_NOFILE, (65536, 65536))
import matplotlib.pyplot as plt


from fats.dataset import AdsorptionGraphDataset
from fats.graph_tools import graph_plotter

## Generate graph dataset

In [None]:
ASE_DB_PATH = "../data/fg.db"
GRAPH_DATASET_PATH = "../data"
STRUCTURE_DICT = {"tolerance": 0.25, "scaling_factor": 1.25, "second_order": True}
FEATURES_DICT = {"adsorbate": False, "radical": False, "valence": False, "gcn": True, "magnetization": False}
GRAPH_PARAMS = {"structure": STRUCTURE_DICT, "features": FEATURES_DICT, "target": "scaled_energy"}
DB_KEY = ''

dataset = AdsorptionGraphDataset(ASE_DB_PATH, GRAPH_DATASET_PATH, GRAPH_PARAMS, DB_KEY)

In [None]:
gas_graphs = []
for graph in dataset:
    if all(elem in ['C', 'H', 'O', 'N', 'S'] for elem in graph.elem):
        gas_graphs.append(graph)
print(len(gas_graphs))

# Check failed graphs

In [None]:
random_graph = random.choice(gas_graphs)
graph_plotter(random_graph, dataset.ohe_elements, node_index=True)

In [None]:
graph_plotter(random_graph, dataset.ohe_elements, node_index=True)
plt.savefig("graph_ts.svg")

In [None]:
print(random_graph.elem)

In [None]:
print(random_graph.x[:,-1])

In [None]:
from ase.visualize import view
view(random_graph.atoms_obj)

In [None]:
random_graph.type

In [None]:
random_graph.x[:,-1]

# Visualization of graph and atoms objects


In [None]:
metal_list, type_list, bb_type_list, adsorbate_size_list, facet_list, energy_list, nC_list, nO_list = [], [], [], [], [], [], [], []
for graph in dataset:
    metal_list.append(graph.metal)
    type_list.append(graph.type)
    bb_type_list.append(graph.bb_type)
    atoms = graph.atoms_obj
    nC_list.append(atoms.get_chemical_symbols().count("C"))
    nO_list.append(atoms.get_chemical_symbols().count("O"))
    size = atoms.get_chemical_symbols().count("C") + atoms.get_chemical_symbols().count("H") + atoms.get_chemical_symbols().count("O")
    adsorbate_size_list.append(size)
    facet_list.append(graph.facet)
    energy_list.append(graph.target.item())

In [None]:
metal_count = {metal: metal_list.count(metal) for metal in set(metal_list)}
bb_type_count = {bb_type: bb_type_list.count(bb_type) for bb_type in set(bb_type_list)}
facet_count = {facet: facet_list.count(facet) for facet in set(facet_list)}
adsorbate_size_count = {size: adsorbate_size_list.count(size) for size in set(adsorbate_size_list)}
nC_count = {nC: nC_list.count(nC) for nC in set(nC_list)}
nO_count = {nO: nO_list.count(nO) for nO in set(nO_list)}
type_count = {type_: type_list.count(type_) for type_ in set(type_list)}


In [None]:
type_count

In [None]:
fig, ax = plt.subplots(4,2, figsize=(18/2.54, 20/2.54))
# Set global title
fig.suptitle("Transition state dataset")
ax[0,0].bar(bb_type_count.keys(), bb_type_count.values(), color="C0")
for key, value in bb_type_count.items():
    ax[0,0].text(key, value, str(value), ha="center", va="bottom")
ax[0,1].bar(metal_count.keys(), metal_count.values(), color="C1", width=0.7, align="center", )
# for key, value in metal_count.items():
#     ax[0,1].text(key, value, str(value), ha="center", va="bottom")
ax[1,0].bar(adsorbate_size_count.keys(), adsorbate_size_count.values(), color="C2")
for key, value in adsorbate_size_count.items():
    ax[1,0].text(key, value, str(value), ha="center", va="bottom")
ax[1,1].bar(facet_count.keys(), facet_count.values(), color="C3", width=0.5, align="center")
for key, value in facet_count.items():
    ax[1,1].text(key, value, str(value), ha="center", va="bottom")
ax[2,0].bar(nC_count.keys(), nC_count.values(), color="C4")
for key, value in nC_count.items():
    ax[2,0].text(key, value, str(value), ha="center", va="bottom")
ax[2,1].bar(nO_count.keys(), nO_count.values(), color="C5")
for key, value in nO_count.items():
    ax[2,1].text(key, value, str(value), ha="center", va="bottom")
ax[3,0].bar(type_count.keys(), type_count.values(), color="C6")
for key, value in type_count.items():
    ax[3,0].text(key, value, str(value), ha="center", va="bottom")



ax[0,0].set_xlabel("Bond-breaking")
ax[0,1].set_xlabel("Metal")
ax[1,0].set_xlabel("Adsorbate size")
ax[1,1].set_xlabel("Surface")
ax[2,0].set_xlabel("Adsorbate C count")
ax[2,1].set_xlabel("Adsorbate O count")
ax[0,0].set_ylabel("Count")
ax[0,1].set_ylabel("")
ax[1,0].set_ylabel("Count")
ax[1,1].set_ylabel("")
ax[2,0].set_ylabel("Count")
ax[2,1].set_ylabel("")
ax[3,0].set_ylabel("Count")
ax[3,1].set_ylabel("")

for i in range(4):
    for j in range(2):
        ylim = ax[i,j].get_ylim()
        ax[i,j].set_ylim(ylim[0], ylim[1]*1.2)

plt.tight_layout()
# plt.savefig("TS_dataset_stats.svg")

In [None]:
fig, ax = plt.subplots(2,2, figsize=(18/2.54, 14/2.54))

In [None]:
# plot distribution of energies
fig, ax = plt.subplots(figsize=(18/2.54, 10/2.54))
sns.histplot(energy_list, ax=ax)
ax.set_xlabel("$E_{tot}-E_{slab}$ / eV")
ax.set_ylabel("Count")
ax.set_title("TS dataset energy distribution")
plt.tight_layout()
plt.savefig("TS_dataset_energy_distribution.svg")

In [None]:
for i, graph in enumerate(TS_dataset):
    if graph.target >= 0:
        print(i, graph.formula, graph.target.item(), graph.type)
        

In [None]:
view([TS_dataset[3414].atoms_obj, TS_dataset[3416].atoms_obj, TS_dataset[3423].atoms_obj])

In [None]:
graph_plotter(TS_dataset[3423], TS_dataset.ohe_elements, node_index=False)

In [None]:
# Create folder of .png files for each graph in dataset
# import os
# import matplotlib.pyplot as plt
# os.makedirs("../data/plots", exist_ok=True)
# for i, graph in enumerate(FG_dataset):
#     formula = graph.atoms_obj.get_chemical_formula()
#     bond_breaking = graph.bb_type
#     graph_plotter(graph, FG_dataset.ohe_elements, node_index=False)
#     plt.title(formula + " (" + bond_breaking+ ")" )
#     plt.savefig(os.path.join("../data/plots", f"graph_{i}.svg"))
#     plt.close()