# MVE Axiom Prediction

In [1]:
%env CUDA_VISIBLE_DEVICES=""

env: CUDA_VISIBLE_DEVICES=""


# TODO



* Add model parameters
* Add other model definition from script
* Try to get a reasonable score?
* TODO add batchnorm
* TODO add dropout parameter
* TODO add residual connections
* TODO add other directions?

# Imports

In [68]:
import torch
from torch.nn import Embedding
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, Linear
import torch.nn as nn
import torch_geometric.nn as pyg_nn


In [4]:
import sys, os
from pathlib import Path
sys.path.append(str(Path(os.path.abspath('')).parent))

import config
from dataset import get_data_loader, BenchmarkType

## CONSTANTS

In [5]:
#TRAIN_ID = '../id_files/train.txt'
TRAIN_ID = '../id_files/validation.txt'

VAL_ID = '../id_files/validation.txt'
BENCHMARK_TYPE = BenchmarkType('deepmath')

In [6]:
torch.manual_seed(1234567)

<torch._C.Generator at 0x7f0904037610>

## Load dataset

In [7]:
train_data = get_data_loader(TRAIN_ID, BENCHMARK_TYPE)
val_data = get_data_loader(VAL_ID, BENCHMARK_TYPE)

Dataset: TorchMemoryDataset(2465)
Dataset: TorchMemoryDataset(2465)


### Small check on a datapoint

In [52]:
data = next(iter(train_data))[0]
print(data)
print(data.keys)
print(data.num_nodes)
print(data.num_edges)
print(data.num_node_features)
print(data.has_isolated_nodes())
print(data.is_directed())


Data(x=[170], edge_index=[2, 268], premise_index=[2], conjecture_index=[1], name='t65_topalg_1', y=[2])
['x', 'name', 'y', 'premise_index', 'edge_index', 'conjecture_index']
170
268
1
True
True


## Model

In [96]:
class GNNStack(torch.nn.Module):
    def __init__(self, hidden_dim, output_dim, task='node'):
        
        super(GNNStack, self).__init__()
        self.task = task
        
        self.dropout = 0.25
        self.num_layers = 3
        
        self.node_embedding = Embedding(len(config.NODE_TYPE), hidden_dim)

        
        self.convs = nn.ModuleList()
        #self.convs.append(self.build_conv_model(hidden_dim, hidden_dim))
        for _ in range(self.num_layers): 
            self.convs.append(self.build_conv_model(hidden_dim, hidden_dim))
        
        self.lns = nn.ModuleList()
        self.lns.append(nn.LayerNorm(hidden_dim))
        self.lns.append(nn.LayerNorm(hidden_dim))
        
        


        # post-message-passing
        # TODO make a seprate function?
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.Dropout(self.dropout), 
            nn.Linear(hidden_dim, output_dim))
        
        
        #if not (self.task == 'node' or self.task == 'graph'):
        #    raise RuntimeError('Unknown task.')

        
    def build_conv_model(self, input_dim, hidden_dim):
        # refer to pytorch geometric nn module for different implementation of GNNs.
        if self.task == 'node':
            return pyg_nn.GCNConv(input_dim, hidden_dim)
        else:
            return pyg_nn.GINConv(nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                  nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))

    def forward(self, data):
        x, edge_index, premise_index = data.x, data.edge_index, data.premise_index    

        
        x = self.node_embedding(x)
        
        #if data.num_node_features == 0:
        #  x = torch.ones(data.num_nodes, 1)

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            emb = x
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if not i == self.num_layers - 1:
                x = self.lns[i](x)

        #if self.task == 'graph':
        #    x = pyg_nn.global_mean_pool(x, batch)

        # TODO should this be combined?
        x = x[premise_index]
        x = self.post_mp(x)
        x = x.squeeze(-1)
        return x

        #
        # return emb, F.log_softmax(x, dim=1) FIXME

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
    
    
model = GNNStack(hidden_dim=32, output_dim=1, task='graph')

In [90]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.node_embedding = Embedding(len(config.NODE_TYPE), hidden_channels)


        self.conv1 = GCNConv(hidden_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
        self.linear = Linear(hidden_channels, 1)
        

    def forward(self, input_batch):
        
        x = input_batch.x
        edge_index = input_batch.edge_index
        premise_index = input_batch.premise_index    

        x = self.node_embedding(x)

        x = self.conv1(x, edge_index)
        x = x.relu()
        
        x = F.dropout(x, p=0.5, training=self.training) # TODO add dropout parameter?
        
        x = self.conv2(x, edge_index)
        x = x.relu()
        
        x = x[premise_index]
        x = self.linear(x)
        
        # Remove inner axis
        x = x.squeeze(-1)
        
    
        return x

    
model = GCN(hidden_channels=16)
print(model)

GCN(
  (node_embedding): Embedding(15, 16)
  (conv1): GCNConv(16, 16)
  (conv2): GCNConv(16, 16)
  (linear): Linear(16, 1, bias=True)
)


## Train and Test

In [97]:

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
#criterion = torch.nn.CrossEntropyLoss()
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')

In [98]:
def train():
    
    model.train()
    
    for data in train_data:  # Iterate in batches over the training dataset.
        out = model(data)  # Perform a single forward pass. TODO change this

        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

        
        
def test(loader):
    
    model.eval()

    correct = 0
    total_samples = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data)
        pred = torch.sigmoid(out).round().long()
        
        correct += data.y.eq(pred).sum().item()
        
        total_samples += len(pred)

    return correct / total_samples  # Derive ratio of correct predictions.



In [99]:

for epoch in range(1, 5):
    train()
    train_acc = test(train_data)
    test_acc = test(val_data)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.5000, Test Acc: 0.5000
Epoch: 002, Train Acc: 0.5000, Test Acc: 0.5000
Epoch: 003, Train Acc: 0.5000, Test Acc: 0.5000
Epoch: 004, Train Acc: 0.5000, Test Acc: 0.5000
