In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

In [3]:

import sys
import os
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)
from DataPipeline.dataset import ZincSubgraphDatasetStep, custom_collate_GNN1
from DataPipeline.preprocessing import plot_graph

from torch_geometric.utils import to_networkx
import torch

In [4]:
def torch_geometric_to_networkx(data):
    """
    Convert a torch_geometric.data.Data object into a networkx.Graph object.

    Args:
    data (torch_geometric.data.Data): A PyTorch Geometric Data object representing the molecule.

    Returns:
    G (networkx.Graph): A NetworkX Graph object representing the molecule.
    """
    # Modify node features to take the argmax, excluding the last element
    if data.x.shape[1] > 1:
        data.x = torch.argmax(data.x[:, :-1], dim=1).unsqueeze(1)

    # Modify edge features to take the argmax
    if data.edge_attr.shape[1] > 1:
        data.edge_attr = torch.argmax(data.edge_attr, dim=1).unsqueeze(1)

    G = to_networkx(data, node_attrs=['x'], edge_attrs=['edge_attr'])

    for i in G.nodes:
        x_attr = G.nodes[i]['x']
        atomic_num = int(x_attr.item()) if hasattr(x_attr, 'item') else int(x_attr)
        G.nodes[i]['atomic_num'] = atomic_num
        del G.nodes[i]['x']

    for i, j in G.edges:
        edge_attr = G.edges[i, j]['edge_attr']
        bond_type = edge_attr.item() if hasattr(edge_attr, 'item') else edge_attr
        G.edges[i, j]['bond_type'] = bond_type
        del G.edges[i, j]['edge_attr']

    return G

In [5]:
from torch_geometric.data import Batch

def custom_collate(batch):
    sg_data_list = [item[0] for item in batch]
    g_data_list = [item[1] for item in batch]
    terminal_nodes_info_list = [item[2] for item in batch]
    id_map_list = [item[3] for item in batch]
    sg_data_batch = Batch.from_data_list(sg_data_list)
    g_data_batch = Batch.from_data_list(g_data_list)
    return sg_data_batch, g_data_batch, terminal_nodes_info_list, id_map_list

In [6]:
from torch.utils.data import DataLoader

from torch_geometric.data import Batch

In [7]:
datapath = Path('..') / 'DataPipeline/data/preprocessed_graph_no_I_Br_P.pt'
dataset = ZincSubgraphDatasetStep(datapath)

Dataset encoded with size 7


In [8]:
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=0, collate_fn=custom_collate_GNN1)

In [9]:
for batch in dataloader:
    print(batch[0].edge_index)

tensor([[    6,     3,     3,  ..., 13727, 13718, 13717],
        [    3,     6,     7,  ..., 13725, 13717, 13718]])
tensor([[   15,    13,    13,  ..., 13681, 13678, 13680],
        [   13,    15,    12,  ..., 13678, 13680, 13678]])
tensor([[   19,    18,    18,  ..., 13997, 14000, 14004],
        [   18,    19,    17,  ..., 14000, 14004, 14000]])
tensor([[    3,     1,     1,  ..., 13693, 13701, 13698],
        [    1,     3,     0,  ..., 13696, 13698, 13701]])
tensor([[   26,    24,    24,  ..., 13901, 13901, 13897],
        [   24,    26,    25,  ..., 13904, 13897, 13901]])
tensor([[    8,     5,     5,  ..., 13740, 13758, 13757],
        [    5,     8,     4,  ..., 13743, 13757, 13758]])
tensor([[    2,     0,     0,  ..., 13445, 13447, 13442],
        [    0,     2,     1,  ..., 13441, 13442, 13447]])
tensor([[   11,     8,     8,  ..., 13923, 13923, 13925],
        [    8,    11,    10,  ..., 13922, 13925, 13923]])
tensor([[   10,     8,     8,  ..., 13536, 13529, 13527],
      

KeyboardInterrupt: 

In [None]:
import networkx as nx
from tqdm import tqdm

In [None]:
y_list = []

for batch in tqdm(dataloader):

    y_list.append(batch[1])
y_list = torch.stack(y_list)


sum = torch.sum(y_list, dim=0)

100%|██████████| 231/231 [04:25<00:00,  1.15s/it]


RuntimeError: stack expects each tensor to be equal size, but got [1024, 7] at entry 0 and [597, 7] at entry 230

In [None]:
concat = torch.cat(y_list, dim=0)
sum = torch.sum(concat, dim=0)


In [None]:
sum

tensor([ 82792.5859,  13852.4150,  11517.1660,   1634.7500,   2047.0001,
           904.0834, 123369.0000])

In [None]:
print(sum/len(dataset))

tensor([0.3506, 0.0587, 0.0488, 0.0069, 0.0087, 0.0038, 0.5225])


In [None]:
def compute_class_weights(class_proportions):

    # Compute the inverse of the class proportions
    class_weights = 1 / class_proportions

    # Normalize the class weights so they sum up to 1
    class_weights_normalized = class_weights / class_weights.sum()

    return class_weights_normalized

In [None]:
i = 0

for batch in dataloader:
    i += 1
    data_list = Batch.to_data_list(batch[0])
    g_data_list = Batch.to_data_list(batch[1])
    network_subgraph = torch_geometric_to_networkx(data_list[0])
    network_graph = torch_geometric_to_networkx(g_data_list[0])
    id_map = batch[3][0]

    plot_graph(network_subgraph, atom_conversion_type='onehot', encoding_type='reduced')
    plot_graph(network_graph, atom_conversion_type='onehot', encoding_type='reduced', id_map=id_map)
    print(batch[2])
    print(id_map)

    if i == 3:
        break