In [1]:
import torch

from ptgnn.transform.ptree_matrix import permutation_tree_to_matrix

In [2]:
from ptgnn.runtime_config.run_config import fetch_loaders

In [3]:
data_config = {
    'dataset' : {
        'type' : "rs",
        'mask_chiral_tags': True,
        'transformation_mode': 'chienn_tree_basic',
    },
    'loader': {
        'general': {
            'n_neighbors_in_circle': 3,
            'batch_size': 32,
            'num_workers': 0,
        },
        'train': {},
        'test': {},
        'val': {}
    }
}

In [4]:
train_loader, _, val_loader = fetch_loaders(data_config=data_config)

In [5]:
for batch in val_loader:
    display(batch)
    break

DataBatch(x=[2356, 118], edge_index=[2, 6416], edge_attr=[6416, 80], pos=[2356, 6], y=[32], batch=[2356], ptr=[33], ptree=[2356])

# part in collation and part possible in preprocessing

In [6]:
k = 3

In [7]:
idx_matrix, type_matrix = permutation_tree_to_matrix(batch.ptree, k)
display(idx_matrix)
display(type_matrix)

tensor([[   0,    0,    1,    1],
        [   1,    0,    0,    0],
        [   1,    1,    0,    3],
        ...,
        [2354,    1,    1, 2351],
        [2354,    1,    2, 2353],
        [2355,    0, 2354, 2354]])

tensor([[1, 0],
        [1, 0],
        [1, 2],
        ...,
        [1, 2],
        [1, 2],
        [1, 0]])

## From here second part that is either in the model or in preprocessing

In [8]:
# load first index range of elements
data_array = batch.x[idx_matrix[:, -1]]

In [9]:
# get structure to orient to
idx_structure = idx_matrix[:, :-1]

### for loop starts here

In [10]:
# get indexes for graph pooling
idx_structure, current_layer_pooling_counts = torch.unique(idx_structure[:, :-1], dim=0, return_counts=True)
display(idx_structure)
display(current_layer_pooling_counts)

tensor([[   0,    0],
        [   1,    0],
        [   1,    1],
        ...,
        [2354,    0],
        [2354,    1],
        [2355,    0]])

tensor([1, 1, 3,  ..., 1, 3, 1])

In [11]:
# get indexes for graph pooling
current_layer_pooling = torch.repeat_interleave(current_layer_pooling_counts)
display(current_layer_pooling[:10])

tensor([0, 1, 2, 2, 2, 3, 4, 4, 4, 5])

In [12]:
# init circling
# todo: rework for other types
order_matrix = torch.zeros(k, len(current_layer_pooling), dtype=torch.int) - 1

In [13]:
cur_pos = 0
for i in current_layer_pooling_counts:#[:5]:
    current_k = min(k, i)
    r = torch.arange(cur_pos, cur_pos+i)
    for j in range(current_k):
        order_matrix[:current_k, cur_pos+j] = torch.roll(r, shifts=-j)
    cur_pos += i

In [14]:
order_matrix

tensor([[   0,    1,    2,  ..., 6413, 6414, 6415],
        [  -1,   -1,    3,  ..., 6414, 6412,   -1],
        [  -1,   -1,    4,  ..., 6412, 6413,   -1]], dtype=torch.int32)

In [15]:
# add zero padding to data list
data_array = torch.cat([torch.zeros(1, data_array.shape[-1]), data_array], dim=0)
order_matrix += 1

In [16]:
temp = data_array[order_matrix]

In [17]:
temp.shape

torch.Size([3, 6416, 118])

In [18]:
# apply z layer

In [19]:
mask_z = type_matrix[:, -1] == 2

In [20]:
z_layer = torch.nn.ModuleList([
    torch.nn.Linear(118, 118)
    for _ in range(k)
])
z_final_layer = torch.nn.Linear(118, 118)
z_elu = torch.nn.ELU()

In [21]:
# todo: take care of duplicate elements that are sent through linear layer

In [22]:
data_array.shape

torch.Size([6417, 118])

In [23]:
embedding = torch.stack([
    emb(t)
    for emb, t in zip(z_layer,temp)
], dim=1)
display(embedding.shape)

torch.Size([6416, 3, 118])

In [24]:
temp3 = embedding.sum(dim=1)
display(temp3.shape)

torch.Size([6416, 118])

In [25]:
temp4 = z_final_layer(temp3)
display(temp4.shape)

torch.Size([6416, 118])

In [26]:
import torch_geometric
data_array = torch_geometric.nn.global_add_pool(temp4, current_layer_pooling)

In [27]:
data_array.shape

torch.Size([4124, 118])

In [53]:
batch.idx_matrix, batch.type_matrix = permutation_tree_to_matrix(batch.ptree, 3)

In [54]:
device = 'cuda'
batch = batch.to(device)
z_elu = z_elu.to(device)
z_layer = z_layer.to(device)
z_final_layer = z_final_layer.to(device)

In [55]:
# make link to batch idx_matrix
idx_matrix = batch.idx_matrix

# load first index range of elements
data_array = batch.x[idx_matrix[:, -1]]

# get structure to orient to
idx_structure = idx_matrix[:, :-1]

In [56]:
# get indexes for graph pooling
idx_structure, current_layer_pooling_counts = torch.unique(idx_structure[:, :-1], dim=0, return_counts=True)

# get indexes for graph pooling
current_layer_pooling = torch.repeat_interleave(current_layer_pooling_counts)

# init circling
# todo: rework for other types - treat everything as Z
order_matrix = torch.zeros(k, len(current_layer_pooling), dtype=torch.int) - 1

In [57]:
cur_pos = 0
for i in current_layer_pooling_counts:
    current_k = min(k, i)
    r = torch.arange(cur_pos, cur_pos+i)
    for j in range(current_k):
        order_matrix[:current_k, cur_pos+j] = torch.roll(r, shifts=-j)
    cur_pos += i

In [58]:
# add zero padding to data list
data_array = torch.cat([torch.zeros(1, data_array.shape[-1], device=data_array.device), data_array], dim=0)
order_matrix += 1

embedding = data_array[order_matrix]
# mask_z = batch.type_matrix

embedding = torch.stack([
    emb(t)
    for emb, t in zip(z_layer, embedding)
], dim=1)

In [59]:
embedding = embedding.sum(dim=1)
embedding = z_final_layer(embedding)

# global pooling
data_array = torch_geometric.nn.global_add_pool(embedding, current_layer_pooling)