In [13]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing, global_mean_pool, GCNConv
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.loader import DataLoader

# To compute distances
# from scipy.spatial.distance import cdist

# Construct the GCN for regression 

In [14]:
# First of all we have to define the dataset
# Prepare dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG', use_node_attr=True)

torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:10]
test_dataset = dataset[10:20]

In [15]:
# Define the dataloader to take as input graph pairs
from torch_geometric.data import Data
class PairData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_1':
            return self.x_1.size(0)
        if key == 'edge_index_2':
            return self.x_2.size(0)
        return super().__inc__(key, value, *args, **kwargs)

In [16]:
# Construct all pairs of graphs using PairData object
# Distance for now is initialized at random in a value between (-1, 1) like a cosine distance would
import random
train_data_list = []
for ind1, graph1 in enumerate(train_dataset):
    for ind2, graph2 in enumerate(train_dataset[ind1+1:]):
        # ind2 += (ind1 + 1)
        train_data_list.append(PairData(x_1=graph1.x, edge_index_1=graph1.edge_index,
                            x_2=graph2.x, edge_index_2=graph2.edge_index,
                            distance = random.uniform(-1,1)))   

test_data_list = []
for ind1, graph1 in enumerate(test_dataset):
    for ind2, graph2 in enumerate(test_dataset[ind1+1:]):
        # ind2 += (ind1 + 1)
        test_data_list.append(PairData(x_1=graph1.x, edge_index_1=graph1.edge_index,
                            x_2=graph2.x, edge_index_2=graph2.edge_index,
                            distance = random.uniform(-1,1)))   


In [17]:
len(train_data_list) #binom(n,2) = n(n-1)/2

45

In [18]:
class GCN_pairs(torch.nn.Module):
    """
    Takes as input a pair of graphs which are both fed through 3 convolutional layers each followed by an activation function
    3 graph convolutional layers (Welling) that share parameters
    """
    def __init__(self, input_features, hidden_channels, output_embeddings):
        super(GCN_pairs, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(input_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, output_embeddings)

    def forward(self, x1, edge_index1, batch1, x2, edge_index2, batch2): # Need a way to extract these from dataloader
        # 1. Obtain node embeddings for graph 1
        x1 = self.conv1(x1, edge_index1)
        x1 = x1.relu()
        x1 = self.conv2(x1, edge_index1)
        x1 = x1.relu()
        x1 = self.conv3(x1, edge_index1)
        # 2. Readout layer
        x1 = global_mean_pool(x1, batch1)  # [batch_size, hidden_channels]
        # 3. Apply a final linear transformation on the aggregated embedding
        x1 = torch.nn.functional.dropout(x1, p=0.5, training=self.training)
        x1 = self.lin(x1)

        # 1. Obtain node embeddings for graph 2
        x2 = self.conv1(x2, edge_index2)
        x2 = x2.relu()
        x2 = self.conv2(x2, edge_index2)
        x2 = x2.relu()
        x2 = self.conv3(x2, edge_index2)
        # 2. Readout layer
        x2 = global_mean_pool(x2, batch2)  # [batch_size, hidden_channels]
        # 3. Apply a final linear transformation on the aggregated embedding
        x2 = torch.nn.functional.dropout(x2, p=0.5, training=self.training)
        x2 = self.lin(x2)

        #print(f"x1 has shape: {x1.shape} and is:\n\t{x1}")
        #print(f"x2 has shape: {x2.shape} and is:\n\t{x2}")
        return (x1 - x2).pow(2).sum(1) # This is the euclidean distance between the two returns [dist of pair1, dist of pair2, ..]

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GCN_pairs(input_features = dataset.num_node_features, hidden_channels=64, output_embeddings=300).to(device)
print(model)

GCN_pairs(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=300, bias=True)
)


In [20]:
train_loader = DataLoader(train_data_list, batch_size=16, follow_batch=['x_1', 'x_2'])
test_loader = DataLoader(test_data_list, batch_size=16, follow_batch=['x_1', 'x_2'])
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    data = data.to(device)
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print(data.distance)
    print()

Step 1:
Number of graphs in the current batch: 16
PairDataBatch(x_1=[223, 7], x_1_batch=[223], x_1_ptr=[17], edge_index_1=[2, 482], x_2=[245, 7], x_2_batch=[245], x_2_ptr=[17], edge_index_2=[2, 528], distance=[16])
tensor([ 0.1234, -0.3736,  0.7904,  0.6997, -0.4735, -0.0471, -0.6616,  0.5039,
         0.3501,  0.7905,  0.0886, -0.0820, -0.8334, -0.1336,  0.9230, -0.4689])

Step 2:
Number of graphs in the current batch: 16
PairDataBatch(x_1=[243, 7], x_1_batch=[243], x_1_ptr=[17], edge_index_1=[2, 508], x_2=[248, 7], x_2_batch=[248], x_2_ptr=[17], edge_index_2=[2, 544], distance=[16])
tensor([-0.9048,  0.8163, -0.4775,  0.4399,  0.7116,  0.1004,  0.1144,  0.8433,
         0.5529, -0.7445,  0.0731,  0.3546, -0.5416, -0.6670, -0.6884, -0.4920])

Step 3:
Number of graphs in the current batch: 13
PairDataBatch(x_1=[220, 7], x_1_batch=[220], x_1_ptr=[14], edge_index_1=[2, 490], x_2=[180, 7], x_2_batch=[180], x_2_ptr=[14], edge_index_2=[2, 382], distance=[13])
tensor([ 0.1419, -0.9760, -0.49

In [21]:
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x_1, data.edge_index_1, data.x_1_batch, 
                    data.x_2, data.edge_index_2, data.x_2_batch)  # Perform a single forward pass.
        # print(f"out is:\n\t{out}\nwhile dist is {data.distance}")
        loss = criterion(out, data.distance)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test():
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            out = model(data.x_1, data.edge_index_1, data.x_1_batch, 
                        data.x_2, data.edge_index_2, data.x_2_batch)  # Perform a single forward pass.
            loss = criterion(out, data.distance)


In [22]:
model = GCN_pairs(input_features=dataset.num_node_features, hidden_channels=64, output_embeddings=300)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

train_loader = DataLoader(train_data_list, batch_size=16, follow_batch=['x_1', 'x_2'])


epochs=150
for epoch in range(epochs):
    train()
    test()
    #train_acc = test(train_loader)
    #test_acc = test(test_loader)
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f')
        print(f'Epoch: {epoch:03d},  train_acc:.4f, Test Acc: test_acc:.4f')


Epoch: 000, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 000,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 010, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 010,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 020, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 020,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 030, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 030,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 040, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 040,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 050, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 050,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 060, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 060,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 070, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 070,  train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 080, Train Acc: train_acc:.4f, Test Acc: test_acc:.4f
Epoch: 080,  train_acc:.4f, Test Acc: test_

## TO DO:
- Add the test during the training procedure and print some metric like MSE at every iteration in training and test set for example.
- Do it with the actual homomorphism counts distances