In [1]:
import json
import numpy as np
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt

import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

from Utilities import score
from Utilities import plot_matrix_runs, plot_results
from Utilities import number_of_neighbours, PairData, prepare_dataloader_distance

from training import training_loop

from models import GCN_pairs_distance

## Exploration of extracted homomorphism counts

In [None]:
with open('data/homomorphism_counts/MUTAG_full_kernel_max_20_run1.homson') as f:
   data_run1 = json.load(f)

with open('data/homomorphism_counts/MUTAG_full_kernel_max_20_run2.homson') as f:
   data_run2 = json.load(f)

# To extract the homomorphism counts for each of the embeddings
hom_counts_list_run1 = np.array([element['counts'] for element in data_run1['data']], dtype = 'float')
hom_counts_list_run2 = np.array([element['counts'] for element in data_run2['data']], dtype = 'float')

In [None]:
L1_run1 = cdist(hom_counts_list_run1, hom_counts_list_run1, metric='cityblock')
L1_run2 = cdist(hom_counts_list_run2, hom_counts_list_run2, metric='cityblock')
plot_matrix_runs(L1_run1, L1_run2, num_elements=30) # We can see some inconsitent scales among different runs

### Try to see if GNN picks up something interesting

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = TUDataset(root='/tmp/MUTAG_transformed', name='MUTAG', pre_transform=number_of_neighbours) 

In [None]:
train_loader, val_loader, test_loader = prepare_dataloader_distance(hom_counts_list_run1, dataset, batch_size=32, dist='L1', device = device)

In [None]:
model = GCN_pairs_distance(input_features=dataset.num_node_features, hidden_channels=64, output_embeddings=300, name='GCN3_L1.2', dist = 'L1').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()
print(model)

In [None]:
training_loop(model, train_loader, optimizer, criterion, val_loader, epoch_number=100)

## Inference step

In [None]:
model = GCN_pairs_distance(input_features=dataset.num_node_features, hidden_channels=64, output_embeddings=300, name='GCN3_L1.2', dist='L1').to(device)
model.load_state_dict(torch.load("models/GCN3_L1.2.pt"))

In [None]:
y, predictions = score(model, val_loader)

In [None]:
plot_results(y, predictions, subset=50)