In [1]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing, global_mean_pool, GCNConv
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.loader import DataLoader

from running_functions import training_loop
# To compute distances
# from scipy.spatial.distance import cdist

# Construct the GCN for regression 

In [2]:
# First of all we have to define the dataset
# Prepare dataset
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG', use_node_attr=True)

torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:100]
val_dataset = dataset[100:]

In [3]:
# Define the dataloader to take as input graph pairs
from torch_geometric.data import Data
class PairData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index_1':
            return self.x_1.size(0)
        if key == 'edge_index_2':
            return self.x_2.size(0)
        return super().__inc__(key, value, *args, **kwargs)

In [4]:
# Construct all pairs of graphs using PairData object
# Distance for now is initialized at random in a value between (-1, 1) like a cosine distance would
import random
train_data_list = []
for ind1, graph1 in enumerate(train_dataset):
    for ind2, graph2 in enumerate(train_dataset[ind1+1:]):
        # ind2 += (ind1 + 1)
        train_data_list.append(PairData(x_1=graph1.x, edge_index_1=graph1.edge_index,
                            x_2=graph2.x, edge_index_2=graph2.edge_index,
                            distance = random.uniform(-1,1)))   

val_data_list = []
for ind1, graph1 in enumerate(val_dataset):
    for ind2, graph2 in enumerate(val_dataset[ind1+1:]):
        # ind2 += (ind1 + 1)
        val_data_list.append(PairData(x_1=graph1.x, edge_index_1=graph1.edge_index,
                            x_2=graph2.x, edge_index_2=graph2.edge_index,
                            distance = random.uniform(-1,1)))   


In [5]:
len(train_data_list) #binom(n,2) = n(n-1)/2

4950

In [6]:
class GCN_pairs(torch.nn.Module):
    """
    Takes as input a pair of graphs which are both fed through 3 convolutional layers each followed by an activation function
    3 graph convolutional layers (Welling) that share parameters
    """
    def __init__(self, input_features, hidden_channels, output_embeddings):
        super(GCN_pairs, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(input_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, output_embeddings)

    def forward(self, x1, edge_index1, batch1, x2, edge_index2, batch2): # Need a way to extract these from dataloader
        # 1. Obtain node embeddings for graph 1
        x1 = self.conv1(x1, edge_index1)
        x1 = x1.relu()
        x1 = self.conv2(x1, edge_index1)
        x1 = x1.relu()
        x1 = self.conv3(x1, edge_index1)
        # 2. Readout layer
        x1 = global_mean_pool(x1, batch1)  # [batch_size, hidden_channels]
        # 3. Apply a final linear transformation on the aggregated embedding
        x1 = torch.nn.functional.dropout(x1, p=0.5, training=self.training)
        x1 = self.lin(x1)

        # 1. Obtain node embeddings for graph 2
        x2 = self.conv1(x2, edge_index2)
        x2 = x2.relu()
        x2 = self.conv2(x2, edge_index2)
        x2 = x2.relu()
        x2 = self.conv3(x2, edge_index2)
        # 2. Readout layer
        x2 = global_mean_pool(x2, batch2)  # [batch_size, hidden_channels]
        # 3. Apply a final linear transformation on the aggregated embedding
        x2 = torch.nn.functional.dropout(x2, p=0.5, training=self.training)
        x2 = self.lin(x2)

        #print(f"x1 has shape: {x1.shape} and is:\n\t{x1}")
        #print(f"x2 has shape: {x2.shape} and is:\n\t{x2}")
        return torch.nn.functional.cosine_similarity(x1, x2)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GCN_pairs(input_features = dataset.num_node_features, hidden_channels=64, output_embeddings=300).to(device)
print(model)

GCN_pairs(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=300, bias=True)
)


In [8]:
batch_size = 32
train_loader = DataLoader(train_data_list, batch_size=batch_size, follow_batch=['x_1', 'x_2'])
val_loader = DataLoader(val_data_list, batch_size=batch_size, follow_batch=['x_1', 'x_2'])
for step, data in enumerate(train_loader):
    print(f'Batch {step + 1}:')
    data = data.to(device)
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print(data.distance)
    print()

Batch 1:
Number of graphs in the current batch: 32
PairDataBatch(x_1=[544, 7], x_1_batch=[544], x_1_ptr=[33], edge_index_1=[2, 1216], x_2=[571, 7], x_2_batch=[571], x_2_ptr=[33], edge_index_2=[2, 1268], distance=[32])
tensor([ 0.5963,  0.4198,  0.1550,  0.0571, -0.1782, -0.8633,  0.6130, -0.9416,
        -0.7611, -0.8061, -0.1256,  0.7038,  0.4941,  0.3455, -0.9816, -0.4816,
         0.6428,  0.2381, -0.8480, -0.0257,  0.6622, -0.4694, -0.3306,  0.9496,
        -0.4415,  0.6610,  0.0727,  0.0040, -0.6391,  0.0270,  0.2656,  0.3718])

Batch 2:
Number of graphs in the current batch: 32
PairDataBatch(x_1=[544, 7], x_1_batch=[544], x_1_ptr=[33], edge_index_1=[2, 1216], x_2=[578, 7], x_2_batch=[578], x_2_ptr=[33], edge_index_2=[2, 1286], distance=[32])
tensor([-0.3214,  0.6405,  0.2957,  0.8971, -0.9204, -0.2868,  0.0827,  0.1358,
         0.2368,  0.7892, -0.2213,  0.9217, -0.1788,  0.6594, -0.2092, -0.8053,
        -0.5792, -0.6184, -0.1740,  0.3791,  0.4816, -0.7517,  0.5434, -0.5554,
  

In [9]:
model = GCN_pairs(input_features=dataset.num_node_features, hidden_channels=64, output_embeddings=300)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()


In [10]:
training_loop(model, train_loader, optimizer, criterion, val_loader, epoch_number=20)

Epoch: 10 | Epoch Time: 0m 2s
	Train Loss: 0.369
	 Val. Loss: 1.314
Epoch: 20 | Epoch Time: 0m 3s
	Train Loss: 0.363
	 Val. Loss: 1.314


## TO DO:
- Do it with the actual homomorphism counts distances
- For now the output is obtained using cosine similarity -> Try Euclideean and Manhattan maybe?