In [None]:
!git clone https://github.com/arunsammit/MPNN-Ptr mpnn_ptr
%cd mpnn_ptr

In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.10.0+cu111.html

In [18]:
from models.mpnn_ptr import *
from models.mpnn import *
from models.seqToseq import *
from torch import nn
import torch
from utils.utils import communication_cost, calculate_baseline

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)

In [5]:
!ls

sample_data


In [10]:
from google.colab import drive
drive.mount('./drive')

Mounted at ./drive


In [11]:
dataloader, distance_matrices = torch.load('drive/MyDrive/data_MTP/data_64.pt')

In [None]:
dataloader.batch_size

128

In [15]:
max_graph_size = 64
mpnn_ptr = MpnnPtr(input_dim=max_graph_size, embedding_dim=75, hidden_dim=81, K=3, n_layers=4,p_dropout=0, logit_clipping=True, device=device)
mpnn_ptr.to(device)
mpnn_ptr.apply(init_weights)
optim = torch.optim.Adam(mpnn_ptr.parameters(), lr=0.0001)
num_epochs = 3500
epoch_penalty = torch.zeros(len(dataloader))
loss_list_pre = []

In [None]:
mpnn_ptr.train()

In [19]:
mpnn_ptr = torch.load('./drive/MyDrive/data_MTP/model_64.pt')

In [None]:
for epoch in range(num_epochs):
    epoch_penalty[:] = 0
    for i, (data, distance_matrix) in enumerate(zip(dataloader, distance_matrices)):
        num_samples = 16
        
        samples, predicted_mappings, log_likelihoods_sum = mpnn_ptr(data,num_samples)
        # samples shape: (batch_size, num_samples, max_graph_size_in_batch)
        # predicted_mappings shape: (batch_size, max_graph_size_in_batch)
        # log_likelihoods_sum shape: (batch_size,)
        penalty = communication_cost(data.edge_index, data.edge_attr, data.batch, data.num_graphs, distance_matrix, predicted_mappings)
        epoch_penalty[i] = penalty.sum()
        penalty_baseline = calculate_baseline(data.edge_index, data.edge_attr, data.batch, data.num_graphs, distance_matrix, samples)
        loss = torch.mean((penalty.detach() - penalty_baseline.detach()) * log_likelihoods_sum)
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(mpnn_ptr.parameters(), max_norm=1, norm_type=2)
        optim.step()
    batch_loss = epoch_penalty.sum().item()
    loss_list_pre.append(batch_loss)
    print('Epoch: {}/{}, Loss: {}'.format(epoch + 1, num_epochs, batch_loss))

In [None]:
import matplotlib.pyplot as plt

In [21]:
torch.save(mpnn_ptr,'./drive/MyDrive/data_MTP/model_64.pt')

In [None]:
fig_pre, ax_pre = plt.subplots()  
ax_pre.plot(loss_list_pre)  
ax_pre.set_xlabel('number of epochs') 
ax_pre.set_ylabel('communication cost')  
ax_pre.set_title("communication cost v/s number of epochs")  