In [1]:
import os
import pickle
from ase.db import connect
from ase.visualize import view
import networkx as nx

from GAMERNet.rnet.networks.reaction_network import ReactionNetwork
from GAMERNet.rnet.networks.surface import Surface

with open('../scripts/c1au111/rxn_net.pkl', 'rb') as pickle_file:
    content = pickle.load(pickle_file)  # dict of elementary reactions

rxn_net = ReactionNetwork().from_dict(content)
print(len(rxn_net.intermediates), len(rxn_net.reactions))
# graph = rxn_net.graph
# # Fix shortest path detection
x = rxn_net.get_shortest_path2(['1-0-1-1-1-g', '0-2-0-1-1-g'], '1-4-1-1-1-g')
print(x)

58 105
['1-0-1-1-1-g', '1-0-1-1-1-*', '1-1-1-1-1-*', '1-2-1-1-1-*', '1-3-1-1-1-*', '1-4-1-1-1-*', '1-4-1-1-1-g']


# delete C1HxO4

In [2]:
rxn_net.del_intermediates(["1-2-4-1-1-*"])
print(rxn_net)

Deleted 1 intermediates and 2 elementary reactions
Number of intermediates before: 58, after: 57
Number of reactions before: 105, after: 103
ReactionNetwork(42 intermediates, 14 closed-shell molecules, 103 reactions)
Surface: Au48(111)
Network Carbon cutoff: C1



# delete rxn

In [2]:
rxn_net.del_reactions([rxn_net.reactions[0]])

Deleted 0 intermediates and 1 elementary reactions
Number of reactions before: 105, after: 104


In [3]:
step = rxn_net.reactions[0]
print(type(step))
print(step)
print(type(step.components[0]))

<class 'GAMERNet.rnet.networks.elementary_reaction.ElementaryReaction'>
0-0-0-0-0-*(Au48*)+1-4-0-1-1-*(CH4*)<->1-3-0-1-1-*(CH3*)+0-1-0-1-1-*(H*)
<class 'frozenset'>


In [4]:
print(len(rxn_net.intermediates))

57


In [4]:
rxn_net.reactions

[1-4-0-1-1-*(CH4*)+0-0-0-0-0-*(Au48*)<->1-3-0-1-1-*(CH3*)+0-1-0-1-1-*(H*),
 1-3-0-1-1-*(CH3*)+0-0-0-0-0-*(Au48*)<->1-2-0-1-1-*(CH2*)+0-1-0-1-1-*(H*),
 1-2-0-1-1-*(CH2*)+0-0-0-0-0-*(Au48*)<->1-1-0-1-1-*(CH*)+0-1-0-1-1-*(H*),
 1-1-0-1-1-*(CH*)+0-0-0-0-0-*(Au48*)<->1-0-0-1-1-*(C*)+0-1-0-1-1-*(H*),
 1-4-4-1-1-*(CH4O4*)+0-0-0-0-0-*(Au48*)<->1-3-4-1-1-*(CH3O4*)+0-1-0-1-1-*(H*),
 1-2-4-1-1-*(CH2O4*)+0-0-0-0-0-*(Au48*)<->1-1-4-1-1-*(CHO4*)+0-1-0-1-1-*(H*),
 1-1-4-1-1-*(CHO4*)+0-0-0-0-0-*(Au48*)<->1-0-4-1-1-*(CO4*)+0-1-0-1-1-*(H*),
 1-4-3-1-1-*(CH4O3*)+0-0-0-0-0-*(Au48*)<->1-3-3-1-1-*(CH3O3*)+0-1-0-1-1-*(H*),
 1-4-3-1-1-*(CH4O3*)+0-0-0-0-0-*(Au48*)<->1-3-3-1-2-*(CH3O3*)+0-1-0-1-1-*(H*),
 1-3-3-1-1-*(CH3O3*)+0-0-0-0-0-*(Au48*)<->1-2-3-1-1-*(CH2O3*)+0-1-0-1-1-*(H*),
 1-3-3-1-1-*(CH3O3*)+0-0-0-0-0-*(Au48*)<->1-2-3-1-2-*(CH2O3*)+0-1-0-1-1-*(H*),
 1-3-3-1-2-*(CH3O3*)+0-0-0-0-0-*(Au48*)<->1-2-3-1-1-*(CH2O3*)+0-1-0-1-1-*(H*),
 1-2-3-1-1-*(CH2O3*)+0-0-0-0-0-*(Au48*)<->1-1-3-1-1-*(CHO3*)+0-1-0-1-1-*(H*)

In [5]:
rxn_net.graph

<networkx.classes.digraph.DiGraph at 0x7fdf581b7e50>

In [5]:
graph = rxn_net.gen_graph(del_surf=True)
for node in graph.nodes:
    print(node, graph.nodes[node])

0-0-0-0-0-* {'category': 'surf', 'gas_atoms': Atoms(symbols='Au48', pbc=True, cell=[[10.1802357174757, 0.0, 0.0], [5.09011785873783, 8.81634274784763, 0.0], [0.0, 0.0, 20.0]], conn_pairs=..., constraint=FixAtoms(indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23])), 'code': '0-0-0-0-0-*', 'formula': 'Au48*', 'fig_path': '/home/smorandi/care/notebooks/tmp/0-0-0-0-0-*.png', 'switch': None}
1-4-0-1-1-* {'category': 'ads', 'gas_atoms': Atoms(symbols='CH4', pbc=False, conn_pairs=...), 'code': '1-4-0-1-1-*', 'formula': 'CH4*', 'fig_path': '/home/smorandi/care/notebooks/tmp/1-4-0-1-1-*.png', 'switch': None}
1-3-0-1-1-* {'category': 'ads', 'gas_atoms': Atoms(symbols='CH3', pbc=False, conn_pairs=...), 'code': '1-3-0-1-1-*', 'formula': 'CH3*', 'fig_path': '/home/smorandi/care/notebooks/tmp/1-3-0-1-1-*.png', 'switch': None}
1-2-0-1-1-* {'category': 'ads', 'gas_atoms': Atoms(symbols='CH2', pbc=False, conn_pairs=...), 'code': '1-2-0-1-1-*', 'formula': 'CH2

In [3]:
rxn_net.write_dotgraph(".", "test_del.png", show_steps=False)

KeyError: 'category'

In [2]:
rxn_net.write_dotgraph(".", 'HIGHLIGHT_test.png', del_surf=True,  show_steps=False, highlight=x)

# interactive plotly graph

In [4]:
graph = rxn_net.gen_graph(del_surf=True, show_steps=False, highlight=x)

In [5]:
import plotly.graph_objects as go

def gen_interactive_graph(graph: nx.DiGraph):
    pos = nx.kamada_kawai_layout(graph)
    nx.set_node_attributes(graph, pos, 'pos')
    edge_x = []
    edge_y = []
    G = graph
    for edge in G.edges():
        x0, y0 = G.nodes[edge[0]]['pos']
        x1, y1 = G.nodes[edge[1]]['pos']
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
    
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')
    
    node_x = [pos[0] for pos in list(nx.get_node_attributes(G,'pos').values())]
    node_y = [pos[1] for pos in list(nx.get_node_attributes(G,'pos').values())]
    
    hover_texts = []
    for node in G.nodes(data=True):
        img_path = node[1]['fig_path']
        formula = node[1]['formula']
        hover_text = f"""<
                <TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">
                <TR>
                <TD><IMG SRC="{img_path}"/></TD>
                </TR>
                <TR>
                <TD>{formula}</TD>
                </TR>
                </TABLE>>"""
        hover_texts.append(hover_text)
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        hovertext=hover_texts,
        marker=dict(
            showscale=True,
            size=30,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            )
        )
    )
    
    node_adjacencies = []
    node_text = []
    for node, adjacencies in enumerate(G.adjacency()):
        node_adjacencies.append(len(adjacencies[1]))
        node_text.append(f"# of connections: {len(adjacencies[1])}")
    
    node_trace.marker.color = node_adjacencies
    node_trace.text = node_text
    
    fig = go.Figure(data=[edge_trace, node_trace],
                 layout=go.Layout(
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=0,l=0,r=0,t=0),
                    xaxis=dict(showgrid=False, zeroline=False),
                    yaxis=dict(showgrid=False, zeroline=False))
                    )
    fig.show()
    

In [6]:
gen_interactive_graph(graph)

In [None]:
import networkx as nx
from collections import deque

def constrained_shortest_path(graph, source, target, intermediates_to_avoid):
    visited = set()
    queue = deque([(source, [source])])  # Each element of the queue is a tuple (node, path_so_far)

    while queue:
        current_node, path_so_far = queue.popleft()

        if current_node == target:
            return path_so_far

        visited.add(current_node)

        for neighbor in graph.neighbors(current_node):
            # Skip nodes that are not elementary reactions
            if 'type' in graph.nodes[neighbor] and graph.nodes[neighbor]['type'] != 'elementary_reaction':
                continue

            # Skip intermediates to avoid
            if neighbor in intermediates_to_avoid:
                continue
            
            if neighbor not in visited:
                queue.append((neighbor, path_so_far + [neighbor]))
                visited.add(neighbor)
                
    return None  # return None if no such path exists


In [None]:
import networkx as nx
from collections import deque

# Define your graph here
G = nx.Graph()
G.add_edges_from([
    ("A", "R1"),
    ("R1", "B"),
    ("B", "R2"),
    ("R2", "C"),
    ("B", "R3"),
    ("R3", "D"),
])

# Annotate nodes with their types
for node in ["A", "B", "C", "D"]:
    G.nodes[node]['type'] = 'intermediate'
for node in ["R1", "R2", "R3"]:
    G.nodes[node]['type'] = 'elementary_reaction'

def custom_shortest_path(rxn_net, graph, source, target):
    queue = deque([(source, [])])
    graph = graph.to_undirected()
    intermediates_visited = set()
    reactions_visited = set()
    while queue:
        current_node, path_so_far = queue.popleft()
        
        # Check if the target node is reached
        if current_node == target:
            return path_so_far

        intermediates_visited.add(current_node)

        for step in graph.neighbors(current_node):
            if step in reactions_visited:
                continue

            # Update the path based on the type of the neighbor
            inters = list(rxn_net.reactions[step].reactants) + list(rxn_net.reactions[step].products)
            unvisited_inters = [inter for inter in inters if inter not in intermediates_visited]
            if len(unvisited_inters) != 0:
                continue
            new_path = path_so_far + [step] if graph.nodes[step]['category'] not in ('ads', 'sur', 'gas') else path_so_far

            queue.append((step, new_path))
            visited.add(neighbor)

    return None  # Path not found

def custom_shortest_path(net, graph, source, target):
    visited = set()
    visited_inters = {'000000*', '010101*'}
    queue = deque([(source, [])])

    while queue:
        current_node, path_so_far = queue.popleft()
        
        # Check if the target node is reached
        if current_node == target:
            return path_so_far

        visited.add(current_node)
        if graph.nodes[current_node]['category'] in ('ads', 'sur', 'gas'):
            visited_inters.add(current_node)

        for neighbor in graph.neighbors(current_node):
            # Skip if already visited
            if neighbor in visited:
                continue

            if graph.nodes[neighbor]['category'] not in ('ads', 'sur', 'gas'):
                index = rxn_net.reactions[neighbor].index
                inters = list(net.reactions[index].reactants) + list(net.reactions[index].products)
                if not all([inter in visited_inters for inter in inters]):
                    continue

            # Update the path based on the type of the neighbor
            new_path = path_so_far + [neighbor] if graph.nodes[neighbor]['type'] == 'elementary_reaction' else path_so_far

            queue.append((neighbor, new_path))
            visited.add(neighbor)

    return None  # Path not found

def shortest_path_sm(graph, source, target):
    # select all nodes that are not intermediates
    visited_intermediates = set()
    rxn_nodes = [node for node in graph.nodes if graph.nodes[node]['category'] not in ('ads', 'sur', 'gas')]
    break_condition = lambda node: target in node
    cc_condition = lambda node: all([intermediate in visited_intermediates for intermediate in node])




In [None]:
nx_graph = rxn_net.graph

In [None]:
x = custom_shortest_path(rxn_net, nx_graph, '102101g', '141101g')
x

In [None]:
rxn_net.reactions

# Other

In [None]:
for inter in rxn_net.intermediates.values():
    if inter.closed_shell:
        print(inter.smiles)

In [None]:
rxn_net.add_eley_rideal('101101g', '001101*', '102101*')
print(rxn_net)

In [None]:
print(len(rxn_net.reactions))

In [None]:
print(len(rxn_net.reactions))

In [None]:
counter  =0 
for reaction in rxn_net.reactions:
    counter += 1
    print(counter, reaction.code, reaction.components, reaction.r_type)

In [None]:
for reaction in rxn_net.reactions:
    print(reaction.components)
    for component in reaction.components:
        for inter in component:
            print(inter.code)

In [None]:
print(len(rxn_net.intermediates))

In [None]:
closed_shell_atoms = []
for inter in rxn_net.intermediates.values():
    if inter.closed_shell == True:
        closed_shell_atoms.append(inter.molecule)
print(len(closed_shell_atoms))

In [None]:
closed_shell_atoms = []
for inter in rxn_net.intermediates.values():
    if inter.closed_shell == True:
        closed_shell_atoms.append(inter.molecule)
print(len(closed_shell_atoms))

In [None]:
closed_shell_atoms[0].get_chemical_symbols().count("H")

In [None]:
view(closed_shell_atoms)

In [None]:
y = rxn_net.gen_graph()
# y.remove_node("")
for node in y.nodes(data=True):
    print(node)
print(y)

In [None]:
rxn_net.surface.facet

In [None]:
rxn_net.write_dotgraph(".", 'OLIV_test.png', del_surf=True)

Look for intermediates with specified composition

In [None]:
rxn_net.search_inter_by_elements({'C':1, 'H':2, 'O':2})

Look for all elementary steps involving a specific intermediate

In [None]:
rxn_net.search_ts(["222101"])

In [None]:
types = []
for reaction in rxn_net.reactions:
    types.append(reaction.r_type)
print(set(types))

In [None]:
types = []
for inter in rxn_net.intermediates.values():
    types.append(inter.phase)
print(set(types))

# closed shell

In [None]:
def is_closed_shell_santi(self):
        """
        Check if a molecule CxHyOz is closed-shell or not.
        """
        graph = self.graph
        # print(graph.nodes()) list of node indexes, element symbol stored as "elem"
        molecule = self.molecule
        valence_electrons = {'C': 4, 'H': 1, 'O': 2}
        graph = graph.to_undirected()
        mol_composition = molecule.get_chemical_symbols()
        mol = {'C': mol_composition.count('C'), 'H': mol_composition.count('H'), 'O': mol_composition.count('O')} # CxHyOz

        if mol['C'] != 0 and mol['H'] == 0 and mol['O'] == 0: # Cx
                return False
        elif mol['C'] == 0 and mol['H'] != 0 and mol['O'] == 0: # Hy
                return True if mol['H'] == 2 else False
        elif mol['C'] == 0 and mol['H'] == 0 and mol['O'] != 0: # Oz
                return True if mol['O'] == 2 else False
        elif mol['C'] != 0 and mol['H'] == 0 and mol['O'] != 0: # CxOz
                return True if mol['C'] == 1 and mol['O'] in (1,2) else False
        elif mol['C'] != 0 and mol['H'] != 0: # CxHyOz (z can be zero)
            node_val = lambda graph: {node: (graph.degree(node), 
                                        valence_electrons.get(graph.nodes[node]["elem"], 0)) for node in graph.nodes()}
            num_unsaturated_nodes = lambda dict: len([node for node in dict.keys() if dict[node][0] < dict[node][1]])
            node_valence_dict = node_val(graph)
            if num_unsaturated_nodes(node_valence_dict): # all atoms are saturated
                return True
            elif num_unsaturated_nodes(node_valence_dict) == 1: # only one unsaturated atom
                return False
            else:
                saturation_condition = lambda dict: all(dict[node][0] == dict[node][1] for node in dict.keys())
                while saturation_condition(node_valence_dict) == False:
                    unsat_nodes = [node for node in node_valence_dict.keys() if node_valence_dict[node][0] < node_valence_dict[node][1]]
                    O_unsat_nodes = [node for node in unsat_nodes if graph.nodes[node]["elem"] == 'O']  # all oxygens unsaturated
                    if len(O_unsat_nodes) != 0: # only one unsaturated oxygen
                        for oxygen in O_unsat_nodes:
                            node_valence_dict[oxygen][0] += 1
                            # increase the valence of the oxygen neighbour by 1
                            for neighbour in graph.neighbors(oxygen): # only one neighbour
                                if node_valence_dict[neighbour][0] < node_valence_dict[neighbour][1]:
                                    node_valence_dict[neighbour][1] += 1
                                else:
                                    return False # O neighbour is saturated already
                    else: # CxHy
                         # select node with the highest degree
                        max_degree = max([node_valence_dict[node][0] for node in unsat_nodes])
                        max_degree_node = [node for node in unsat_nodes if node_valence_dict[node][0] == max_degree][0]
                        max_degree_node_unsat_neighbours = [neighbour for neighbour in graph.neighbors(max_degree_node) if neighbour in unsat_nodes]
                        if len(max_degree_node_unsat_neighbours) == 0: # all neighbours are saturated
                            return False
                        node_valence_dict[max_degree_node][0] += 1
                        node_valence_dict[max_degree_node_unsat_neighbours][0] += 1
                         
                            

                
        
        # # Getting the unsaturated nodes (if there are not unsaturated nodes, the molecule is closed-shell)
        # unsat_nodes = [node for node in graph.nodes() if graph.degree(node) < valence_electrons.get(graph.nodes[node]["elem"], 0)]

        # # If the graph only has Carbon as an element and not H or O, then it is open-shell
        # if not 'H' and 'O' in molecule.get_chemical_formula():
        #     print(f'System {molecule.get_chemical_formula()} is open-shell: only C atoms')
        #     return False 
        
        # # Specific case for O2
        # if not 'C' and 'H' in molecule.get_chemical_formula() and len(unsat_nodes) == 2:
        #     print(f'System {molecule.get_chemical_formula()} is closed-shell: Oxygen')
        #     return True 
        
        # # CO and CO2
        # if not 'H' in molecule.get_chemical_formula() and len(molecule.get_chemical_symbols()['C']) == 1 and len(molecule.get_chemical_symbols()['O']) in (1,2):
        #     print(f'System {molecule.get_chemical_formula()} is closed-shell: CO or CO2')
        #     return True
        
        # if unsat_nodes:
        #     # If the molecule has only one unsaturated node, then it is open-shell
        #     if len(unsat_nodes) == 1:
        #         print(f'System {molecule.get_chemical_formula()} is open-shell')
        #         return False 
        #     else:
        #         # Checking if there is one unsaturated node that does not have as neighbour another unsaturated node
        #         for node in unsat_nodes:
        #             # If the molecule has only one unsaturated node, then it is open-shell
        #             if not [n for n in graph.neighbors(node) if n in unsat_nodes]:
        #                 print(f'System {molecule.get_chemical_formula()} is open-shell: one node is unsaturated but does not have as neighbour another unsaturated node')
        #                 return False 
        #             else:
        #                 # Case for molecules where an unsaturated node is oxygen
        #                 if graph.nodes[node]["elem"] == 'O':
        #                     # Adding one bond order (valence electrons) to the oxygen node by adding it to the unsat_nodes list

In [None]:
is_closed_shell_santi(rxn_net.intermediates['121101'])