# MVE Axiom Prediction

In [1]:
%env CUDA_VISIBLE_DEVICES=""

env: CUDA_VISIBLE_DEVICES=""


# TODO




* Try to get a reasonable score?
* TODO add batchnorm
* TODO add residual connections
* TODO add other directions? [done]

# Imports

In [2]:
import torch
from torch.nn import Embedding
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, Linear
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.transforms import to_undirected, ToUndirected

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
import sys, os
from pathlib import Path

sys.path.append(str(Path(os.path.abspath("")).parent))

import config
from dataset import get_data_loader, BenchmarkType

## CONSTANTS

In [4]:
# TRAIN_ID = '../id_files/train.txt'
TRAIN_ID = "../id_files/validation.txt"

VAL_ID = "../id_files/validation.txt"
BENCHMARK_TYPE = BenchmarkType("deepmath")

In [5]:
torch.manual_seed(1234567)

<torch._C.Generator at 0x7f4b14098570>

## Load dataset

In [6]:
# dataset_params = {'transform': ToUndirected()}
dataset_params = {"transform": None}

# transform = None

In [7]:
train_data = get_data_loader(TRAIN_ID, BENCHMARK_TYPE, **dataset_params)
val_data = get_data_loader(VAL_ID, BENCHMARK_TYPE, **dataset_params)

Dataset: TorchMemoryDataset(2465)
Dataset: TorchMemoryDataset(2465)


## Data point check

In [8]:
next(iter(train_data))

DataBatch(x=[36912], edge_index=[2, 60406], premise_index=[842], conjecture_index=[64], name=[64], y=[842], batch=[36912], ptr=[65])

In [9]:
data = next(iter(train_data))[0]
print(data)
print(data.keys)
print(data.num_nodes)
print(data.num_edges)
print(data.num_node_features)
print(data.has_isolated_nodes())
print(data.is_directed())

Data(x=[164], edge_index=[2, 258], premise_index=[4], conjecture_index=[1], name='t104_zfmisc_1', y=[4])
['name', 'x', 'edge_index', 'premise_index', 'conjecture_index', 'y']
164
258
1
True
True


## Model

In [10]:
def get_dense_output_network(hidden_dim, task, dropout_rate=0.0):
    if task == "premise":
        return nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.Dropout(dropout_rate), nn.Linear(hidden_dim, 1)
        )  # Two layer dense output network

In [11]:
nn.BatchNorm1d

torch.nn.modules.batchnorm.BatchNorm1d

In [12]:
GCN_NORMALISATION = {"batch": nn.BatchNorm1d, "layer": nn.LayerNorm}

In [13]:
class GCNDirectional(torch.nn.Module):
    def __init__(
        self, hidden_dim, num_convolutional_layers, dropout_rate, normalisation, skip_connection
    ):  # TODO also add normalisation
        super(GCNDirectional, self).__init__()

        self.flow = "target_to_source"  # Sets direction to bottom up
        # self.flow = 'source_to_target' # Not sensible, premise nodes remains unchanged

        # Set variables
        self.hidden_dim = hidden_dim
        self.num_convolutional_layers = num_convolutional_layers
        self.dropout_rate = dropout_rate
        self.skip_connection = skip_connection

        # Add convolutional layers
        self.convs = nn.ModuleList()
        for _ in range(self.num_convolutional_layers):
            self.convs.append(self.build_conv_model(hidden_dim, hidden_dim))

        # Add normalisation layers used in between graph convolutions
        if normalisation is None:
            self.lns = None
        else:
            self.normaliser = GCN_NORMALISATION[normalisation]
            self.lns = nn.ModuleList()
            for _ in range(self.num_convolutional_layers - 1):
                # self.lns.append(nn.LayerNorm(hidden_dim))
                self.lns.append(self.normaliser(hidden_dim))

    def build_conv_model(self, input_dim, hidden_dim):
        return pyg_nn.GCNConv(input_dim, hidden_dim, flow=self.flow)

    def forward(self, x, edge_index):

        # Iterate over each convolutional sequence
        for i in range(self.num_convolutional_layers):

            conv_out = self.convs[i](x, edge_index)
            # Check if applying skip connection
            if self.skip_connection:
                x = x + conv_out
            else:
                x = conv_out

            emb = x
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

            if self.lns is not None and not i == self.num_convolutional_layers - 1:  # Apply normalisation
                x = self.lns[i](x)

        return emb, x


sub = GCNDirectional(
    hidden_dim=32, num_convolutional_layers=3, dropout_rate=0.25, normalisation="batch", skip_connection=False
)
sub

GCNDirectional(
  (convs): ModuleList(
    (0): GCNConv(32, 32)
    (1): GCNConv(32, 32)
    (2): GCNConv(32, 32)
  )
  (lns): ModuleList(
    (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [14]:
# TODO should variables not used in this very module actually be saved here?
class GNNStack(torch.nn.Module):
    def __init__(
        self,
        hidden_dim,
        num_convolutional_layers,
        dropout_rate=0.0,
        task="premise",
        normalisation="layer",
        skip_connection=True,
    ):
        super(GNNStack, self).__init__()

        # Set variables
        self.task = task
        self.dropout_rate = dropout_rate
        self.hidden_dim = hidden_dim

        # Add embedding layer
        self.node_embedding = Embedding(len(config.NODE_TYPE), hidden_dim)

        # Add GCN layer
        self.gcn = GCNDirectional(
            hidden_dim=self.hidden_dim,
            num_convolutional_layers=num_convolutional_layers,
            dropout_rate=self.dropout_rate,
            normalisation=normalisation,
            skip_connection=skip_connection,
        )

        # Post-message-passing
        self.post_mp = get_dense_output_network(hidden_dim, task=self.task, dropout_rate=self.dropout_rate)

    def forward(self, data):
        x, edge_index, premise_index = data.x, data.edge_index, data.premise_index

        x = self.node_embedding(x)

        emb, x = self.gcn(x, edge_index)

        # if self.task == 'graph':
        #    x = pyg_nn.global_mean_pool(x, batch)

        # TODO should this be combined?
        x = x[premise_index]
        x = self.post_mp(x)
        x = x.squeeze(-1)

        return emb, x

    def loss(self, pred, label):
        return F.nll_loss(pred, label)


model = GNNStack(hidden_dim=32, num_convolutional_layers=3, dropout_rate=0.25, task="premise")

In [15]:
print(model)

GNNStack(
  (node_embedding): Embedding(15, 32)
  (gcn): GCNDirectional(
    (convs): ModuleList(
      (0): GCNConv(32, 32)
      (1): GCNConv(32, 32)
      (2): GCNConv(32, 32)
    )
    (lns): ModuleList(
      (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (post_mp): Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): Dropout(p=0.25, inplace=False)
    (2): Linear(in_features=32, out_features=1, bias=True)
  )
)


## Train and Test

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")

In [17]:
def train():

    model.train()

    for data in train_data:  # Iterate in batches over the training dataset.
        _, out = model(data)  # Perform a single forward pass. TODO change this

        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.


def test(loader):

    model.eval()

    correct = 0
    total_samples = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        _, out = model(data)
        pred = torch.sigmoid(out).round().long()

        correct += data.y.eq(pred).sum().item()

        total_samples += len(pred)

    return correct / total_samples  # Derive ratio of correct predictions.

In [18]:
for epoch in range(1, 3):
    train()
    train_acc = test(train_data)
    test_acc = test(val_data)
    print(f"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")

Epoch: 001, Train Acc: 0.5941, Test Acc: 0.5941
Epoch: 002, Train Acc: 0.6389, Test Acc: 0.6389


In [19]:
"""
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.node_embedding = Embedding(len(config.NODE_TYPE), hidden_channels)


        self.conv1 = GCNConv(hidden_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
        self.linear = Linear(hidden_channels, 1)
        

    def forward(self, input_batch):
        
        x = input_batch.x
        edge_index = input_batch.edge_index
        premise_index = input_batch.premise_index    

        x = self.node_embedding(x)

        x = self.conv1(x, edge_index)
        x = x.relu()
        
        x = F.dropout(x, p=0.5, training=self.training) # TODO add dropout parameter?
        
        x = self.conv2(x, edge_index)
        x = x.relu()
        
        x = x[premise_index]
        x = self.linear(x)
        
        # Remove inner axis
        x = x.squeeze(-1)
        
    
        return x

    
model = GCN(hidden_channels=16)
print(model)
"""

'\nclass GCN(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.node_embedding = Embedding(len(config.NODE_TYPE), hidden_channels)\n\n\n        self.conv1 = GCNConv(hidden_channels, hidden_channels)\n        self.conv2 = GCNConv(hidden_channels, hidden_channels)\n        \n        self.linear = Linear(hidden_channels, 1)\n        \n\n    def forward(self, input_batch):\n        \n        x = input_batch.x\n        edge_index = input_batch.edge_index\n        premise_index = input_batch.premise_index    \n\n        x = self.node_embedding(x)\n\n        x = self.conv1(x, edge_index)\n        x = x.relu()\n        \n        x = F.dropout(x, p=0.5, training=self.training) # TODO add dropout parameter?\n        \n        x = self.conv2(x, edge_index)\n        x = x.relu()\n        \n        x = x[premise_index]\n        x = self.linear(x)\n        \n        # Remove inner axis\n        x = x.squeeze(-1)\n        \n    \n        return x\n\n