In [1]:
ROOT_PATH = '../../../../'

import sys
import os
import inspect
from collections import OrderedDict
from functools import partial

sys.path.append(ROOT_PATH)

import pyro
import torch
import torchvision
import glob
import tqdm
import torch.nn.functional as F
import torch_geometric.nn as gnn
import torch_geometric.transforms as gnn_T
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch_cluster import grid_cluster
from torch_geometric.data import DataLoader
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.data import Batch, DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
%matplotlib inline

## Setup and Data Inspection

In [4]:
dataset_folder = './MNIST/graphs/'

graph_train_dataset = MNISTSuperpixels(root=dataset_folder, train=True)
graph_test_dataset = MNISTSuperpixels(root=dataset_folder, train=False)

In [5]:
preprocess = gnn_T.Compose([
    gnn_T.Cartesian(),
    gnn_T.ToSparseTensor(remove_edge_index=False),
    gnn_T.ToUndirected(),
])

## Classifier Model

In [6]:
class SplineCNN(torch.nn.Module):
    def __init__(self):
        super(SplineCNN, self).__init__()
        self.conv1 = gnn.SplineConv(1, 32, dim=2, kernel_size=5)
        self.conv2 = gnn.SplineConv(32, 64, dim=2, kernel_size=5)
        self.fc1 = torch.nn.Linear(64, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5)
        cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        data = gnn.max_pool(cluster, data)
        
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.elu(self.conv2(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5)
        cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        data = gnn.max_pool(cluster, data)
        
        x, batch = data.x, data.batch
        
        x = gnn.global_mean_pool(x, batch)  
        x = F.elu(self.fc1(x))
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)

In [7]:
# SplineCNN
model = SplineCNN().to(device)
model

SplineCNN(
  (conv1): SplineConv(1, 32, dim=2)
  (conv2): SplineConv(32, 64, dim=2)
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [8]:
lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

epochs = 0

train_frac = 0.8
train_idx = int(len(graph_train_dataset) * train_frac)
train_split = graph_train_dataset[:train_idx]
val_split = graph_train_dataset[train_idx:]

batch_size = 16
graph_train_loader = DataLoader(train_split, batch_size=batch_size, shuffle=False)
graph_val_loader = DataLoader(val_split, batch_size=batch_size, shuffle=False)

In [9]:
def train_model(model, train_dataset_loader, preprocessor, device):
    model.train()
    
    for batch in tqdm.tqdm(train_dataset_loader):
        optimizer.zero_grad()
        batch = preprocessor(batch.to(device))
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()
    
    return model
    

def validate_model(model, val_dataset_loader, preprocessor, device):
    model.eval()
    
    val_acc = []
    
    for batch in val_dataset_loader:
        batch = preprocessor(batch.to(device))
        pred_logits = model(batch)
        _, pred = pred_logits.max(dim=1)
        acc = (pred.eq(batch.y).sum() / batch.y.size(0)).item()
        val_acc.append(acc)
    
    return sum(val_acc) / len(val_acc)

In [10]:
# for epoch in range(epochs):
    
#     print(f'Epoch: {epoch}')
        
#     train_model(model, graph_train_loader, preprocess, device)
#     val_acc = validate_model(model, graph_val_loader, preprocess, device)
#     train_acc = validate_model(model, graph_train_loader, preprocess, device)
    
#     print(f'Train Acc: {train_acc}, Val Acc: {val_acc}')
#     print()

In [11]:
# graph_test_loader = DataLoader(graph_test_dataset, batch_size=batch_size)
# test_acc = validate_model(model, graph_test_loader, preprocess, device)
# test_acc

## Generative Model

In [314]:
def train_gvae(model, train_dataset_loader, optimiser, preprocessor, device):
    model.train()
    
    for batch in tqdm.tqdm(train_dataset_loader):
        optimiser.zero_grad()
        batch = preprocessor(batch.to(device))
        edge_probs, z, mean, log_std = model(batch.x, batch.edge_index, batch.edge_attr)
        loss, _, _ = model.loss_function(mean, log_std, edge_probs)
        loss.backward()
        optimiser.step()

def validate_gvae(model, val_dataset_loader, preprocessor, device):
    model.eval()
    
    with torch.no_grad():

        val_elbos = []
        log_lik = []
        kl_losses = []

        for batch in tqdm.tqdm(val_dataset_loader):
            batch = preprocessor(batch.to(device))
            edge_probs, z, mean, log_std = model(batch.x, batch.edge_index, batch.edge_attr)
            elbo, log_prob, kl_loss = model.loss_function(mean, log_std, edge_probs)
            val_elbos.append(elbo)
            log_lik.append(log_prob)
            kl_losses.append(kl_loss)
            
    total_elbo = sum(val_elbos) / len(val_elbos)
    total_log_lik = sum(log_lik) / len(log_lik)
    total_kl = sum(kl_losses) / len(kl_losses)
    
    return total_elbo, total_log_lik, total_kl

In [321]:
epochs = 0

train_frac = 0.8
train_idx = int(len(graph_train_dataset) * train_frac)
train_split = graph_train_dataset[:train_idx]
val_split = graph_train_dataset[train_idx:]

batch_size = 512
graph_train_loader = DataLoader(train_split, batch_size=batch_size, shuffle=False)
graph_val_loader = DataLoader(val_split, batch_size=batch_size, shuffle=False)

In [322]:
from torch import Tensor
from torch_geometric.nn import DeepGCNLayer


LOG_CONST = 1e-15


class SplineConvUnit(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim, edge_dim, dropout=0, latent_encoder=False):
        super(SplineConvUnit, self).__init__()
        self.edge_dim = edge_dim
        self.dropout = dropout
        self.conv = gnn.SplineConv(input_dim, output_dim, dim=edge_dim, kernel_size=5)
        self.latent_encoder = latent_encoder
        
    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:
        x = self.conv(x, edge_index, edge_attr)
        
        if self.latent_encoder:
            return x
        
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        return x
    

class GCNConvUnit(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim, edge_dim, dropout=0, latent_encoder=False):
        super(GCNConvUnit, self).__init__()
        self.dropout = dropout
        self.conv = gnn.GCNConv(input_dim, output_dim)
        self.latent_encoder = latent_encoder
        
    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:
        x = self.conv(x, edge_index)
        
        if self.latent_encoder:
            return x
        
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout)
        
        return x


class Encoder(torch.nn.Module):
    
    def __init__(self, conv_class, input_dim, hidden_dim1, hidden_dim2, latent_dim, edge_dim=None, dropout=0):
        super(Encoder, self).__init__()
        self.conv1 = conv_class(input_dim, hidden_dim1, edge_dim, dropout)
        self.conv2 = conv_class(hidden_dim1, hidden_dim2, edge_dim, dropout)
        
        kwargs = {
            'input_dim': hidden_dim2,
            'output_dim': latent_dim,
            'edge_dim': edge_dim,
            'latent_encoder': True,
        }
        
        if conv_class == SplineConvUnit:
            kwargs['edge_dim'] = edge_dim
        
        self.mean = conv_class(**kwargs)
        self.log_std = conv_class(**kwargs)
        self.edge_dim = edge_dim
        self.dropout = dropout
    
    def get_output_dim(self):
        return self.output_dim
    
    def get_edge_dim(self):
        return self.edge_dim

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> [Tensor, Tensor]:
        x = self.conv1(x, edge_index, edge_attr)
        x = self.conv2(x, edge_index, edge_attr)
        mean = self.mean(x, edge_index, edge_attr)
        log_std = self.log_std(x, edge_index, edge_attr)
        return mean, log_std
    

class GraphVAE(torch.nn.Module):
    """
    https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
    """
    
    def __init__(self, encoder, decoder):
        super(GraphVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def encode(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> [Tensor, Tensor]:
        return self.encoder(x, edge_index, edge_attr)
    
    def decode(self, z: Tensor, edge_index: Tensor) -> Tensor:
        return self.decoder(z, edge_index)
    
    def reparametrise(self, mean: Tensor, log_std: Tensor) -> Tensor:
        std = torch.exp(log_std)
        
        # Note: Just 1 MC particle
        eps = torch.randn_like(std)
        z = mean + eps * std
        
        return z
    
    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> [Tensor, Tensor, Tensor, Tensor]:
        res = self.encode(x, edge_index, edge_attr)
        mean, log_std = self.encode(x, edge_index, edge_attr)
        z = self.reparametrise(mean, log_std)
        return self.decode(z, edge_index), z, mean, log_std

    def generate(self, x: Tensor) -> Tensor:
        return self.forward(x)[0]
    
    def loss_function(self, mean: Tensor, log_std: Tensor, pos_edge_probs: Tensor) -> Tensor:
        pos_log_prob = -torch.log(pos_edge_probs + LOG_CONST).mean()
        # neg_log_prob = -torch.log(1 - neg_edge_probs + LOG_CONST).mean()
        log_prob = pos_log_prob  # + neg_log_prob
        
        kl_loss = torch.sum(1 + 2 * log_std - mean ** 2 - log_std.exp() ** 2, dim=1)
        kl_loss = -0.5 * torch.mean(kl_loss)
        
        loss = log_prob + kl_loss
        
        return loss, log_prob, kl_loss

In [323]:
preprocess = gnn_T.Compose([
    gnn_T.TargetIndegree()
])

In [324]:
from torch_geometric.nn import InnerProductDecoder

input_dim = 1
hidden1 = 32
hidden2 = 16
edge_dim = 1
latent_dim = 10
dropout = 0

conv_class = GCNConvUnit
encoder = Encoder(conv_class, input_dim, hidden1, hidden2, latent_dim, edge_dim, dropout).to(device)
decoder = InnerProductDecoder().to(device)
gvae = GraphVAE(encoder, decoder).to(device)

gvae

GraphVAE(
  (encoder): Encoder(
    (conv1): GCNConvUnit(
      (conv): GCNConv(1, 32)
    )
    (conv2): GCNConvUnit(
      (conv): GCNConv(32, 16)
    )
    (mean): GCNConvUnit(
      (conv): GCNConv(16, 10)
    )
    (log_std): GCNConvUnit(
      (conv): GCNConv(16, 10)
    )
  )
  (decoder): InnerProductDecoder()
)

In [325]:
graph_trial_loader = DataLoader(graph_train_dataset, batch_size=3, shuffle=False)

for x in graph_trial_loader:
    break

gvae.eval()
    
_data = preprocess(x.to(device))

print('Data shape')
print(_data)
print(_data.x.shape, _data.edge_index.shape, _data.edge_attr.shape)
print()

edge_probs, z, mean, log_std = gvae(_data.x, _data.edge_index, _data.edge_attr)
print('Output shape')
print(edge_probs.shape, z.shape, mean.shape, log_std.shape)
print()

print(gvae.loss_function(mean, log_std, edge_probs))

Data shape
Batch(batch=[225], edge_attr=[4043, 1], edge_index=[2, 4043], pos=[225, 2], x=[225, 1], y=[3])
torch.Size([225, 1]) torch.Size([2, 4043]) torch.Size([4043, 1])

Output shape
torch.Size([4043]) torch.Size([225, 10]) torch.Size([225, 10]) torch.Size([225, 10])

(tensor(1.5489, device='cuda:0', grad_fn=<AddBackward0>), tensor(1.5436, device='cuda:0', grad_fn=<NegBackward>), tensor(0.0053, device='cuda:0', grad_fn=<MulBackward0>))


In [326]:
optimiser_gvae = torch.optim.Adam(gvae.parameters(), lr=0.0001)

for epoch in range(5):
    
    print(f'Epoch: {epoch}')
    print('Training')
    train_gvae(gvae, graph_train_loader, optimiser_gvae, preprocess, device)
    print('Validation - Val set')
    mean_val_elbo, mean_val_log_lik, mean_val_total_kl = validate_gvae(gvae, graph_val_loader, preprocess, device)
    print('Validation - Train set')
    mean_train_elbo, mean_train_log_lik, mean_train_total_kl = validate_gvae(gvae, graph_train_loader, preprocess, device)
    
    print(f'Train Elbo: {mean_train_elbo}, Val Elbo: {mean_val_elbo}')
    print(f'Train Log Lik: {mean_train_log_lik}, Val Log Lik: {mean_val_log_lik}')
    print(f'Train KL: {mean_train_total_kl}, Val KL: {mean_val_total_kl}')
    print()

  2%|▏         | 2/94 [00:00<00:07, 12.06it/s]

Epoch: 0
Training


100%|██████████| 94/94 [00:07<00:00, 13.24it/s]
  8%|▊         | 2/24 [00:00<00:01, 16.56it/s]

Validation - Val set


100%|██████████| 24/24 [00:01<00:00, 15.80it/s]
  2%|▏         | 2/94 [00:00<00:05, 17.97it/s]

Validation - Train set


100%|██████████| 94/94 [00:05<00:00, 17.76it/s]
  2%|▏         | 2/94 [00:00<00:06, 14.04it/s]

Train Elbo: 1.3891448974609375, Val Elbo: 1.3890862464904785
Train Log Lik: 1.3729982376098633, Val Log Lik: 1.3730144500732422
Train KL: 0.016146957874298096, Val KL: 0.016071707010269165

Epoch: 1
Training


100%|██████████| 94/94 [00:07<00:00, 13.24it/s]
  8%|▊         | 2/24 [00:00<00:01, 16.39it/s]

Validation - Val set


100%|██████████| 24/24 [00:01<00:00, 18.14it/s]
  2%|▏         | 2/94 [00:00<00:05, 18.20it/s]

Validation - Train set


100%|██████████| 94/94 [00:05<00:00, 18.09it/s]
  2%|▏         | 2/94 [00:00<00:06, 14.44it/s]

Train Elbo: 1.360724687576294, Val Elbo: 1.361077070236206
Train Log Lik: 1.327467441558838, Val Log Lik: 1.3279004096984863
Train KL: 0.03325744345784187, Val KL: 0.03317663446068764

Epoch: 2
Training


100%|██████████| 94/94 [00:07<00:00, 13.31it/s]
  8%|▊         | 2/24 [00:00<00:01, 16.28it/s]

Validation - Val set


100%|██████████| 24/24 [00:01<00:00, 18.10it/s]
  2%|▏         | 2/94 [00:00<00:04, 18.46it/s]

Validation - Train set


100%|██████████| 94/94 [00:05<00:00, 17.21it/s]
  2%|▏         | 2/94 [00:00<00:06, 14.54it/s]

Train Elbo: 1.3393208980560303, Val Elbo: 1.3388365507125854
Train Log Lik: 1.284861445426941, Val Log Lik: 1.2844533920288086
Train KL: 0.054458584636449814, Val KL: 0.05438306927680969

Epoch: 3
Training


100%|██████████| 94/94 [00:07<00:00, 13.27it/s]
  8%|▊         | 2/24 [00:00<00:01, 16.49it/s]

Validation - Val set


100%|██████████| 24/24 [00:01<00:00, 18.18it/s]
  2%|▏         | 2/94 [00:00<00:05, 17.96it/s]

Validation - Train set


100%|██████████| 94/94 [00:05<00:00, 18.04it/s]
  2%|▏         | 2/94 [00:00<00:06, 14.76it/s]

Train Elbo: 1.3318439722061157, Val Elbo: 1.3324756622314453
Train Log Lik: 1.2642377614974976, Val Log Lik: 1.2649564743041992
Train KL: 0.067606121301651, Val KL: 0.06751897931098938

Epoch: 4
Training


100%|██████████| 94/94 [00:07<00:00, 13.32it/s]
  8%|▊         | 2/24 [00:00<00:01, 16.25it/s]

Validation - Val set


100%|██████████| 24/24 [00:01<00:00, 18.05it/s]
  2%|▏         | 2/94 [00:00<00:05, 18.04it/s]

Validation - Train set


100%|██████████| 94/94 [00:05<00:00, 17.96it/s]

Train Elbo: 1.3269635438919067, Val Elbo: 1.3278532028198242
Train Log Lik: 1.2514976263046265, Val Log Lik: 1.2524890899658203
Train KL: 0.07546579092741013, Val KL: 0.07536416500806808




