In [1]:
import torch

In [2]:
from ptgnn.runtime_config.run_config import fetch_loaders

In [3]:
data_config = {
    'dataset' : {
        'type' : "rs",
        'mask_chiral_tags': True,
        'transformation_mode': 'chienn_tree_basic',
    },
    'loader': {
        'general': {
            'n_neighbors_in_circle': 3,
            'batch_size': 32,
            'num_workers': 0,
        },
        'train': {},
        'test': {},
        'val': {}
    }
}

In [4]:
train_loader, _, val_loader = fetch_loaders(data_config=data_config)

In [5]:
for batch in val_loader:
    display(batch)
    break

DataBatch(x=[2356, 118], edge_index=[2, 6416], edge_attr=[6416, 80], pos=[2356, 6], y=[32], batch=[2356], ptr=[33], ptree=[2356])

In [6]:
import json

# extract permutation trees from string
permutation_trees = [
    json.loads(p_string)
    for p_string in batch.ptree
]

In [7]:
k = 3

In [8]:
def depth(d):
    if isinstance(d, list):
        return max(map(depth, d))
    if isinstance(d, dict):
        return 1 + (max(map(depth, d.values())) if d else 0)
    return 0

In [9]:
# get number of layers
num_layers = depth(permutation_trees)
num_layers

2

In [10]:
type_dict = {
    "P": 1,
    "Z": 2
}

In [11]:
# get matrix of permutation tree

In [12]:
def get_matrix(tree, depth, idx_prefix: list = [], type_prefix: list = []) -> list:
    if isinstance(tree, list):
        idx_matrix, type_matrix = zip(*[
            get_matrix(child, depth, idx_prefix + [idx], type_prefix)
            for idx, child in enumerate(tree)
        ])
        return torch.cat(idx_matrix, dim=0), torch.cat(type_matrix, dim=0)

    elif isinstance(tree, dict):
        key = next(iter(tree.keys()))
        return get_matrix(tree[key], depth - 1, idx_prefix, type_prefix=type_prefix + [type_dict[key]])

    elif isinstance(tree, int):

        return torch.tensor(idx_prefix + [tree] + [-1] * depth).reshape(1, -1), torch.tensor(type_prefix + [0]*depth).reshape(1, -1)

    else:
        return [-1]

In [13]:
idx_matrix, type_matrix = get_matrix(permutation_trees, depth=num_layers)
display(idx_matrix)
display(type_matrix)

tensor([[   0,    0,    1,   -1],
        [   1,    0,    0,   -1],
        [   1,    1,    0,    3],
        ...,
        [2354,    1,    1, 2351],
        [2354,    1,    2, 2353],
        [2355,    0, 2354,   -1]])

tensor([[1, 0],
        [1, 0],
        [1, 2],
        ...,
        [1, 2],
        [1, 2],
        [1, 0]])

In [14]:
# fill -1 in matrix with last values
for i in range(1,idx_matrix.shape[1]):
    # get the -1 entry mask
    mask = idx_matrix[:, i] == -1
    idx_matrix[mask, i] = idx_matrix[mask, i-1]

display(idx_matrix)

tensor([[   0,    0,    1,    1],
        [   1,    0,    0,    0],
        [   1,    1,    0,    3],
        ...,
        [2354,    1,    1, 2351],
        [2354,    1,    2, 2353],
        [2355,    0, 2354, 2354]])

In [15]:
# load first index range of elements
data_array = batch.x[idx_matrix[:, -1]]
display(data_array)
display(data_array.shape)

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]])

torch.Size([6416, 118])

In [16]:
# get structure to orient to
idx_structure = idx_matrix[:, :-1]

In [34]:
# get indexes for graph pooling
idx_structure, current_layer_pooling_counts = torch.unique(idx_structure[:, :-1], dim=0, return_counts=True)
display(idx_structure)
display(current_layer_pooling_counts)

tensor([[   0],
        [   1],
        [   2],
        ...,
        [2353],
        [2354],
        [2355]])

tensor([1, 2, 2,  ..., 1, 2, 1])

In [35]:
# get indexes for graph pooling
current_layer_pooling = torch.repeat_interleave(current_layer_pooling_counts)
display(current_layer_pooling[:10])

tensor([0, 1, 1, 2, 2, 3, 3, 4, 4, 5])

In [36]:
# init circling
# todo: rework for other types
order_matrix = torch.zeros(k, len(current_layer_pooling), dtype=torch.int) - 1

In [37]:
cur_pos = 0
for i in current_layer_pooling_counts:#[:5]:
    current_k = min(k, i)
    r = torch.arange(cur_pos, cur_pos+i)
    for j in range(current_k):
        order_matrix[:current_k, cur_pos+j] = torch.roll(r, shifts=-j)
    cur_pos += i

In [38]:
order_matrix

tensor([[   0,    1,    2,  ..., 4121, 4122, 4123],
        [  -1,    2,    1,  ..., 4122, 4121,   -1],
        [  -1,   -1,   -1,  ...,   -1,   -1,   -1]], dtype=torch.int32)

In [39]:
# add zero padding to data list
data_array = torch.cat([torch.zeros(1, data_array.shape[-1]), data_array], dim=0)
order_matrix += 1

In [40]:
temp = data_array[order_matrix]

In [41]:
temp.shape

torch.Size([3, 4124, 118])

In [42]:
# apply z layer

In [43]:
mask_z = type_matrix[:, -1] == 2

In [44]:
z_layer = torch.nn.ModuleList([
    torch.nn.Linear(118, 118)
    for _ in range(k)
])
z_final_layer = torch.nn.Linear(118, 118)
z_elu = torch.nn.ELU()

In [45]:
# todo: take care of duplicate elements that are sent through linear layer

In [46]:
data_array.shape

torch.Size([4125, 118])

In [47]:
embedding = torch.stack([
    emb(t)
    for emb, t in zip(z_layer,temp)
], dim=1)
display(embedding.shape)

torch.Size([4124, 3, 118])

In [48]:
temp3 = embedding.sum(dim=1)
display(temp3.shape)

torch.Size([4124, 118])

In [49]:
temp4 = z_final_layer(temp3)
display(temp4.shape)

torch.Size([4124, 118])

In [50]:
import torch_geometric
data_array = torch_geometric.nn.global_add_pool(temp4, current_layer_pooling)

In [51]:
data_array.shape

torch.Size([2356, 118])