In [416]:
ROOT_PATH = '../../../../'

import sys
import os
import inspect
from collections import OrderedDict
from functools import partial

sys.path.append(ROOT_PATH)

import pyro
import torch
import torchvision
import glob
import tqdm
import torch.nn.functional as F
import torch_geometric.nn as gnn
import torch_geometric.transforms as gnn_T
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch_cluster import grid_cluster
from torch_geometric.data import DataLoader
from torch_geometric.datasets import MNISTSuperpixels, CoraFull
from torch_geometric.data import Batch, DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
%matplotlib inline

## Setup and Data Inspection

In [4]:
dataset_folder = './MNIST/graphs/'

graph_train_dataset = MNISTSuperpixels(root=dataset_folder, train=True)
graph_test_dataset = MNISTSuperpixels(root=dataset_folder, train=False)

In [5]:
preprocess = gnn_T.Compose([
    gnn_T.Cartesian(),
    gnn_T.ToSparseTensor(remove_edge_index=False),
    gnn_T.ToUndirected(),
])

## Classifier Model

In [6]:
class SplineCNN(torch.nn.Module):
    def __init__(self):
        super(SplineCNN, self).__init__()
        self.conv1 = gnn.SplineConv(1, 32, dim=2, kernel_size=5)
        self.conv2 = gnn.SplineConv(32, 64, dim=2, kernel_size=5)
        self.fc1 = torch.nn.Linear(64, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5)
        cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        data = gnn.max_pool(cluster, data)
        
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.elu(self.conv2(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5)
        cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        data = gnn.max_pool(cluster, data)
        
        x, batch = data.x, data.batch
        
        x = gnn.global_mean_pool(x, batch)  
        x = F.elu(self.fc1(x))
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)

In [7]:
# SplineCNN
model = SplineCNN().to(device)
model

SplineCNN(
  (conv1): SplineConv(1, 32, dim=2)
  (conv2): SplineConv(32, 64, dim=2)
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [8]:
lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

epochs = 0

train_frac = 0.8
train_idx = int(len(graph_train_dataset) * train_frac)
train_split = graph_train_dataset[:train_idx]
val_split = graph_train_dataset[train_idx:]

batch_size = 16
graph_train_loader = DataLoader(train_split, batch_size=batch_size, shuffle=False)
graph_val_loader = DataLoader(val_split, batch_size=batch_size, shuffle=False)

In [9]:
def train_model(model, train_dataset_loader, preprocessor, device):
    model.train()
    
    for batch in tqdm.tqdm(train_dataset_loader):
        optimizer.zero_grad()
        batch = preprocessor(batch.to(device))
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()
    
    return model
    

def validate_model(model, val_dataset_loader, preprocessor, device):
    model.eval()
    
    val_acc = []
    
    for batch in val_dataset_loader:
        batch = preprocessor(batch.to(device))
        pred_logits = model(batch)
        _, pred = pred_logits.max(dim=1)
        acc = (pred.eq(batch.y).sum() / batch.y.size(0)).item()
        val_acc.append(acc)
    
    return sum(val_acc) / len(val_acc)

In [10]:
# for epoch in range(epochs):
    
#     print(f'Epoch: {epoch}')
        
#     train_model(model, graph_train_loader, preprocess, device)
#     val_acc = validate_model(model, graph_val_loader, preprocess, device)
#     train_acc = validate_model(model, graph_train_loader, preprocess, device)
    
#     print(f'Train Acc: {train_acc}, Val Acc: {val_acc}')
#     print()

In [11]:
# graph_test_loader = DataLoader(graph_test_dataset, batch_size=batch_size)
# test_acc = validate_model(model, graph_test_loader, preprocess, device)
# test_acc

## Generative Model

In [390]:
from sklearn.metrics import average_precision_score, accuracy_score

def train_gvae(model, train_dataset_loader, optimiser, preprocessor, device):
    model.train()
    
    for batch in tqdm.tqdm(train_dataset_loader):
        optimiser.zero_grad()
        batch = preprocessor(batch.to(device))
        edge_probs, z, mean, log_std = model(batch.x, batch.edge_index, batch.edge_attr)
        loss, _, _ = model.loss_function(mean, log_std, edge_probs)
        loss.backward()
        optimiser.step()

def validate_gvae(model, val_dataset_loader, preprocessor, device):
    model.eval()
    
    with torch.no_grad():

        val_elbos = []
        log_lik = []
        kl_losses = []
        average_precisions = []
        accuracies = []

        for batch in tqdm.tqdm(val_dataset_loader):
            # Forward Pass
            batch = preprocessor(batch.to(device))
            edge_probs, z, mean, log_std = model(batch.x, batch.edge_index, batch.edge_attr)
            elbo, log_prob, kl_loss = model.loss_function(mean, log_std, edge_probs)
            # Accumulate Metrics
            val_elbos.append(elbo)
            log_lik.append(log_prob)
            kl_losses.append(kl_loss)
            ones = torch.ones_like(edge_probs.cpu())
            ap = average_precision_score(ones, edge_probs.cpu())
            average_precisions.append(ap)
            mask = edge_probs > 0.5
            edge_probs[mask] = 1
            edge_probs[~mask] = 0
            acc = accuracy_score(ones, edge_probs.cpu())
            accuracies.append(acc)
            
    total_elbo = sum(val_elbos) / len(val_elbos)
    total_log_lik = sum(log_lik) / len(log_lik)
    total_kl = sum(kl_losses) / len(kl_losses)
    ap = sum(average_precisions) / len(average_precisions)
    acc = sum(accuracies) / len(accuracies)
    
    return total_elbo, total_log_lik, total_kl, ap, acc

In [443]:
graph_train_dataset = CoraFull(root=dataset_folder)  # , train=True)
graph_test_dataset = CoraFull(root=dataset_folder)  # , train=False)

# graph_train_dataset = MNISTSuperpixels(root=dataset_folder, train=True)
# graph_test_dataset = MNISTSuperpixels(root=dataset_folder, train=False)

In [444]:
epochs = 0

train_frac = 0.8
train_idx = int(len(graph_train_dataset) * train_frac)
train_split = graph_train_dataset[:train_idx]
val_split = graph_train_dataset[train_idx:]

batch_size = 512
graph_train_loader = DataLoader(train_split, batch_size=batch_size, shuffle=False)
graph_val_loader = DataLoader(val_split, batch_size=batch_size, shuffle=False)

In [445]:
from torch import Tensor
from torch_geometric.nn import DeepGCNLayer


LOG_CONST = 1e-15


class SplineConvUnit(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim, edge_dim, dropout=0, latent_encoder=False):
        super(SplineConvUnit, self).__init__()
        self.edge_dim = edge_dim
        self.dropout = dropout
        self.conv = gnn.SplineConv(input_dim, output_dim, dim=edge_dim, kernel_size=5)
        self.latent_encoder = latent_encoder
        
    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:
        x = self.conv(x, edge_index, edge_attr)
        
        if self.latent_encoder:
            return x
        
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        return x
    

class GCNConvUnit(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim, edge_dim, dropout=0, latent_encoder=False):
        super(GCNConvUnit, self).__init__()
        self.dropout = dropout
        self.conv = gnn.GCNConv(input_dim, output_dim)
        self.latent_encoder = latent_encoder
        
    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:
        x = self.conv(x, edge_index)
        
        if self.latent_encoder:
            return x
        
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout)
        
        return x
    

class CoMAUnit(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim, edge_dim, dropout=0, latent_encoder=False):
        super(CoMAUnit, self).__init__()
        self.dropout = dropout
        self.conv = gnn.GCNConv(input_dim, output_dim)
        self.latent_encoder = latent_encoder
        
    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:
        x = self.conv(x, edge_index)
        
        if self.latent_encoder:
            return x
        
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout)
        
        return x


class Encoder(torch.nn.Module):
    
    def __init__(self, conv_class, input_dim, hidden_dim1, hidden_dim2, latent_dim, edge_dim=None, dropout=0):
        super(Encoder, self).__init__()
        self.conv1 = conv_class(input_dim, hidden_dim1, edge_dim, dropout)
        self.conv2 = conv_class(hidden_dim1, hidden_dim2, edge_dim, dropout)
        
        kwargs = {
            'input_dim': hidden_dim2,
            'output_dim': latent_dim,
            'edge_dim': edge_dim,
            'latent_encoder': True,
        }
        
        if conv_class == SplineConvUnit:
            kwargs['edge_dim'] = edge_dim
        
        self.mean = conv_class(**kwargs)
        self.log_std = conv_class(**kwargs)
        self.edge_dim = edge_dim
        self.dropout = dropout
    
    def get_output_dim(self):
        return self.output_dim
    
    def get_edge_dim(self):
        return self.edge_dim

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> [Tensor, Tensor]:
        x = self.conv1(x, edge_index, edge_attr)
        x = self.conv2(x, edge_index, edge_attr)
        mean = self.mean(x, edge_index, edge_attr)
        log_std = self.log_std(x, edge_index, edge_attr)
        return mean, log_std
    

class GraphVAE(torch.nn.Module):
    """
    https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
    """
    
    def __init__(self, encoder, decoder):
        super(GraphVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def encode(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> [Tensor, Tensor]:
        return self.encoder(x, edge_index, edge_attr)
    
    def decode(self, z: Tensor, edge_index: Tensor) -> Tensor:
        return self.decoder(z, edge_index)
    
    def reparametrise(self, mean: Tensor, log_std: Tensor) -> Tensor:
        std = torch.exp(log_std)
        
        # Note: Just 1 MC particle
        eps = torch.randn_like(std)
        z = mean + eps * std
        
        return z
    
    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> [Tensor, Tensor, Tensor, Tensor]:
        res = self.encode(x, edge_index, edge_attr)
        mean, log_std = self.encode(x, edge_index, edge_attr)
        z = self.reparametrise(mean, log_std)
        return self.decode(z, edge_index), z, mean, log_std

    def generate(self, x: Tensor) -> Tensor:
        return self.forward(x)[0]
    
    def loss_function(self, mean: Tensor, log_std: Tensor, pos_edge_probs: Tensor) -> Tensor:
        pos_log_prob = -torch.log(pos_edge_probs + LOG_CONST).mean()
        # neg_log_prob = -torch.log(1 - neg_edge_probs + LOG_CONST).mean()
        log_prob = pos_log_prob  # + neg_log_prob
        
        kl_loss = torch.sum(1 + 2 * log_std - mean ** 2 - log_std.exp() ** 2, dim=1)
        kl_loss = -0.5 * torch.mean(kl_loss)
        
        loss = log_prob + kl_loss
        
        return loss, log_prob, kl_loss

In [446]:
preprocess = gnn_T.Compose([
    gnn_T.TargetIndegree()
])

In [447]:
from torch_geometric.nn import InnerProductDecoder

input_dim = 1
hidden1 = 64
hidden2 = 32
edge_dim = 2
latent_dim = 16
dropout = 0.5

conv_class = GCNConvUnit
encoder = Encoder(conv_class, input_dim, hidden1, hidden2, latent_dim, edge_dim, dropout).to(device)
decoder = InnerProductDecoder().to(device)
gvae = GraphVAE(encoder, decoder).to(device)

gvae

GraphVAE(
  (encoder): Encoder(
    (conv1): GCNConvUnit(
      (conv): GCNConv(1, 64)
    )
    (conv2): GCNConvUnit(
      (conv): GCNConv(64, 32)
    )
    (mean): GCNConvUnit(
      (conv): GCNConv(32, 16)
    )
    (log_std): GCNConvUnit(
      (conv): GCNConv(32, 16)
    )
  )
  (decoder): InnerProductDecoder()
)

In [None]:
Data shape
Batch(batch=[19793], edge_attr=[126842, 1], edge_index=[2, 126842], x=[19793, 8710], y=[19793])
torch.Size([19793, 8710]) torch.Size([2, 126842]) torch.Size([126842, 1])

In [451]:
graph_trial_loader = DataLoader(graph_train_dataset, batch_size=1, shuffle=False)

for x in graph_trial_loader:
    break

gvae.eval()
    
_data = preprocess(x.to(device))

print('Data shape')
print(_data)
print(_data.x.shape, _data.edge_index.shape, _data.edge_attr.shape)
print()

edge_probs, z, mean, log_std = gvae(_data.x, _data.edge_index, _data.edge_attr)
print('Output shape')
print(edge_probs.shape, z.shape, mean.shape, log_std.shape)
print()

print(gvae.loss_function(mean, log_std, edge_probs))
# Calculate accuracy + precision etc. using 1 as the target

Data shape
Batch(batch=[19793], edge_attr=[126842, 1], edge_index=[2, 126842], x=[19793, 8710], y=[19793])
torch.Size([19793, 8710]) torch.Size([2, 126842]) torch.Size([126842, 1])



RuntimeError: mat1 dim 1 must match mat2 dim 0

In [454]:
_data.batch

tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')

In [413]:
_x, _batch = to_dense_batch(_data.x, _data.batch)
_x.shape, _batch.shape

(torch.Size([3, 75, 1]), torch.Size([3, 75]))

In [414]:
optimiser_gvae = torch.optim.Adam(gvae.parameters(), lr=0.001)

for epoch in range(5):
    
    print(f'Epoch: {epoch}')
    print('Training')
    train_gvae(gvae, graph_train_loader, optimiser_gvae, preprocess, device)
    print('Validation - Val set')
    mean_val_elbo, mean_val_log_lik, mean_val_total_kl, val_ap, val_acc = validate_gvae(gvae, graph_val_loader, preprocess, device)
    print('Validation - Train set')
    mean_train_elbo, mean_train_log_lik, mean_train_total_kl, train_ap, train_acc = validate_gvae(gvae, graph_train_loader, preprocess, device)
    
    print(f'Train Elbo: {mean_train_elbo}, Val Elbo: {mean_val_elbo}')
    print(f'Train Log Lik: {mean_train_log_lik}, Val Log Lik: {mean_val_log_lik}')
    print(f'Train KL: {mean_train_total_kl}, Val KL: {mean_val_total_kl}')
    print(f'Train AP: {train_ap}, Val AP: {val_ap}')
    print(f'Train Acc: {train_acc}, Val Acc: {val_acc}')
    print()

  2%|▏         | 2/94 [00:00<00:08, 10.83it/s]

Epoch: 0
Training


100%|██████████| 94/94 [00:08<00:00, 10.95it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

Validation - Val set


100%|██████████| 24/24 [00:06<00:00,  3.47it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Validation - Train set


100%|██████████| 94/94 [00:26<00:00,  3.51it/s]
  2%|▏         | 2/94 [00:00<00:08, 11.40it/s]

Train Elbo: 1.6104369163513184, Val Elbo: 1.60963773727417
Train Log Lik: 1.5065380334854126, Val Log Lik: 1.5062659978866577
Train KL: 0.10389874875545502, Val KL: 0.1033717542886734
Train AP: 1.0, Val AP: 1.0
Train Acc: 0.5018537162577388, Val Acc: 0.5019539535957878

Epoch: 1
Training


100%|██████████| 94/94 [00:08<00:00, 11.17it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

Validation - Val set


100%|██████████| 24/24 [00:06<00:00,  3.69it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Validation - Train set


100%|██████████| 94/94 [00:26<00:00,  3.56it/s]
  2%|▏         | 2/94 [00:00<00:08, 11.27it/s]

Train Elbo: 1.6102389097213745, Val Elbo: 1.609022617340088
Train Log Lik: 1.5049399137496948, Val Log Lik: 1.50433349609375
Train KL: 0.10529859364032745, Val KL: 0.1046890914440155
Train AP: 1.0, Val AP: 1.0
Train Acc: 0.5016965414825028, Val Acc: 0.5018609673086637

Epoch: 2
Training


100%|██████████| 94/94 [00:08<00:00, 11.11it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

Validation - Val set


100%|██████████| 24/24 [00:06<00:00,  3.63it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Validation - Train set


100%|██████████| 94/94 [00:26<00:00,  3.58it/s]
  2%|▏         | 2/94 [00:00<00:08, 11.34it/s]

Train Elbo: 1.6090431213378906, Val Elbo: 1.6101363897323608
Train Log Lik: 1.5049176216125488, Val Log Lik: 1.5067613124847412
Train KL: 0.10412514209747314, Val KL: 0.10337506234645844
Train AP: 1.0, Val AP: 1.0
Train Acc: 0.501633845874357, Val Acc: 0.5016592835883632

Epoch: 3
Training


100%|██████████| 94/94 [00:08<00:00, 11.00it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

Validation - Val set


100%|██████████| 24/24 [00:06<00:00,  3.68it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Validation - Train set


100%|██████████| 94/94 [00:26<00:00,  3.60it/s]
  2%|▏         | 2/94 [00:00<00:09, 10.13it/s]

Train Elbo: 1.6082805395126343, Val Elbo: 1.6081209182739258
Train Log Lik: 1.5043376684188843, Val Log Lik: 1.5050033330917358
Train KL: 0.10394272208213806, Val KL: 0.10311751067638397
Train AP: 1.0, Val AP: 1.0
Train Acc: 0.501660929165506, Val Acc: 0.5017431790399384

Epoch: 4
Training


100%|██████████| 94/94 [00:08<00:00, 10.96it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

Validation - Val set


100%|██████████| 24/24 [00:06<00:00,  3.67it/s]
  0%|          | 0/94 [00:00<?, ?it/s]

Validation - Train set


100%|██████████| 94/94 [00:26<00:00,  3.53it/s]

Train Elbo: 1.6086026430130005, Val Elbo: 1.6081657409667969
Train Log Lik: 1.5013962984085083, Val Log Lik: 1.5016707181930542
Train KL: 0.10720643401145935, Val KL: 0.10649512708187103
Train AP: 1.0, Val AP: 1.0
Train Acc: 0.5013872670926388, Val Acc: 0.5014801099262081




