In [1]:
ROOT_PATH = '../../../../'

import sys
import os
import inspect
from collections import OrderedDict
from functools import partial

sys.path.append(ROOT_PATH)

import pyro
import torch
import torchvision
import glob
import tqdm
import torch.nn.functional as F
import torch_geometric.nn as gnn
import torch_geometric.transforms as gnn_T
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch_cluster import grid_cluster
from torch_geometric.data import DataLoader
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.data import Batch, DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
%matplotlib inline

## Setup and Data Inspection

In [4]:
dataset_folder = './MNIST/graphs/'

graph_train_dataset = MNISTSuperpixels(root=dataset_folder, train=True)
graph_test_dataset = MNISTSuperpixels(root=dataset_folder, train=False)

In [5]:
preprocess = gnn_T.Compose([
    gnn_T.Cartesian(),
    gnn_T.ToSparseTensor(remove_edge_index=False),
    gnn_T.ToUndirected(),
])

## Classifier Model

In [6]:
class SplineCNN(torch.nn.Module):
    def __init__(self):
        super(SplineCNN, self).__init__()
        self.conv1 = gnn.SplineConv(1, 32, dim=2, kernel_size=5)
        self.conv2 = gnn.SplineConv(32, 64, dim=2, kernel_size=5)
        self.fc1 = torch.nn.Linear(64, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5)
        cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        data = gnn.max_pool(cluster, data)
        
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.elu(self.conv2(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5)
        cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        data = gnn.max_pool(cluster, data)
        
        x, batch = data.x, data.batch
        
        x = gnn.global_mean_pool(x, batch)  
        x = F.elu(self.fc1(x))
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)

In [7]:
# SplineCNN
model = SplineCNN().to(device)
model

SplineCNN(
  (conv1): SplineConv(1, 32, dim=2)
  (conv2): SplineConv(32, 64, dim=2)
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [8]:
lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

epochs = 0

train_frac = 0.8
train_idx = int(len(graph_train_dataset) * train_frac)
train_split = graph_train_dataset[:train_idx]
val_split = graph_train_dataset[train_idx:]

batch_size = 64
graph_train_loader = DataLoader(train_split, batch_size=batch_size, shuffle=False)
graph_val_loader = DataLoader(val_split, batch_size=batch_size, shuffle=False)

In [9]:
def train_model(model, train_dataset_loader, preprocessor, device):
    model.train()
    
    for batch in tqdm.tqdm(train_dataset_loader):
        optimizer.zero_grad()
        batch = preprocessor(batch.to(device))
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()
    
    return model
    

def validate_model(model, val_dataset_loader, preprocessor, device):
    model.eval()
    
    val_acc = []
    
    for batch in val_dataset_loader:
        batch = preprocessor(batch.to(device))
        pred_logits = model(batch)
        _, pred = pred_logits.max(dim=1)
        acc = (pred.eq(batch.y).sum() / batch.y.size(0)).item()
        val_acc.append(acc)
    
    return sum(val_acc) / len(val_acc)

In [10]:
for epoch in range(epochs):
    
    print(f'Epoch: {epoch}')
        
    train_model(model, graph_train_loader, preprocess, device)
    val_acc = validate_model(model, graph_val_loader, preprocess, device)
    train_acc = validate_model(model, graph_train_loader, preprocess, device)
    
    print(f'Train Acc: {train_acc}, Val Acc: {val_acc}')
    print()

In [11]:
# graph_test_loader = DataLoader(graph_test_dataset, batch_size=batch_size)
# test_acc = validate_model(model, graph_test_loader, preprocess, device)
# test_acc

## Generative Model

In [130]:
class SplineCNNEncoder(torch.nn.Module):
    
    def __init__(self, output_dim):
        super(SplineCNNEncoder, self).__init__()
        self.conv1 = gnn.SplineConv(1, 32, dim=2, kernel_size=5)
        self.conv2 = gnn.SplineConv(32, output_dim, dim=2, kernel_size=5)
        # self.fc = torch.nn.Linear(64, output_dim)
        # self.fc2 = torch.nn.Linear(128, 10)
        self.output_dim = output_dim
    
    def get_output_dim(self):
        return self.output_dim

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = self.conv1(x, edge_index, edge_attr)
        x = F.elu(x)
        x = F.dropout(x, p=0.5)
        # cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        # data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        # data = gnn.max_pool(cluster, data)
        
        # x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = self.conv2(x, edge_index, edge_attr)
        x = F.elu(x)
        x = F.dropout(x, p=0.5)
        # cluster = gnn.graclus(edge_index, num_nodes=x.size(0))
        data = Batch(x=x, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        # data = gnn.max_pool(cluster, data)
        
        # x, batch = data.x, data.batch
        # print(x.shape)
        
        # x = gnn.global_mean_pool(x, batch)
        # print(x.shape)
        # x = F.elu(self.fc(x))
        # print(x.shape)
        # x = self.fc2(x)
        
        return data  # F.log_softmax(x, dim=1)
    

class GraphVAE(torch.nn.Module):
    """
    https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
    """
    
    def __init__(self, encoder, decoder, latent_dim):
        super(GraphVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.mean = gnn.SplineConv(encoder.get_output_dim(), latent_dim, dim=2, kernel_size=5)
        self.log_std = gnn.SplineConv(encoder.get_output_dim(), latent_dim, dim=2, kernel_size=5)
        self.latent_dim = latent_dim
        
    def encode(self, data: Batch) -> [Batch, Batch]:
        embed = self.encoder(data)
        
        x, edge_index, edge_attr, batch = embed.x, embed.edge_index, embed.edge_attr, embed.batch
        mean = self.mean(x, edge_index, edge_attr)
        log_std = self.log_std(x, edge_index, edge_attr)
        
        mean_batch = Batch(x=mean, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        log_std_batch = Batch(x=log_std, edge_index=edge_index, edge_attr=edge_attr, batch=batch)
        
        return mean_batch, log_std_batch
    
    def decode(self, z: Batch) -> Batch:
        recon = self.decoder(z.x, z.edge_index)
        recon_batch = Batch(x=recon, edge_index=z.edge_index, edge_attr=z.edge_attr, batch=z.batch)
        return recon_batch
    
    def reparametrise(self, mean: Batch, log_std: Batch) -> Batch:
        std = torch.exp(log_std.x)
        
        # Note: Just 1 MC particle
        eps = torch.randn_like(std)
        z = mean.x + eps * std
        
        z_batch = Batch(x=z, edge_index=mean.edge_index, edge_attr=mean.edge_attr, batch=mean.batch)
        z_batch.x = z
        
        return z_batch
    
    def forward(self, data: Batch) -> [Batch, Batch, Batch, Batch]:
        mean, log_std = self.encode(data)
        z = self.reparametrise(mean, log_std)
        return self.decode(z), data, mean, log_std
    
    def loss_function(self):
        pass

In [131]:
# from deepscm.distributions.deep import DeepIndepNormal
from torch_geometric.nn import InnerProductDecoder

output_dim = 150
latent_dim = 10

encoder = SplineCNNEncoder(output_dim).to(device)
decoder = InnerProductDecoder().to(device)
model = GraphVAE(encoder, decoder, latent_dim).to(device)

In [135]:
graph_train_loader = DataLoader(graph_train_dataset, batch_size=2, shuffle=False)

for x in graph_train_loader:
    break

_data = preprocess(x.to(device))
recon, _, mean, log_std = model(_data)

In [136]:
recon

Batch(batch=[150], edge_attr=[2659, 2], edge_index=[2, 2659], x=[2659])