In [1]:
!git clone https://github.com/arunsammit/MPNN-Ptr mpnn_ptr
%cd mpnn_ptr

Cloning into 'mpnn_ptr'...
remote: Enumerating objects: 212, done.[K
remote: Counting objects: 100% (212/212), done.[K
remote: Compressing objects: 100% (141/141), done.[K
remote: Total 212 (delta 100), reused 163 (delta 62), pack-reused 0[K
Receiving objects: 100% (212/212), 16.77 MiB | 3.28 MiB/s, done.
Resolving deltas: 100% (100/100), done.
/content/mpnn_ptr


In [3]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cu111.html

Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 5.4 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 24.4 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 40.9 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (747 kB)
[K     |████████████████████████████████| 747 kB 57.6 MB/s 
[?25hCollecting torch-geometric
  Downloading torch_geometric-2.0.2.tar.gz (325 kB)
[K     |███

In [4]:
from models.mpnn_ptr import MpnnPtr
from torch import nn
import torch
from utils.utils import communication_cost, calculate_baseline

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)

In [None]:
dataloader, distance_matrices = torch.load('data/data_64.pt')

In [None]:
dataloader.batch_size

128

In [None]:
max_graph_size = 64
mpnn_ptr = MpnnPtr(input_dim=max_graph_size, embedding_dim=75, hidden_dim=81, K=3, n_layers=4,p_dropout=0, logit_clipping=True, device=device)
mpnn_ptr.to(device)
mpnn_ptr.apply(init_weights)
optim = torch.optim.Adam(mpnn_ptr.parameters(), lr=0.0001)
num_epochs = 3500
epoch_penalty = torch.zeros(len(dataloader))
loss_list_pre = []

In [None]:
mpnn_ptr.train()

In [None]:
for epoch in range(num_epochs):
    epoch_penalty[:] = 0
    for i, (data, distance_matrix) in enumerate(zip(dataloader, distance_matrices)):
        num_samples = 16
        
        samples, predicted_mappings, log_likelihoods_sum = mpnn_ptr(data,num_samples)
        # samples shape: (batch_size, num_samples, max_graph_size_in_batch)
        # predicted_mappings shape: (batch_size, max_graph_size_in_batch)
        # log_likelihoods_sum shape: (batch_size,)
        penalty = communication_cost(data.edge_index, data.edge_attr, data.batch, data.num_graphs, distance_matrix, predicted_mappings)
        epoch_penalty[i] = penalty.sum()
        penalty_baseline = calculate_baseline(data.edge_index, data.edge_attr, data.batch, data.num_graphs, distance_matrix, samples)
        loss = torch.mean((penalty.detach() - penalty_baseline.detach()) * log_likelihoods_sum)
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(mpnn_ptr.parameters(), max_norm=1, norm_type=2)
        optim.step()
    batch_loss = epoch_penalty.sum().item()
    loss_list_pre.append(batch_loss)
    print('Epoch: {}/{}, Loss: {}'.format(epoch + 1, num_epochs, batch_loss))