# Permutation trees

## Permutation tree representation v1

In [1]:
import torch
import torch_geometric as pyg

In [2]:
torch.tensor([1,2,3])

tensor([1, 2, 3])

In [3]:
t1 = pyg.data.Data(x=torch.tensor([[1,2,3],[4,5,6]]), edge_index=torch.tensor([[0], [1]]), tree="hello")
print(t1)
print(t1.x)
print(t1.tree)
t1 = t1.to('cuda')
print(t1)
print(t1.x)
print(t1.tree)

Data(x=[2, 3], edge_index=[2, 1], tree='hello')
tensor([[1, 2, 3],
        [4, 5, 6]])
hello
Data(x=[2, 3], edge_index=[2, 1], tree='hello')
tensor([[1, 2, 3],
        [4, 5, 6]], device='cuda:0')
hello


In [4]:
b1 = pyg.data.Batch.from_data_list([t1, t1, t1])
b1

DataBatch(x=[6, 3], edge_index=[2, 3], tree=[3], batch=[6], ptr=[4])

In [5]:
print(b1.tree)
print(b1.x)

['hello', 'hello', 'hello']
tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]], device='cuda:0')


## Creation of custom permutation trees

In [11]:
from ptgnn.dataset import RSDataset
ds_config = {
    'type' : "rs",
    'mask_chiral_tags': True,
    'transformation_mode': 'default',
}
rs_val = RSDataset(
    split='val',
    **ds_config,
)

In [13]:
for data in rs_val:
    display(data)
    break

Data(x=[72, 118], edge_index=[2, 192], edge_attr=[192, 80], pos=[72, 6], parallel_node_index=[72], circle_index=[72], y=[1])

In [28]:
from typing import List
import json


def _circle_index_to_primordial_tree(circle_index: List[int], parallel_node: int, self_node: int):
    # if nothing in the circular index return empty string
    if len(circle_index) == 0:
        return json.dumps({"P": [int(parallel_node)]})
    else:
        # not including parallel node index
        # return f"Z{[i for i in circle_index]}"

        # including parallel node index
        # return f"P[{parallel_node}, Z{[i for i in circle_index]}]"
        return json.dumps({
            "S": [
                int(self_node),
                int(parallel_node),
                {
                    "Z": [int(i) for i in circle_index]
                }
            ]
        })

In [29]:
ptree = [
    _circle_index_to_primordial_tree(circle_index, parallel_node, idx)
    for idx, (circle_index, parallel_node) in enumerate(zip(data.circle_index, data.parallel_node_index))
]

In [30]:
display(ptree)

['{"P": [1]}',
 '{"S": [1, 0, {"Z": [3, 5, 7]}]}',
 '{"S": [2, 3, {"Z": [0, 7, 5]}]}',
 '{"S": [3, 2, {"Z": [9, 11]}]}',
 '{"S": [4, 5, {"Z": [0, 3, 7]}]}',
 '{"S": [5, 4, {"Z": [53, 55]}]}',
 '{"S": [6, 7, {"Z": [0, 5, 3]}]}',
 '{"S": [7, 6, {"Z": [64, 69, 71]}]}',
 '{"S": [8, 9, {"Z": [2, 11]}]}',
 '{"S": [9, 8, {"Z": [13, 15]}]}',
 '{"S": [10, 11, {"Z": [2, 9]}]}',
 '{"S": [11, 10, {"Z": [26, 51]}]}',
 '{"S": [12, 13, {"Z": [8, 15]}]}',
 '{"S": [13, 12, {"Z": [17, 19]}]}',
 '{"S": [14, 15, {"Z": [8, 13]}]}',
 '{"P": [14]}',
 '{"S": [16, 17, {"Z": [12, 19]}]}',
 '{"S": [17, 16, {"Z": [21, 23]}]}',
 '{"S": [18, 19, {"Z": [12, 17]}]}',
 '{"P": [18]}',
 '{"S": [20, 21, {"Z": [16, 23]}]}',
 '{"S": [21, 20, {"Z": [25, 27]}]}',
 '{"S": [22, 23, {"Z": [16, 21]}]}',
 '{"P": [22]}',
 '{"S": [24, 25, {"Z": [20, 27]}]}',
 '{"S": [25, 24, {"Z": [29, 31]}]}',
 '{"S": [26, 27, {"Z": [20, 25]}]}',
 '{"S": [27, 26, {"Z": [10, 51]}]}',
 '{"S": [28, 29, {"Z": [24, 31]}]}',
 '{"S": [29, 28, {"Z": [33, 

## write custom edge graph transformation

In [31]:
# load vertex graph
rs = RSDataset(
    split='val',
    graph_mode='vertex'
)

Processing...
11748it [01:20, 146.37it/s]
Done!


In [58]:
for data in rs:
    display(data)
    break

Data(x=[34, 52], edge_index=[2, 72], edge_attr=[72, 14], pos=[34, 3], bond_distances=[36], bond_distance_index=[2, 36], bond_angles=[60], bond_angle_index=[3, 60], dihedral_angles=[91], dihedral_angle_index=[4, 91], y=[1])

In [67]:
import torch_geometric
from collections import defaultdict
from ptgnn.transform.edge_graph.chienn.get_circle_index import get_circle_index

In [60]:
# make sure that edges are undirected (in chemical context necessary)
if torch_geometric.utils.is_undirected(data.edge_index):
    edge_index, edge_attr = data.edge_index, data.edge_attr
else:
    edge_index, edge_attr = torch_geometric.utils.to_undirected(
        edge_index=data.edge_index,
        edge_attr=data.edge_attr
    )

In [61]:
# create the new nodes
node_storage = []
node_mapping = {}
for (a,b), edge_attr in zip(edge_index.T.tolist(), edge_attr):
    # create the embedding for the new node
    embedding_a2b = torch.cat([
        data.x[a],
        edge_attr,
        data.x[b]
    ])  # x_{i, j} = x'_i | e'_{i, j} | x'_j.

    # create the new position
    pos = torch.cat([data.pos[a], data.pos[b]])

    # todo: add only one if duplicate edges/nodes are required
    #   find solution for problem of embedding. either don't care and take one direction or sum up?
    # add to the storages
    node_mapping[(a, b)] = len(node_storage)
    node_storage.append({
        'a': a,
        'b': b,
        'a_attr': data.x[a],
        'node_attr': embedding_a2b,
        'old_edge_attr': edge_attr,
        'pos': pos
    })

In [63]:
# create dictionary for ingoing nodes (helper for later)
in_nodes = defaultdict(list)

# iterate over new nodes
for i, node_dict in enumerate(node_storage):
    # unpack edge source and target
    a, b = node_dict['a'], node_dict['b']

    # add source to each target
    in_nodes[b].append({'node_idx': i, 'start_node_idx': a})

In [64]:
# create new edges
new_edges = []

# iterate over new nodes
for i, node_dict in enumerate(node_storage):

    # unpack source and target
    a, b = node_dict['a'], node_dict['b']

    # get the edge embeddings (former node embedding)
    ab_old_edge_attr = node_dict['old_edge_attr']

    # get the attributes of the source node
    a_attr = node_dict['a_attr']

    # get the indices ingoing to a
    a_in_nodes_indices = [d['node_idx'] for d in in_nodes[a]]

    # iterate over them
    for in_node_c in a_in_nodes_indices:
        # fetch the current ingoing node
        in_node = node_storage[in_node_c]
        # ... and extract the node embedding
        ca_old_edge_attr = in_node['old_edge_attr']

        # e_{(i, j), (j, k)} = e'_(i, j) | x'_j | e'_{k, j}:
        edge_attr = torch.cat([ca_old_edge_attr, a_attr, ab_old_edge_attr])
        new_edges.append({'edge': [in_node_c, i], 'edge_attr': edge_attr})

In [65]:
parallel_node_index = []
for node_dict in node_storage:
    a, b = node_dict['a'], node_dict['b']
    parallel_idx = node_mapping[(b, a)]
    parallel_node_index.append(parallel_idx)

In [66]:
new_x = [d['node_attr'] for d in node_storage]
new_pos = [d['pos'] for d in node_storage]
new_edge_index = [d['edge'] for d in new_edges]
new_edge_attr = [d['edge_attr'] for d in new_edges]
new_x = torch.stack(new_x)
new_pos = torch.stack(new_pos)
new_edge_index = torch.tensor(new_edge_index).T
new_edge_attr = torch.stack(new_edge_attr)
parallel_node_index = torch.tensor(parallel_node_index)

In [68]:
data = torch_geometric.data.Data(x=new_x, edge_index=new_edge_index, edge_attr=new_edge_attr, pos=new_pos)
data.parallel_node_index = parallel_node_index
data.circle_index = get_circle_index(data, clockwise=False)

In [69]:
display(node_mapping)

{(0, 1): 0,
 (1, 0): 1,
 (1, 2): 2,
 (2, 1): 3,
 (1, 14): 4,
 (14, 1): 5,
 (1, 18): 6,
 (18, 1): 7,
 (2, 3): 8,
 (3, 2): 9,
 (2, 13): 10,
 (13, 2): 11,
 (3, 4): 12,
 (4, 3): 13,
 (3, 19): 14,
 (19, 3): 15,
 (4, 5): 16,
 (5, 4): 17,
 (4, 20): 18,
 (20, 4): 19,
 (5, 6): 20,
 (6, 5): 21,
 (5, 21): 22,
 (21, 5): 23,
 (6, 7): 24,
 (7, 6): 25,
 (6, 13): 26,
 (13, 6): 27,
 (7, 8): 28,
 (8, 7): 29,
 (7, 12): 30,
 (12, 7): 31,
 (8, 9): 32,
 (9, 8): 33,
 (8, 22): 34,
 (22, 8): 35,
 (9, 10): 36,
 (10, 9): 37,
 (9, 23): 38,
 (23, 9): 39,
 (10, 11): 40,
 (11, 10): 41,
 (10, 24): 42,
 (24, 10): 43,
 (11, 12): 44,
 (12, 11): 45,
 (11, 25): 46,
 (25, 11): 47,
 (12, 26): 48,
 (26, 12): 49,
 (13, 27): 50,
 (27, 13): 51,
 (14, 15): 52,
 (15, 14): 53,
 (14, 28): 54,
 (28, 14): 55,
 (15, 16): 56,
 (16, 15): 57,
 (15, 29): 58,
 (29, 15): 59,
 (16, 17): 60,
 (17, 16): 61,
 (16, 30): 62,
 (30, 16): 63,
 (17, 18): 64,
 (18, 17): 65,
 (17, 31): 66,
 (31, 17): 67,
 (18, 32): 68,
 (32, 18): 69,
 (18, 33): 70,
 (3