In [1]:
import utilities as u

import torch
from torch_geometric.data import Data, InMemoryDataset, DataLoader
import torch.nn as nn

import pickle
import random

torch.manual_seed(42)

<torch._C.Generator at 0x7f3dbc0063b0>

In [2]:
EXPNAME = 'test_new_feature_size_1output'
binary = True
only_top = True

# Tensorboard Plotting 

In [3]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'runs/{EXPNAME}')

# Define Dataset

In [4]:
from tqdm import tqdm

class TopLevelProofDataset(InMemoryDataset):
    def __init__(self, root='', transform=None, pre_transform=None):
        super(TopLevelProofDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'../test_new_feature_size.dataset']
    
    def download(self):
        pass
    
    def process(self):
        global data
        data_list = []
        all_features = set()
        trees = []
        
        for thm, y in tqdm(data):
            thm = u.process_theorem(thm)
            tree, distinct_features = u.thm_to_tree(thm)
            all_features = all_features | distinct_features
            trees.append((tree, y))
        
        normalized_features = {k: [random.random() for i in range(128)] for k in list(all_features)}
            
        for tree, y in tqdm(trees):
            merged_tree = u.merge_subexpressions(tree)
            x, edge_index = u.graph_to_data(tree, normalized_features)
            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)
            
        
#         for thm, y in tqdm(data):
#             thm = u.process_theorem(thm)
#             tree, distinct_features = u.thm_to_tree(thm)
#             normalized_features = {k: [random.random() for i in range(16)] for k in list(distinct_features)}
#             tree = u.merge_subexpressions(tree)
            
# #             x, edge_index = u.graph_to_data(tree, list(distinct_features))
#             x, edge_index = u.graph_to_data(tree, normalized_features)
            
#             data = Data(x=x, edge_index=edge_index, y=y)
#             data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# SAGEConv Layer

In [5]:
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='mean') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

# GNN definition

In [6]:
embed_dim = 128
from torch_geometric.nn import TopKPooling, GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        print(dataset.num_features)
        self.conv1 = GCNConv(dataset.num_features, embed_dim)
#         self.conv1 = GCNConv(embed_dim, 128)
        self.embedding = torch.nn.Embedding(num_embeddings=len(distinct_features)+1, embedding_dim=embed_dim)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        self.lin1 = torch.nn.Linear(256, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, 11)
        self.lin4 = torch.nn.Linear(64, 1)
        self.bn1 = torch.nn.BatchNorm1d(128)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()  
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
#         print(x.shape)
#         x = self.embedding(x)
#         print(x.shape)
#         x = x.squeeze(1)      

        x = self.conv1(x, edge_index)
        x = F.relu(x)

        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
#         x = x1

        x = F.relu(self.conv2(x, edge_index))
     
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)      
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin4(x))

#         x = F.log_softmax(self.lin3(x), dim=1).squeeze(1)


        return x

# Model 2 (Subgraph Pooling Paper)

In [7]:
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

class PaliwalMP(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(PaliwalMP, self).__init__(aggr='mean', flow='target_to_source') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

In [8]:
embed_dim = 128
from torch_geometric.nn import TopKPooling, GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
class Net2(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # Batch norm
        
        # Message passing
        
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Apply batch norm and ReLU to each embedding
        x = F.relu(F.batch_norm(x))
        
        # Perform k steps of message-passing (k=4)

        

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)      
        x = F.dropout(x, p=0.5, training=self.training)

        x = F.log_softmax(self.lin3(x), dim=1).squeeze(1)


        return x

# Data inspections

In [9]:
new_data = True

if new_data == True:
    data = u.make_data(binary=binary, only_top=only_top)
    with open(EXPNAME, 'wb') as outfile:
        pickle.dump(data, outfile)
else:
    with open(EXPNAME, 'rb') as infile:
        data = pickle.load(infile)
        
        
# data = data[0:2]

100%|██████████| 150/150 [02:38<00:00,  1.06s/it]


In [10]:
distinct_features = set()

for idx, (thm, _) in enumerate(data):
    if idx % 1000 == 0:
        print(f'{idx} / {len(data)}')
    thm = u.process_theorem(thm)
    thm_tree, features = u.thm_to_tree(thm)
    distinct_features = distinct_features.union(features)
len(distinct_features)

0 / 6289
1000 / 6289
2000 / 6289
3000 / 6289
4000 / 6289
5000 / 6289
6000 / 6289


608

In [11]:
distinct_features = set(i for i in range(len(distinct_features)))

In [12]:
# test_thm = '(fun (a A B) (a A (a A B)))'
# print(test_thm)
# thm = u.process_theorem(test_thm)
# print(thm)
# thm_tree, _ = u.thm_to_tree(thm)
# print(len(thm_tree))

# #print([t.root for t in thm_tree.subtrees[0].subtrees])
# thm_tree = u.merge_subexpressions(thm_tree)

# print(thm_tree.root)
# print([t.root for t in thm_tree.subtrees])
# t_0, t_1 = thm_tree.subtrees
# print([t.root for t in t_0.subtrees])
# print([t.root for t in t_1.subtrees])
# print(t_1.subtrees[0].subtree_str)
# print(len(thm_tree))

In [13]:
counter = dict()
for _, y in data:
    if y in counter:
        counter[y] += 1
    else:
        counter[y] = 1
counter = list(counter.items())
counter.sort(key=lambda x: x[0], reverse=False)
percentages = [(x, y/len(data)*100) for x,y in counter]
percentages

[(0, 51.073302591826995), (1, 48.926697408173)]

# Create Dataset

In [14]:
from math import floor

dataset = TopLevelProofDataset()
dataset.shuffle()

train_dataset = dataset[:floor(len(dataset)/2)]
# train_dataset = dataset[:2]
# for x in train_dataset:
#     print(x)
valid_dataset = dataset[floor(len(dataset)/2) : 3*floor(len(dataset)/2)]
test_dataset = dataset[3*floor(len(dataset)/2):]

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)
print(train_dataset)

TopLevelProofDataset(3144)


# Train

In [15]:
batch_size = 32
num_epochs = 100

def train():
    global epoch
    model.train()
    
    loss_all = 0
    correct = 0
    for i, data in enumerate(train_loader):
        x = data.x.squeeze(1)

        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = torch.unsqueeze(data.y.to(device), 1).float()
#         torch.unsqueeze(label, 1)
#         print(data.shape, output.shape, label.shape)
#         print(label, output)
#         loss = F.nll_loss(output, label)
        loss = F.mse_loss(output, label)
    
        if torch.isnan(loss):
            print(output, label)
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct_this_time = pred.eq(label.data.view_as(pred)).sum()
        correct += correct_this_time
#         if correct_this_time == 0:
#             print(label, output)
        

        optimizer.step()
#         print(correct.item())
    
    writer.add_scalar('training loss',
                     loss_all / len(train_dataset),
                     epoch)
    
    writer.add_scalar('training accuracy',
                     correct.item() / len(train_dataset),
                     epoch)
    #print(correct.item())
    
    return loss_all / len(train_dataset), correct.item() / len(train_dataset)


def test():
    global epoch
    model.eval()
    print('hi')
    
    loss_all = 0
    correct = 0
    for i, data in enumerate(valid_loader):
        x = data.x.squeeze(1)

        data = data.to(device)
        output = model(data)
        label = torch.unsqueeze(data.y.to(device), 1)
        
#         print(output.shape, label.shape)
#         loss = F.nll_loss(output, label)
        loss = F.mse_loss(output, label)
        loss_all += data.num_graphs * loss.item()
#         print(label, output)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(label.data.view_as(pred)).sum()
    
    writer.add_scalar('validation loss',
                     loss_all / len(valid_dataset),
                     epoch)
    
    writer.add_scalar('validation accuracy',
                     correct.item() / len(valid_dataset),
                     epoch)
    
    return loss_all / len(valid_dataset), correct.item() / len(valid_dataset)
    
    
device = torch.device('cuda:1')

model = Net().to(device)
# clip = 5
# torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
optimizer = torch.optim.SGD(model.parameters(), lr=0.00005, momentum=0.8)
crit = torch.nn.CrossEntropyLoss()
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8)


epoch = 0
valid_loss, valid_acc = test()
for epoch in tqdm(range(num_epochs)):
    epoch_loss, epoch_acc = train()
    #print(f'train - loss: {epoch_loss}, acc: {epoch_acc}')
    
    if epoch % 5 == 4:
        valid_loss, valid_acc = test()
#         print(f'valid - loss: {valid_loss}, acc: {valid_acc}')
        
    #print()
        
    

16
hi


  4%|▍         | 4/100 [02:08<51:00, 31.88s/it]

hi


  9%|▉         | 9/100 [05:20<53:14, 35.11s/it]  

hi


 14%|█▍        | 14/100 [08:19<44:23, 30.97s/it]  

hi


 19%|█▉        | 19/100 [11:25<45:58, 34.05s/it]

hi


 24%|██▍       | 24/100 [14:30<42:04, 33.21s/it]

hi


 29%|██▉       | 29/100 [17:47<41:42, 35.25s/it]

hi


 34%|███▍      | 34/100 [20:51<35:49, 32.57s/it]

hi


 39%|███▉      | 39/100 [23:53<33:22, 32.82s/it]

hi


 44%|████▍     | 44/100 [27:11<32:09, 34.45s/it]

hi


 49%|████▉     | 49/100 [30:24<29:51, 35.14s/it]

hi


 54%|█████▍    | 54/100 [33:35<26:26, 34.49s/it]

hi


 59%|█████▉    | 59/100 [36:51<24:13, 35.44s/it]

hi


 64%|██████▍   | 64/100 [39:55<20:42, 34.53s/it]

hi


 69%|██████▉   | 69/100 [43:09<17:55, 34.69s/it]

hi


 74%|███████▍  | 74/100 [46:24<15:26, 35.62s/it]

hi


 79%|███████▉  | 79/100 [49:38<12:28, 35.63s/it]

hi


 84%|████████▍ | 84/100 [52:56<09:31, 35.70s/it]

hi


 89%|████████▉ | 89/100 [56:09<06:26, 35.09s/it]

hi


 94%|█████████▍| 94/100 [59:21<03:27, 34.56s/it]

hi


 99%|█████████▉| 99/100 [1:02:36<00:34, 34.99s/it]

hi


100%|██████████| 100/100 [1:03:39<00:00, 38.19s/it]


In [16]:
torch.save(model, f'{EXPNAME}')

  "type " + obj.__name__ + ". It won't be checked "


PicklingError: Can't pickle <class 'torch._C._VariableFunctions'>: it's not the same object as torch._C._VariableFunctions

# Try creating embedding vectors for the node features. It may be an issue that the features are not between 0 and 1. So creating an embedding network which attempts to recreate the features by passing them through a short vector which is then normalized could fix this.