In [20]:
import numpy as np
from models.vgae import VariationalEncoder, L1VGAE
import torch
from utils.dataset import split_dataset

np.random.seed(1)
torch.manual_seed(1)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_train, data_test, data_val, = split_dataset(True)

in_channels, out_channels, lr, n_epochs = data_train[0].num_features, 20, 0.001, 300


vae_layers, alpha, threshold = 2, 0.5, 0.65
vae = L1VGAE(VariationalEncoder(in_channels, out_channels, layers=vae_layers, molecular=True, transform=True), device)
vae.load_state_dict(torch.load(
    '../results/model_GCN/graph_split_/layers_' + str(
        vae_layers) + '/transform_'+str(True)+'/alpha_' + str(
        alpha) +'/model.pt'))

<All keys matched successfully>

In [21]:
from zinc_classifier import transform_zinc_dataset, transform_zinc_dataset_with_weights

data_train = transform_zinc_dataset_with_weights(vae, data_train, 0.8)
data_test = transform_zinc_dataset_with_weights(vae, data_test, 0.8)

In [22]:
data_train = data_train[0:10]
print(data_train)

[Data(x=[29, 1], edge_index=[2, 64], edge_attr=[64], y=[1], laplacian_eigenvector_pe=[29, 5], vr_edge_index=[2, 160], vr_edge_weight=[160]), Data(x=[26, 1], edge_index=[2, 56], edge_attr=[56], y=[1], laplacian_eigenvector_pe=[26, 5], vr_edge_index=[2, 146], vr_edge_weight=[146]), Data(x=[16, 1], edge_index=[2, 34], edge_attr=[34], y=[1], laplacian_eigenvector_pe=[16, 5], vr_edge_index=[2, 56], vr_edge_weight=[56]), Data(x=[27, 1], edge_index=[2, 60], edge_attr=[60], y=[1], laplacian_eigenvector_pe=[27, 5], vr_edge_index=[2, 173], vr_edge_weight=[173]), Data(x=[21, 1], edge_index=[2, 44], edge_attr=[44], y=[1], laplacian_eigenvector_pe=[21, 5], vr_edge_index=[2, 95], vr_edge_weight=[95]), Data(x=[28, 1], edge_index=[2, 60], edge_attr=[60], y=[1], laplacian_eigenvector_pe=[28, 5], vr_edge_index=[2, 214], vr_edge_weight=[214]), Data(x=[19, 1], edge_index=[2, 38], edge_attr=[38], y=[1], laplacian_eigenvector_pe=[19, 5], vr_edge_index=[2, 91], vr_edge_weight=[91]), Data(x=[16, 1], edge_inde

In [23]:
def remove_repeating_edges(data):
    data_copy = []

    for graph in data:
        graph.best_edge_index = graph.edge_index
        #remove edges from edge_index to vr_edge_index

        a = graph.edge_index.T
        b = graph.vr_edge_index.T

        # get the common element
        # get the common element
        common = set(tuple(x.tolist()) for x in a) & set(tuple(x.tolist()) for x in b)

        # remove common element from b
        b = torch.stack([x for x in b if tuple(x.tolist()) not in common and x[0] != x[1]])

        a = torch.cat([a,b], dim=0)

        graph.new_edge_index = a.T

        data_copy.append(graph)

    return data_copy

data_train = remove_repeating_edges(data_train)


In [17]:
from matplotlib import pyplot as plt
import networkx as nx
from torch_geometric.data import Data

for graph in data_train:


    G1 = Data(edge_index=graph.edge_index, num_nodes=graph.x.shape[0])
    G2 = Data(edge_index=graph.new_edge_index, num_nodes=graph.x.shape[0])

    fig, (ax1, ax2) = plt.subplots(1, 2)

    # plot the first graph on the first column
    nx.draw(G1, ax=ax1, with_labels=True)
    ax1.set_title("True Graph")

    # plot the second graph on the second column
    nx.draw(G2, ax=ax2, with_labels=True)
    ax2.set_title("Rewired Graph")

    plt.show()
    # plt.show()


TypeError: 'NoneType' object is not iterable