In [1]:
import numpy as np
import torch
import pickle as pkl
import ipycytoscape
import networkx as nx

def vis(G):
    cso = ipycytoscape.CytoscapeWidget()
    cso.graph.add_graph_from_networkx(G)
    cso.set_style([
                            {
                                'selector': 'node',
                                'css': {
                                    'background-color': 'red',
                                    'content': 'data(node_label)' #
                                }
                            },
                                                    {
                                'selector': 'edge',
                                'css': {
                                    'content': 'data(edge_label)' #
                                }
                            }
                
                ])

    for i in range(len(cso.graph.nodes)):
        id = int(cso.graph.nodes[i].data['id'])
        label = cso.graph.nodes[i].data['node_label']
        new_label = f"{id}: {label}"
        cso.graph.nodes[i].data['node_label'] = new_label


    # for i in range(len(cso.graph.edges)):
    #     label = cso.graph.edges[i].data['edge_label']
    #     new_label = f"{label}"
    #     cso.graph.edges[i].data['edge_label'] = new_label

    return cso
    
# Test it with output graph
import pickle
#with open('datasets/DD/data.pkl','rb') as f:
with open('../datasets/ZINC_TEST/data.pkl','rb') as f:
    data = pickle.load(f)
out = vis(data[3])
display(out)

CytoscapeWidget(cytoscape_layout={'name': 'cola'}, cytoscape_style=[{'selector': 'node', 'css': {'background-c…

In [28]:

class SparseGraph:
    # Convert a graph to a sparse representation (numpy matrices)
    def __init__(self, G, num_Edge_classes=3, num_Node_classes=23):
        # Convert a networkx graph (with edge and node labels) to a sparse graph format

        # Edge index Matrix
        idxs = np.array(G.edges).transpose() # (2,|E|) dim. array idxs[:,j] = [u,v]^T indicates endpoints of j'th edge e=u->v
        idxs = np.concatenate((idxs, idxs[[1,0]]), axis=1) # idxs[[1,0]] flips the two rows ie [u,v]^T -> [v,u]^T, so by concat now have (2, 2*|E|)
        self.idxs = torch.from_numpy(idxs) #.astype(np.float32))

        # Node features
        Xv = np.array([G.nodes[idx]['node_label'] for idx in G.nodes]).transpose() # Node feature matrix of dim (reshape: (|V|,) -> (|V|,1))
        #Xv = torch.from_numpy(Xv.astype(np.float32))
        self.Xv = torch.nn.functional.one_hot(torch.tensor(Xv, dtype=torch.int64), num_classes=23).to(torch.float32)

        # Edges features
        Xe = np.array([G.edges[idx]['edge_label'] for idx in G.edges]).transpose() # Edge feature matrix of dim (reshape: (|E|,) -> (|E|,1))
        Xe = np.concatenate((Xe,Xe), axis=0) - 1 # For some reason class labels are {1,2,3} and not {0,1,2}...
        self.Xe = torch.nn.functional.one_hot(torch.tensor(Xe, dtype=torch.int64), num_classes=3).to(torch.float32)

        # Get Graph features
        y = G.graph['label']
        self.y = torch.from_numpy(y.astype(np.float32))

        # Set Batch_idx (just here for compability)
        self.batch_idx = torch.zeros((Xv.shape[0]), dtype=torch.int64)

    def to_gpu(self):
        # Transfer all tensors from cpu to gpu/cuda
        self.y.to('cuda')
        self.idxs.to('cuda')
        self.Xe.to('cuda')
        self.Xv.to('cuda')
        


    def to_nx(self):
        # TODO Update to account for OHE encoding of vectors
        # Convert the sparse graph back to a networkx gaph g

        # Convert tensors to numpy
        idxs = self.idxs.numpy().astype('int')
        Xv = self.Xv.numpy()
        Xe = self.Xe.numpy()

        g = nx.Graph() # Empty nx graph

        # Add edges (nodes added automatically)
        for j in range(idxs.shape[1]):
            g.add_edge(idxs[0,j], idxs[1,j])
        
        # Set Node and Edge Weights
        nx.set_node_attributes(g, {idx: Xv[idx] for idx in range(Xv.shape[0])}, "node_label")
        nx.set_edge_attributes(g, {(idxs[0,idx], idxs[1,idx]): Xe[idx] for idx in range(int(Xe.shape[0]/2))}, "edge_label")

        # TODO: Convert graph label in networkx
        return g



class MyDataset(torch.utils.data.Dataset):
    def __init__(self, nx_graph_list):
        self.np_sparse_graphs = [SparseGraph(g) for g in nx_graph_list]

    def __len__(self):
        return len(self.np_sparse_graphs)
    
    def __getitem__(self, idx):
        return self.np_sparse_graphs[idx]
        #return torch.from_numpy(sg.idxs), torch.from_numpy(sg.Xv), torch.from_numpy(sg.Xe), torch.from_numpy(sg.Xe)


SG = SparseGraph(data[3])
# G1 = SG.to_nx()
# vis(G1)

def MyCollate(sparse_graph_list):
    #sparse_graph_list = [SparseGraph(data[0]), SparseGraph(data[1]), SparseGraph(data[2]) ]
    #sgl = sparse_graph_list

    # Create empty SparseGraph Object (avoid calling init, we will initialize here alreadt)
    output = SparseGraph.__new__(SparseGraph)

    # By joining graphs, the node indexes need to she shifted
    # Ie if the first graph has 10 nodes, then for the second graph the node indexes 0,1,2,... --> 10,11,12,...

    # compute batch_idx matrix, and a lookup table for how much to shift each graph's nodes indexes by
    node_idx_shift = [0] # Lookup table for the node index shift of each graph
    batch_idx = []
    tot_num_nodes = 0 # Total number of nodes
    for i,sg in enumerate(sparse_graph_list):
        num_nodes = sg.Xv.shape[0]
        tot_num_nodes += num_nodes
        node_idx_shift.append(tot_num_nodes)
        batch_idx += [i]*num_nodes

    # First shift all the node indexes in each graph, and concatenate them
    output.idxs = torch.cat([sg.idxs + torch.from_numpy(np.array([node_idx_shift[i], node_idx_shift[i]]).transpose().reshape(-1,1))  # idxs + [idx_shift, idx_shift]^T
                            for i, sg in enumerate(sparse_graph_list)],
                        dim = 1)

    # Change batch_idx type to tensor
    output.batch_idx = torch.tensor(np.array(batch_idx), dtype=torch.int64)

    # Concatenate Node and Edge feature vectors, and graph labels
    output.Xv = torch.cat([sg.Xv  for sg in sparse_graph_list])
    output.Xe = torch.cat([sg.Xe for sg in sparse_graph_list])
    output.y = torch.cat([sg.y for sg in sparse_graph_list])

    return output

sgl = [SparseGraph(data[0]), SparseGraph(data[1])] #SparseGraph(data[2]) ]
res = MyCollate(sgl)


In [34]:
import torch_scatter

from torch import nn
class GNN_U(torch.nn.Module):
    # TODO: Actually implement this! Just dummy so far (!!depth attribute!!)

    def __init__(self, in_features, out_features, depth):
        super(GNN_U, self).__init__()
        self.fc = nn.Linear( in_features, out_features)
        self.dropout = torch.nn.Dropout(p=0.2)
        self.relu = torch.nn.ReLU()
        
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        x = self.relu(x) 
        return x

class GNN_M(torch.nn.Module):
    # TODO: Actually implement this! Just dummy so far (!!depth attribute!!)
    def __init__(self, in_features, out_features, depth):
        super(GNN_M, self).__init__()
        self.fc = nn.Linear( in_features, out_features)
        self.dropout = torch.nn.Dropout(p=0.2)
        self.relu = torch.nn.ReLU() 
        
    def forward(self, x):
        x = self.fc(x)
        self.dropout(x)
        x = self.relu(x) 
        return x


class GNN_layer(torch.nn.Module):
    def __init__(self, in_features, out_features, Xe_width, scatter_func='SUM', U_depth=2, M_depth=2, M_width=2):
        super(GNN_layer, self).__init__()

        # Initialize Scatter function
        if type(scatter_func) == type('str'):
            if scatter_func.lower()=='sum':
                self.scatter_agg = torch_scatter.scatter_sum
            elif scatter_func.lower()=='max':
                self.scatter_agg = torch_scatter.scatter_max
            elif scatter_func.lower()=='mean':
                self.scatter_agg = torch_scatter.scatter_mean
            else:
                import warnings
                warnings.warn("scatter_function unknown! Defaulting to \"SUM\"")
                self.scatter_agg = torch_scatter.scatter_add
        else: 
            # Custom scatter function
            self.scatter_agg = scatter_func

        # Initialize M and U Neural Nets
        self.M = GNN_M(in_features + Xe_width, M_width, M_depth)
        self.U = GNN_U(in_features + M_width, out_features, U_depth)

        # Define parameter list (needed for optimizer)
        self.param_list = list(self.M.parameters()) + list( self.U.parameters())

    def forward(self, H, sparse_graph):
        Y = self.M.forward(torch.cat((H[sparse_graph.idxs[0,:]], sparse_graph.Xe), dim=1)) # (2|E|, in_features + Xe_width) -> (2|E|, M_width)
        # TODO: Special case for max
        Z = self.scatter_agg(Y, sparse_graph.idxs[1,:], dim=0) # (2|E|, M_width) -> (|V|, M_width)
        return self.U.forward(torch.cat((H,Z), dim=1)) # (|V|, H_width + M_width) -> (|V|, out_features)

class GNN_skip_layer(GNN_layer):
    # Wraps a GNN_layer with a skip connection (note out_features=in_features enforced, otherwise identical)

    def __init__(self, in_features, Xe_width, scatter_func='SUM', U_depth=2, M_depth=2, M_width=2):
        # Identical to GNN_layer, just that now out_features=in_features
        super(GNN_skip_layer, self).__init__(in_features, in_features, Xe_width, scatter_func='SUM', U_depth=2, M_depth=2, M_width=2)

    def forward(self, H, sparse_graph):
        return H + super(GNN_skip_layer, self).forward(H, sparse_graph)


class GNN_pool(torch.nn.Module):
    def __init__(self, scatter_func='sum'):
        super(GNN_pool, self).__init__()
        
        # Initialize Scatter function
        if type(scatter_func) == type('str'):
            if scatter_func.lower()=='sum':
                self.scatter_agg = torch_scatter.scatter_sum
            elif scatter_func.lower()=='max':
                self.scatter_agg = torch_scatter.scatter_max
            elif scatter_func.lower()=='mean':
                self.scatter_agg = torch_scatter.scatter_mean
            else:
                import warnings
                warnings.warn("scatter_function unknown! Defaulting to \"SUM\"")
                self.scatter_agg = torch_scatter.scatter_add
        else: 
            # Custom scatter function
            self.scatter_agg = scatter_func

        # Parameter list (empty, just here for compatabillity)
        self.param_list = []

    def forward(self, H, sparse_graph):
        return torch_scatter.scatter_sum(H, sparse_graph.batch_idx, dim=0)


class GNN_virtual_node(torch.nn.Module):
    # Virtual Node

    def __init__(self, in_features):
        super(GNN_virtual_node, self).__init__()
        self.fc = nn.Linear(in_features, in_features)
        self.relu = torch.nn.ReLU() 
        self.dropout = torch.nn.Dropout(p=0.2)
        self.relu = torch.nn.ReLU()
        self.param_list = self.parameters()

    def forward(self, H, sparse_graph):
        # Compute sum over all nodes for each graph, H_sum_graph[i] = sum of H over all nodes in the i'th subgrpah
        H_sum_graph = torch_scatter.scatter_sum(H, sparse_graph.batch_idx, dim=0)
        H_sum_graph = self.fc(H_sum_graph)
        H_sum_graph = self.dropout(H_sum_graph)
        H_sum_graph = self.relu(H_sum_graph)

        # Cast graph sum back to every node, and then add to H (skip connection)
        return H + H_sum_graph[sparse_graph.batch_idx]


class GNN_wrapper():
    # Wraps a module like nn.Linear which only applied to H (and not (H, sparse_graph)
    def __init__(self, module):
        self.module = module
        self.param_list = self.module.parameters() 

    def forward(self, H, sparse_graph):
        return self.module.forward(H)

class GNN(torch.nn.Module):
    def __init__(self, Xv_width, Xe_width):
        super(GNN, self).__init__()

        self.layers = []
        
        # First layer, (input dimension must match the initial dimension)
        self.layers.append(GNN_layer(in_features=Xv_width, out_features=10, Xe_width=Xe_width, scatter_func='mean', U_depth=2, M_depth=2, M_width=2))

        self.layers.append(GNN_virtual_node(10))
        # TODO: Add more layers (and actually figure out what needs to be done here)

        self.layers.append(GNN_skip_layer(in_features=10, Xe_width=Xe_width, scatter_func='mean', U_depth=2, M_depth=2, M_width=2))
        self.layers.append(GNN_virtual_node(10))
        
        # self.layers.append(GNN_skip_layer(in_features=10, Xe_width=Xe_width, scatter_func='mean', U_depth=2, M_depth=2, M_width=2))
        # self.layers.append(GNN_virtual_node(10))
        
        # self.layers.append(GNN_skip_layer(in_features=10, Xe_width=Xe_width, scatter_func='mean', U_depth=2, M_depth=2, M_width=2))
        # self.layers.append(GNN_virtual_node(10))
        
        # Final layer (For node level classication, should be exactly the size of the node output, for anything else, it should be different)
        self.layers.append(GNN_layer(in_features=10, out_features=10, Xe_width=Xe_width, scatter_func='mean', U_depth=2, M_depth=2, M_width=2))
        self.layers.append(GNN_wrapper(torch.nn.Linear(10,1)))
        self.layers.append(GNN_pool('mean'))

        # Build list of parameters (needed for optimizer)
        self.param_list = []
        for layer in self.layers:
            self.param_list += layer.param_list


    def forward(self, sparse_graph):
        # Initial Hidden node layers
        H = sparse_graph.Xv

        # Reshape if neccessary
        if len(H.shape)==1:
            H = H.reshape((-1,1))

        # Reshape Edge feature matrix if neccessarry
        if len(sparse_graph.Xe.shape)==1:
            sparse_graph.Xe = sparse_graph.Xe.reshape((-1,1))
            import warnings
            warnings.warn("Needed to reshape Xe!!")

        # Actual forward Pass of H through layers
        for layer in self.layers:
            H = layer.forward(H, sparse_graph)
        return H        

#Simple Example
gnn = GNN(Xv_width=23, Xe_width=3)
#sparse_graph = ZINC2sparse(data[0])
sparse_graph = MyCollate([SparseGraph(data[0]), SparseGraph(data[1]), SparseGraph(data[2])])
H = gnn.forward(sparse_graph)
H.shape

torch.Size([3, 1])

In [29]:
# Import Datasets and initliaze dataloaders
import pickle
with open('../datasets/ZINC_Train/data.pkl','rb') as f:
    data = pickle.load(f)
    train_loader = torch.utils.data.DataLoader(MyDataset(data), batch_size=15, collate_fn=MyCollate)

with open('../datasets/ZINC_Test/data.pkl','rb') as f:
    data = pickle.load(f)
    test_loader = torch.utils.data.DataLoader(MyDataset(data), batch_size=15, collate_fn=MyCollate)

with open('../datasets/ZINC_Val/data.pkl','rb') as f:
    data = pickle.load(f)
    validate_loader = torch.utils.data.DataLoader(MyDataset(data), batch_size=15, collate_fn=MyCollate)


In [35]:
def train_epoch(gnn_model, dataloader, optimizer, loss_fn):
    # Trains a gnn model for one epoch
    gnn_model.train()
    sum_loss = 0
    counter = 0
    for sparse_graph in dataloader:
        if use_gpu:
            sparse_graph.to_gpu()

        # Reset gradients
        optimizer.zero_grad()
        
        # do forward pass
        output = gnn_model.forward(sparse_graph).reshape((-1))
        
        # calculate loss
        loss = loss_fn(output, sparse_graph.y)
        sum_loss += loss.item()
        counter += 1
        
        # backpropagate loss and do parameter updates
        loss.backward()
        optimizer.step()
    return gnn_model, sum_loss / counter

def validate(gnn_model, dataloader, loss_fn):
    # Trains a gnn model for one epoch
    gnn_model.eval()
    
    sum_loss = 0
    counter = 0
    for sparse_graph in dataloader:
        output = gnn_model.forward(sparse_graph).reshape((-1))
        loss = loss_fn(output, sparse_graph.y)
        sum_loss += loss.item()
        counter += 1
    return sum_loss / counter

import torch.optim as optim
gnn = GNN(Xv_width=23, Xe_width=3)
optimizer = optim.Adam(gnn.param_list, lr=0.0001)
loss_fn = torch.nn.L1Loss()
gnn.train()

use_gpu = True
if use_gpu:
    gnn.to('cuda')

for epoch in range(50):
    gnn, train_loss = train_epoch(gnn, train_loader, optimizer, loss_fn)
    val_loss = validate(gnn, validate_loader, loss_fn)
    print(f"{epoch}: Train_loss = {train_loss}, Validation_loss = {val_loss}")


0: Train_loss = 31.878885597303352, Validation_loss = 9.103466055286464


In [25]:
torch.cuda.is_available()
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.get_device_name(0)

'NVIDIA GeForce GTX 1050 Ti with Max-Q Design'