<a href="https://colab.research.google.com/github/aeoranday/SSI_Projects/blob/connect-first/if_projects/Lost-Edge-Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# IF2 Lost Edge Classification
Using the data from the SSI IF2 project, classify the edges that need to be connect and those that need to be disconnected. This is then a binary procedure where a 1 indicates that the edges should be connected and a 0 indicates that the edges should not be connected. In this case, it is best to start with a fully connected graph so that it may score each edge appropriately and does not assume that any particular edge is inherently disconnected.

In [None]:
!pip install git+https://github.com/drinkingkazu/ssi_if
! download_if_dataset.py --challenge=graph --flavor=train
! download_if_dataset.py --challenge=graph --flavor=test

Collecting git+https://github.com/drinkingkazu/ssi_if
  Cloning https://github.com/drinkingkazu/ssi_if to /tmp/pip-req-build-4csb965y
  Running command git clone --filter=blob:none --quiet https://github.com/drinkingkazu/ssi_if /tmp/pip-req-build-4csb965y
  Resolved https://github.com/drinkingkazu/ssi_if to commit af38e2ce0730ec5a3091a849bee9e8e53d58042d
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fire (from iftools==0.1)
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch_geometric (from iftools==0.1)
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdo

In [None]:
import torch
torch.multiprocessing.set_start_method('spawn')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

import numpy as np
SEED=42
_ = np.random.seed(SEED)
_ = torch.manual_seed(SEED)

## Shower Geometric Features `Dataset` and `DataLoader`
The dataset returns a PyTorch Geometric object with the following attributes:
* `x`: A `(C, 16)` tensor of node features;
* `edge_index`: a `(2, E)` edge incidence matrix;
* `edge_attr`: a `(E, 19)` tensor of edge features;
* `y`: a `(C)` vector of node labels (primary IDs: 1 if primary, 0 if not);
* `edge_label`: a `(E)` vector of edge labels (1 if connects two nodes in the same group, 0 otherwise);
* `index`: a scalar representing the entry indices;
* `C`: The number of "clusters" (nodes);
* `E`: The number of edges.

In [None]:
from iftool.gnn_challenge import ShowerFeatures
datapath = 'if-graph-train.h5'
train_data = ShowerFeatures(file_path = datapath)

from torch_geometric.loader import DataLoader as GraphDataLoader
train_loader = GraphDataLoader(train_data,
                               shuffle = True,
                               num_workers = 4,
                               batch_size = 64
                              )



# Model Choice
The design of model will make use of both `MessagePassing` and iterative "pruning" to build a new graph that is more sparse and approaches the true connections. Each block of `MP` aims to update the edge features and score an edge if it should be connected or disconnected. The iterative pruning will then make use of this scoring to connect edges in order of highest to lowest score while maintaining a non-increasing cross entropy loss.

In [None]:
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing, BatchNorm

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') # Max aggregation.
        self.mlp = Sequential(Linear(2 * in_channels, out_channels),
                                     ReLU(),
                                     Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        """
        x has shape [N, in_channels]
        edge_index has shape [2, E]
        """
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        """
        x_i, x_j both have the shape [E, in_channels].
        """
        tmp = torch.cat([x_i, x_j], dim=1) # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

In [None]:
class UnionFind():
    def __init__(self, num_nodes):
        self.parent = torch.zeros(num_nodes, dtype=torch.int)
        self.size = torch.ones(num_nodes, dtype=torch.int)
        for i in range(num_nodes):
            self.parent[i] = i

    def find(self, i):
        while (self.parent[i] != i):
            self.parent[i] = self.parent[self.parent[i]]
            i = self.parent[i]
        return i

    def union(self, i, j):
        # Use the roots for nodes i and j
        i = self.find(i)
        j = self.find(j)

        # If the same, nothing to unify.
        if i == j:
            return

        # Keep it so that node i always has a larger size
        if self.size[i] < self.size[j]:
            tmp = i
            i = j
            j = tmp

        self.parent[j] = i
        self.size[i] += self.size[j]

class IterPrune(torch.nn.Module):
    def __init__(self):
        super(IterPrune, self).__init__()

    def _same_partition(self, g, i, j):
        """
        g is of type UnionFind.
        i, j are nodes.
        """
        return g.find(i) == g.find(j)

    def forward(self, scores):
        """
        Takes the edge label scores as input and iteratively selects the
        relevant edges that strictly lower the partition cross entropy.

        Returns the updated edge scores (1 or 0) and the connected edge index.
        """
        score_mask = scores.detach().clone()

        g = UnionFind(score_mask.shape[0]) # Used to check that two nodes are in the same subgraph
        L = 0
        kept_edge_count = 0
        pruned_edges = torch.zeros(edge_index.shape) # Slightly large, slice later.
        while (new_L <= L):
            max_edge = torch.argmax(score_mask)
            I,J = edge_index[0, max_edge], edge_index[1, max_edge]
            g.union(I,J) # Unify nodes I and J to the same subgraph
            score_mask[max_edge] = -1 # Don't look at this score in later iterations

            new_L = 0
            for edge_idx, score in enumerate(scores):
                i,j = edge_index[0, edge_idx], edge_index[1, edge_idx]
                if self._same_partition(g,i,j):
                    new_L += torch.log(score)
                else:
                    new_L += torch.log(1 - score)
            new_L *= -1
            if (new_L < L):
                pruned_edges[0, kept_edge_count] = I
                pruned_edges[1, kept_edge_count] = J
                kept_edge_count += 1
                L = new_L

        # Get the new scoring based on the iterations
        connected_edges = torch.where(score_mask == -1)
        disconnected_edges = torch.where(score_mask != -1)
        scores[connected_edges] = 1
        scores[disconnected_edges] = 0
        return scores, pruned_edges[:,:kept_edge_count]

In [None]:
class Node2Edge(torch.nn.Module):
    def __init__(self):
        super(Node2Edge, self).__init__()

    def forward(self, x, edge_index):
        return x[edge_index[0]] - x[edge_index[1]]

In [None]:
class LostEdgePruning(torch.nn.Module):
    def __init__(self, num_node_features):
        super(LostEdgePruning, self).__init__()
        self.batch_norm = BatchNorm(num_node_features)
        self.node2edge = Node2Edge()
        self.conv1 = EdgeConv(num_node_features, num_node_features//2)
        self.conv2 = EdgeConv(num_node_features, num_node_features//2)
        self.lin1 = Sequential(Linear(num_node_features//2, 1))
        self.lin2 = Sequential(Linear(num_node_features//2, 1))
        self.iter_prune = IterPrune()

    def forward(self, data):
        x0, edge_index0 = data.x, data.edge_index
        x0 = self.batch_norm(x0)
        x = self.conv1(x, edge_index0)
        x = self.lin1(x)
        x = self.node2edge(x, edge_index)
        x = F.sigmoid(x)
        x1, edge_index1 = self.iter_prune(x)
        x = self.conv2(x0, edge_index1)
        x = self.lin2(x)
        x = self.node2edge(x, edge_index1)
        x = F.sigmoid(x)
        x2, edge_index2 = self.iter_prune(x)
        return x2

# Training
Start the training.

In [None]:
import torch.optim as optim
from torch.nn import BCELoss

loss_fn = BCELoss()
num_node_feats = 16
gnn = LostEdgePruning(num_node_feats)
gnn.to(device)
opt = optim.Adam(gnn.parameters(), lr = 0.001)
def train(model, num_epochs = 5):
    model.train(True)
    tot_iter = len(train_loader)

    for ep in range(num_epochs):
        running_loss, running_correct = 0, 0
        for data_idx, data in enumerate(train_loader):
            data.to(device)
            inputs = data
            labels = data.edge_label.reshape(-1,1).float()

            opt.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            opt.step()

            running_loss += loss.item()
            if (data_idx % 50 == 0):
                print(data_idx, loss.item())

        epoch_loss = running_loss / tot_iter
        print(epoch_loss)

In [None]:
train(gnn)



0 0.6937949061393738
50 0.6932360529899597
100 0.6931782960891724


KeyboardInterrupt: ignored