# Permutation trees

## Permutation tree representation v1

In [1]:
import torch
import torch_geometric as pyg

In [2]:
torch.tensor([1,2,3])

tensor([1, 2, 3])

In [3]:
t1 = pyg.data.Data(x=torch.tensor([[1,2,3],[4,5,6]]), edge_index=torch.tensor([[0], [1]]), tree="hello")
print(t1)
print(t1.x)
print(t1.tree)
t1 = t1.to('cuda')
print(t1)
print(t1.x)
print(t1.tree)

Data(x=[2, 3], edge_index=[2, 1], tree='hello')
tensor([[1, 2, 3],
        [4, 5, 6]])
hello
Data(x=[2, 3], edge_index=[2, 1], tree='hello')
tensor([[1, 2, 3],
        [4, 5, 6]], device='cuda:0')
hello


In [4]:
b1 = pyg.data.Batch.from_data_list([t1, t1, t1])
b1

DataBatch(x=[6, 3], edge_index=[2, 3], tree=[3], batch=[6], ptr=[4])

In [5]:
print(b1.tree)
print(b1.x)

['hello', 'hello', 'hello']
tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]], device='cuda:0')


## Creation of custom permutation trees

In [11]:
from ptgnn.dataset import RSDataset
ds_config = {
    'type' : "rs",
    'mask_chiral_tags': True,
    'transformation_mode': 'default',
}
rs_val = RSDataset(
    split='val',
    **ds_config,
)

In [13]:
for data in rs_val:
    display(data)
    break

Data(x=[72, 118], edge_index=[2, 192], edge_attr=[192, 80], pos=[72, 6], parallel_node_index=[72], circle_index=[72], y=[1])

In [28]:
from typing import List
import json


def _circle_index_to_primordial_tree(circle_index: List[int], parallel_node: int, self_node: int):
    # if nothing in the circular index return empty string
    if len(circle_index) == 0:
        return json.dumps({"P": [int(parallel_node)]})
    else:
        # not including parallel node index
        # return f"Z{[i for i in circle_index]}"

        # including parallel node index
        # return f"P[{parallel_node}, Z{[i for i in circle_index]}]"
        return json.dumps({
            "S": [
                int(self_node),
                int(parallel_node),
                {
                    "Z": [int(i) for i in circle_index]
                }
            ]
        })

In [29]:
ptree = [
    _circle_index_to_primordial_tree(circle_index, parallel_node, idx)
    for idx, (circle_index, parallel_node) in enumerate(zip(data.circle_index, data.parallel_node_index))
]

In [30]:
display(ptree)

['{"P": [1]}',
 '{"S": [1, 0, {"Z": [3, 5, 7]}]}',
 '{"S": [2, 3, {"Z": [0, 7, 5]}]}',
 '{"S": [3, 2, {"Z": [9, 11]}]}',
 '{"S": [4, 5, {"Z": [0, 3, 7]}]}',
 '{"S": [5, 4, {"Z": [53, 55]}]}',
 '{"S": [6, 7, {"Z": [0, 5, 3]}]}',
 '{"S": [7, 6, {"Z": [64, 69, 71]}]}',
 '{"S": [8, 9, {"Z": [2, 11]}]}',
 '{"S": [9, 8, {"Z": [13, 15]}]}',
 '{"S": [10, 11, {"Z": [2, 9]}]}',
 '{"S": [11, 10, {"Z": [26, 51]}]}',
 '{"S": [12, 13, {"Z": [8, 15]}]}',
 '{"S": [13, 12, {"Z": [17, 19]}]}',
 '{"S": [14, 15, {"Z": [8, 13]}]}',
 '{"P": [14]}',
 '{"S": [16, 17, {"Z": [12, 19]}]}',
 '{"S": [17, 16, {"Z": [21, 23]}]}',
 '{"S": [18, 19, {"Z": [12, 17]}]}',
 '{"P": [18]}',
 '{"S": [20, 21, {"Z": [16, 23]}]}',
 '{"S": [21, 20, {"Z": [25, 27]}]}',
 '{"S": [22, 23, {"Z": [16, 21]}]}',
 '{"P": [22]}',
 '{"S": [24, 25, {"Z": [20, 27]}]}',
 '{"S": [25, 24, {"Z": [29, 31]}]}',
 '{"S": [26, 27, {"Z": [20, 25]}]}',
 '{"S": [27, 26, {"Z": [10, 51]}]}',
 '{"S": [28, 29, {"Z": [24, 31]}]}',
 '{"S": [29, 28, {"Z": [33, 