# Tree separation

In [1]:
import json

import torch

from ptgnn.dataset import RSDataset

## Fetch sample

In [2]:
rs = RSDataset(split='val', transformation_mode="permutation_tree", transformation_parameters={'k': 3})

In [3]:
test_elem = rs[0]
test_elem

Data(x=[72, 118], edge_index=[2, 192], edge_attr=[192, 80], pos=[72, 6], parallel_node_index=[72], circle_index=[72], ptree=[72], initial_map=[264], layer0_order_matrix=[1], layer0_type_mask=[1], layer0_pooling=[1], num_layer=[1], layer1_order_matrix=[1], layer1_type_mask=[1], layer1_pooling=[1], y=[1])

## Logic part:
- each leaf node of tree has node index.
- every internal node needs index - with this create new node
- new node is probably empty at beginning (e.g. zero)
- trim tree of each node to one layer and create ones for new nodes (subtrees that were trimmed off)

In [4]:
current_idx = 0

while current_idx < test_elem.x.shape[0]:
    # get tree
    ptree = test_elem.ptree[current_idx]

    # convert tree into dict
    ptree = json.loads(ptree)

    # iterate over subtrees
    key = next(iter(ptree.keys()))
    for idx, subtree in enumerate(ptree[key]):
        if isinstance(subtree, int):
            # this means that element is already a leaf node
            continue

        elif isinstance(subtree, dict):
            # element is a tree
            # replace tree with new_idx
            ptree[key][idx] = test_elem.x.shape[0]

            # append zeros to x
            test_elem.x = torch.cat([test_elem.x, torch.zeros(1, test_elem.x.shape[1])], dim=0)

            # append new ptree
            test_elem.ptree.append(json.dumps(subtree))

        else:
            raise Exception(f"Something went wrong, {subtree} is neither int nor dict.")

    # save ptree
    test_elem.ptree[current_idx] = json.dumps(ptree)

    # increase current idx
    current_idx += 1

In [5]:
display(test_elem, test_elem.ptree)

Data(x=[128, 118], edge_index=[2, 192], edge_attr=[192, 80], pos=[72, 6], parallel_node_index=[72], circle_index=[72], ptree=[128], initial_map=[264], layer0_order_matrix=[1], layer0_type_mask=[1], layer0_pooling=[1], num_layer=[1], layer1_order_matrix=[1], layer1_type_mask=[1], layer1_pooling=[1], y=[1])

['{"S": [0, 1]}',
 '{"S": [1, 0, 72]}',
 '{"S": [2, 3, 73]}',
 '{"S": [3, 2, 74]}',
 '{"S": [4, 5, 75]}',
 '{"S": [5, 4, 76]}',
 '{"S": [6, 7, 77]}',
 '{"S": [7, 6, 78]}',
 '{"S": [8, 9, 79]}',
 '{"S": [9, 8, 80]}',
 '{"S": [10, 11, 81]}',
 '{"S": [11, 10, 82]}',
 '{"S": [12, 13, 83]}',
 '{"S": [13, 12, 84]}',
 '{"S": [14, 15, 85]}',
 '{"S": [15, 14]}',
 '{"S": [16, 17, 86]}',
 '{"S": [17, 16, 87]}',
 '{"S": [18, 19, 88]}',
 '{"S": [19, 18]}',
 '{"S": [20, 21, 89]}',
 '{"S": [21, 20, 90]}',
 '{"S": [22, 23, 91]}',
 '{"S": [23, 22]}',
 '{"S": [24, 25, 92]}',
 '{"S": [25, 24, 93]}',
 '{"S": [26, 27, 94]}',
 '{"S": [27, 26, 95]}',
 '{"S": [28, 29, 96]}',
 '{"S": [29, 28, 97]}',
 '{"S": [30, 31, 98]}',
 '{"S": [31, 30, 99]}',
 '{"S": [32, 33, 100]}',
 '{"S": [33, 32, 101]}',
 '{"S": [34, 35, 102]}',
 '{"S": [35, 34]}',
 '{"S": [36, 37, 103]}',
 '{"S": [37, 36, 104]}',
 '{"S": [38, 39, 105]}',
 '{"S": [39, 38]}',
 '{"S": [40, 41, 106]}',
 '{"S": [41, 40, 107]}',
 '{"S": [42, 43, 108]}',
 '{