In [53]:
import utilities as u
import torch
from torch_geometric.data import Data, InMemoryDataset, DataLoader

# Define Dataset

In [54]:
class TopLevelProofDataset(InMemoryDataset):
    def __init__(self, root='', transform=None, pre_transform=None):
        super(TopLevelProofDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return ['../top_level_proofs.dataset']
    
    def download(self):
        pass
    
    def process(self):
        data_list = []
        
        for thm, y in data:
            thm = u.process_theorem(thm)
            tree, distinct_features = u.thm_to_tree(thm)
            
            x, edge_index = u.graph_to_data(tree, list(distinct_features))
            
            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# SAGEConv Layer

In [55]:
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

# GNN definition

In [56]:
embed_dim = 128
from torch_geometric.nn import TopKPooling, GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = GCNConv(embed_dim, 128)
        self.embedding = torch.nn.Embedding(num_embeddings=len(distinct_features)+1, embedding_dim=embed_dim)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        self.lin1 = torch.nn.Linear(256, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, 1)
        self.bn1 = torch.nn.BatchNorm1d(128)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()        
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.embedding(x)
        x = x.squeeze(1)      

        x = self.conv1(x, edge_index)
        x = F.relu(x)

        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
     
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)      
        x = F.dropout(x, p=0.5, training=self.training)

        x = torch.sigmoid(self.lin3(x)).squeeze(1)

        return x

# Data inspections

In [3]:
data = u.make_data()
data = data[0:2]

0


FileNotFoundError: [Errno 2] No such file or directory: '../deephol-data/deepmath/deephol/proofs/human/train/prooflogs-00000-of-00600.pbtxt'

In [51]:
distinct_features = set(i for i in range(878))

In [None]:


for idx, (thm, _) in enumerate(data):
    if idx % 1000 == 0:
        print(f'{idx} / {len(data)}')
    thm = u.process_theorem(thm)
    thm_tree, features = u.thm_to_tree(thm)
    distinct_features = distinct_features.union(features)

In [35]:
test_thm = '(fun (a A B) (a A (a A B)))'
print(test_thm)
thm = u.process_theorem(test_thm)
print(thm)
thm_tree = u.thm_to_tree(thm)
print(thm_tree)

(fun (a A B) (a A (a A B)))
['(', 'fun', '(', 'a', 'A', 'B', ')', '(', 'a', 'A', '(', 'a', 'A', 'B', ')', ')', ')']
(<utilities.data_structures.Tree object at 0x7f41e66b6c50>, {'fun', 'a', 'A', 'B'})


In [36]:
counter = dict()
for _, y in data:
    if y in counter:
        counter[y] += 1
    else:
        counter[y] = 1
counter = list(counter.items())
counter.sort(key=lambda x: x[1], reverse=True)
percentages = [(x, y/len(datapoints)*100) for x,y in counter]
percentages

NameError: name 'data' is not defined

# Create Dataset

In [58]:
dataset = TopLevelProofDataset()
print(dataset)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

TopLevelProofDataset(13677)


# Train

In [61]:
batch_size = 1
num_epochs = 10

def train():
    model.train()
    
    loss_all = 0
    count = 0
    for data in train_loader:
        count += 1
        x = data.x.squeeze(1)

        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y.to(device)
        label = label.float()
        loss = crit(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()
        
        if count % 100 == 0:
            print(count)
    #print(loss_all)
    return loss_all / len(dataset)
    
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
crit = torch.nn.BCELoss()
train_loader = DataLoader(dataset, batch_size=batch_size)
for epoch in range(num_epochs):
    a = train()
    print(a)

100
200
300


KeyboardInterrupt: 