# IF2 Lost Edge Classification
Using the data from the SSI IF2 project, classify the edges that need to be connect and those that need to be disconnected. This is then a binary procedure where a 1 indicates that the edges should be connected and a 0 indicates that the edges should not be connected. In this case, it is best to start with a fully connected graph so that it may score each edge appropriately and does not assume that any particular edge is inherently disconnected.

In [None]:
!pip install git+https://github.com/drinkingkazu/ssi_if
! download_if_dataset.py --challenge=graph --flavor=train
! download_if_dataset.py --challenge=graph --flavor=test

In [None]:
import torch
torch.multiprocessing.set_start_method('spawn')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

import numpy as np
SEED=42
_ = np.random.seed(SEED)
_ = torch.manual_seed(SEED)

## Shower Geometric Features `Dataset` and `DataLoader`
The dataset returns a PyTorch Geometric object with the following attributes:
* `x`: A `(C, 16)` tensor of node features;
* `edge_index`: a `(2, E)` edge incidence matrix;
* `edge_attr`: a `(E, 19)` tensor of edge features;
* `y`: a `(C)` vector of node labels (primary IDs: 1 if primary, 0 if not);
* `edge_label`: a `(E)` vector of edge labels (1 if connects two nodes in the same group, 0 otherwise);
* `index`: a scalar representing the entry indices;
* `C`: The number of "clusters" (nodes);
* `E`: The number of edges.

In [None]:
datapath = 'if-graph-train.h5'
train_data = ShowerFeatures(file_path = datapath)

from torch_geometric.loader import DataLoader as GraphDataLoader
train_loader = GraphDataLoader(train_data,
                               shuffle = True,
                               num_workers = 4,
                               batch_size = 64
                              )

# Model Choice
The design of model will make use of both `MessagePassing` and iterative "pruning" to build a new graph that is more sparse and approaches the true connections. Each block of `MP` aims to update the edge features and score an edge if it should be connected or disconnected. The iterative pruning will then make use of this scoring to connect edges in order of highest to lowest score while maintaining a non-increasing cross entropy loss.

In [None]:
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') # Max aggregation.
        self.mlp = Sequential(Linear(2 * in_channels, out_channels),
                                     ReLU(),
                                     Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        """
        x has shape [N, in_channels]
        edge_index has shape [2, E]
        """
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        """
        x_i, x_j both have the shape [E, in_channels].
        """
        tmp = torch.cat([x_i, x_j], dim=1) # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

In [None]:
class Node2Edge(nn.Module):
    def __init__(self):
        super(Node2Edge, self).__init__()

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        return x[edge_index[0]] - x[edge_index[1]]

In [None]:
class LostEdgePruning(nn.Module):
    def __init__(self, num_node_features):
        super(LostEdgePruning, self).__init__()
        self.batch_norm = BatchNorm(num_node_features)
        self.node2edge = Node2Edge()
        self.conv1 = EdgeConv(num_node_features, num_node_features/2)
        self.lin1 = Sequential(Linear(num_node_features/2, 1))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.batch_norm(x)
        x = self.conv1(x, edge_index)
        x = self.lin1(x)
        x = self.node2edge(x)
        return x = F.sigmoid(x)

# Training
Start the training.

In [None]:
import torch.optim as optim
from torch.nn import CrossEntropyLoss

loss_fn = CrossEntropyLoss()
num_node_feats = 16
gnn = LostEdgePruning(node_feat_size)
opt = optim.Adam(model.parameters(), lr = 0.001)
def train(model, num_epochs = 5):
    model.train(True)
    tot_iter = len(train_loader)
    
    for ep in range(num_epochs):
        running_loss, running_correct = 0, 0
        for data in train_loader:
            inputs = data
            labels = data.edge_label
            inputs = inputs.to(device)
            labels = labels.to(device)

            opt.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            opt.step()

            running_loss += loss.item()
            
        epoch_loss = running_loss / tot_iter
        print(epoch_loss)