In [1]:
import utilities as u

import torch
import torch.nn as nn
from torch_geometric.data import Data, Dataset, InMemoryDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn import Embedding

import numpy as np
# import seaborn as sns
import pandas as pd

from math import floor
import pickle
import random
from tqdm import tqdm

torch.manual_seed(42)

<torch._C.Generator at 0x7f81f40493d0>

In [2]:
EXPNAME = 'bug_fix'
binary = True
only_top = True

In [3]:
def save_model(model, PATH):
    torch.save(model.state_dict(), PATH)
    
def load_model(model_type, PATH):
    model = model_type()
    model.load_state_dict(torch.load(PATH))
    return model

# Define Dataset

In [4]:
class TopLevelProofDataset(InMemoryDataset):
    def __init__(self, root='', transform=None, pre_transform=None):
        super(TopLevelProofDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'../datasets/{dataset_name}.dataset']
    
    def download(self):
        pass
    
    def process(self):
        global data
        data_list = []
        all_features = set()
        trees = []
        
        for thm, y in tqdm(data):
            thm = u.process_theorem(thm)
            tree, distinct_features = u.thm_to_tree(thm, to_merge)
            all_features = all_features | distinct_features
            trees.append((tree, y))
        
#         normalized_features = {k: [random.random() for i in range(128)] for k in list(all_features)}
        normalized_features = {k: [i] for i,k in enumerate(all_features)}
            
        for idx, (tree, y) in tqdm(enumerate(trees)):
            merged_tree = u.merge_subexpressions(tree) if to_merge else tree
            x, (edge_index_up, edge_index_down), (edge_features_up, edge_features_down) = u.graph_to_data(tree, 
                                                                                                           normalized_features)
            datum = Data(x=x, 
                        y=y, 
                        edge_index=torch.cat((edge_index_up, edge_index_down), dim=1),
                        edge_attr=torch.cat((edge_features_up, edge_features_down)),
                       )
            data_list.append(datum)
#             trees[idx] = None
            
        
        all_data, slices = self.collate(data_list)
        torch.save((all_data, slices), self.processed_paths[0])

In [5]:
n_graphs = 10

class ProofDataset(Dataset):
    def __init__(self, root='', transform=None, pre_transform=None):
        super(TopLevelProofDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'../datasets/{i}.dataset' for i in range(n_graphs)]
    
    def download(self):
        pass
    
    def process(self):
        global data
        data_list = []
        all_features = set()
        trees = []
        count = 0
        
        for i in range(num_files):
            data = get_data_from_file(i, binary, only_top)
        
            for thm, y in tqdm(data):
                thm = u.process_theorem(thm)
                tree, distinct_features = u.thm_to_tree(thm, to_merge)
                all_features = all_features | distinct_features
                trees.append((tree, y))

    #         normalized_features = {k: [random.random() for i in range(128)] for k in list(all_features)}
            normalized_features = {k: [i] for i,k in enumerate(all_features)}

            for idx, (tree, y) in tqdm(enumerate(trees)):
                merged_tree = u.merge_subexpressions(tree) if to_merge else tree
                x, (edge_index_up, edge_index_down), (edge_features_up, edge_features_down) = u.graph_to_data(tree, 
                                                                                                               normalized_features)
                data = Data(x=x, 
                            y=y, 
                            edge_index=torch.cat((edge_index_up, edge_index_down), dim=1),
                            edge_attr=torch.cat((edge_features_up, edge_features_down)),
                           )
                data_list.append(data)
                trees[idx] = None
            
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# SAGEConv Layer

In [6]:
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

class SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='mean') #  "Max" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.act = torch.nn.ReLU()
        self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_act = torch.nn.ReLU()
        
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        # x_j has shape [E, in_channels]

        x_j = self.lin(x_j)
        x_j = self.act(x_j)
        
        return x_j

    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]


        new_embedding = torch.cat([aggr_out, x], dim=1)
        
        new_embedding = self.update_lin(new_embedding)
        new_embedding = self.update_act(new_embedding)
        
        return new_embedding

# GNN definition

In [7]:
embed_dim = 128
from torch_geometric.nn import TopKPooling, GCNConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = SAGEConv(dataset.num_features, embed_dim)
#         self.conv1 = GCNConv(embed_dim, 128)
#         self.embedding = torch.nn.Embedding(num_embeddings=len(distinct_features)+1, embedding_dim=embed_dim)
        self.pool1 = TopKPooling(128, ratio=0.8)
        self.conv2 = GCNConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)
        self.conv3 = GCNConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)
        self.lin1 = torch.nn.Linear(256, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, 11)
        self.lin4 = torch.nn.Linear(64, 1)
        self.bn1 = torch.nn.BatchNorm1d(128)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()  
  
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch    

        x = self.conv1(x, edge_index)
        x = F.relu(x)

        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
#         x = x1

        x = F.relu(self.conv2(x, edge_index))
     
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))

        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)      
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin4(x))

#         x = F.log_softmax(self.lin3(x), dim=1).squeeze(1)


        return x

# Model 2 (Subgraph Pooling Paper)

In [8]:
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

class PaliwalMP(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(PaliwalMP, self).__init__(aggr='mean', flow='target_to_source') #  "Mean" aggregation.
        
        # MLP for Parents and Children, step 2 of Paliwal MP
        self.MLP_edge = BuildingBlock(3*128, 128)
        self.MLP_edge_hat = BuildingBlock(3*128, 128)
        
        # MLP to pass aggregated message through, step 3 of Paliwal MP
        self.MLP_aggr = BuildingBlock(3*128, 128)

        
    def forward(self, x, edge_index_parents, edge_index_children, edge_attr_parents, edge_attr_children):
        # x has shape [N, in_channels]
        # edge_index_x has shape [2, E/2]
        out_parents = self.propagate(edge_index_parents, 
                                     x=x, 
                                     edge_attr=edge_attr_parents, 
                                     direction='up', 
                                     size=None)
        out_children = self.propagate(edge_index_children, 
                                      x=x, 
                                      edge_attr=edge_attr_children, 
                                      direction='down', 
                                      size=None)
        
        out = torch.cat([x, out_parents, out_children], dim=1)
        out = self.MLP_aggr(out) + x
        
        
        return out

    
    def message(self, x_i, x_j, edge_attr, direction):

        s_ij = torch.cat([x_i, x_j, edge_attr], dim=1)
        if direction == 'up':
            s_ij = self.MLP_edge(s_ij)
        elif direction == 'down':
            s_ij = self.MLP_edge_hat(s_ij)
        
        return s_ij

    
    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]
        return aggr_out


In [9]:
from torch_geometric.nn import global_max_pool as gmp
from torch_geometric.nn import BatchNorm

embed_dim = 128

# TODO: Apply dropout

class BuildingBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dim=0):
        super(BuildingBlock, self).__init__()
        self.lin1 = Linear(in_channels, 256, dim)
        self.hidden = Linear(256, 128)
        self.lin2 = Linear(128, 128)
        
    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.hidden(x))
        x = F.dropout(x, 0.5)
        x = F.relu(self.lin2(x))
        
        return x
    

class PaliwalNet(torch.nn.Module):
    def __init__(self, t):
        super(PaliwalNet, self).__init__()
        
        self.embedding = Embedding(num_embeddings=distinct_features[dataset_name]+1, embedding_dim=embed_dim)
        
        self.MLP_V = BuildingBlock(embed_dim, 128)
        self.MLP_E = BuildingBlock(1, 128)
        
        self.message_passing_steps = nn.ModuleList()
        for i in range(t):
            self.message_passing_steps.append(PaliwalMP(embed_dim, embed_dim))
            
        self.conv1 = nn.Conv1d(128, 512, 1)
        self.conv2 = nn.Conv1d(512, 1024, 1)
        
        self.bn1 = BatchNorm(512)
        self.bn2 = BatchNorm(1024)
        
        # FCNN for final prediction
        self.lin1 = Linear(1024, 512)
        self.lin2 = Linear(512, 512)
        self.lin3 = Linear(512, 256)
        self.lin4 = Linear(256, 256)
        self.lin5 = Linear(256, 128)
        self.lin6 = Linear(128, 1)

  
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        edge_index_u, edge_index_d = torch.split(edge_index, int(edge_index.shape[1]/2), dim=1)
        edge_attr_u, edge_attr_d = torch.split(edge_attr, int(edge_attr.shape[0]/2))
        
        # Generate learnable embeddings for node features
        x = x.squeeze(-1)
        x = self.embedding(x)
        
        
        # Embed node and edge features into high dimensional space
        x = self.MLP_V(x)
        edge_attr_u = self.MLP_E(edge_attr_u.float())
        edge_attr_d = self.MLP_E(edge_attr_d.float())
        
        for message_passing_step in self.message_passing_steps:
            x = message_passing_step(x, edge_index_u, edge_index_d, edge_attr_u, edge_attr_d)
        
        x = x.transpose(0,1).unsqueeze(0)
        x = self.conv1(x)
#         x = self.bn1(x)
        x = self.conv2(x)
#         x = self.bn2(x)
        x = x.squeeze(0).transpose(0,1)
        
        # Final prediction network
        x = x.squeeze(-1)
        g = gmp(x, batch)
        
        g = F.relu(self.lin1(g))
        g = F.relu(self.lin2(g))
        g = F.relu(self.lin3(g))
        g = F.relu(self.lin4(g))
        g = F.relu(self.lin5(g))
        g = F.relu(self.lin6(g))
        
        return g

# Data inspections

In [10]:
def get_data_distribution(data):
    counter = dict()
    for _, y in data:
        if y in counter:
            counter[y] += 1
        else:
            counter[y] = 1
    counter = list(counter.items())
    counter.sort(key=lambda x: x[0], reverse=False)
    percentages = [(x, y/len(data)*100) for x,y in counter]
    return percentages

In [11]:
# new_data = True
# binary = True
# only_top = False

# if new_data == True:
#     data = u.make_data(binary=binary, only_top=only_top)
#     with open('testing_dataset', 'wb') as outfile:
#         pickle.dump(data, outfile)
# else:
#     with open(EXPNAME, 'rb') as infile:
#         data = pickle.load(infile)
    
        
        
# # dataset_name = 'BSu'
# to_merge = True

100%|██████████| 50/50 [00:35<00:00,  1.41it/s]


In [12]:
def get_num_distinct_features(data):
    distinct_features = set()
    for thm, _ in tqdm(data):
        thm = u.process_theorem(thm)
        thm_tree, features = u.thm_to_tree(thm, to_merge=False)
        distinct_features = distinct_features.union(features)
    return len(distinct_features)

# distinct_features = len(distinct_features)

In [13]:
dummy_test = False

if dummy_test:
    test_thm = '(fun (a A B) (a A (a A B)))'
    print(test_thm)
    thm = u.process_theorem(test_thm)
    print(thm)
    thm_tree, _ = u.thm_to_tree(thm)
    print(len(thm_tree))
    print(thm_tree.subtrees[0].parents[0])
    thm_tree = u.merge_subexpressions(thm_tree)
    x = u.graph_to_data(thm_tree)
    print(x)

    #print([t.root for t in thm_tree.subtrees[0].subtrees])


    print(thm_tree.root)
    print([t.root for t in thm_tree.subtrees])
    t_0, t_1 = thm_tree.subtrees
    print([t.root for t in t_0.subtrees])
    print([t.root for t in t_1.subtrees])
    print(t_1.subtrees[0].subtree_str)
    print(len(thm_tree))

# New Train Function

In [14]:
def train(model, data_loader, epoch, crit, optimizer, device, writer, len_dataset):
    model.train()
    loss_all = 0
    
    for i, data in enumerate(data_loader):
        optimizer.zero_grad()
        model.zero_grad()
        data = data.to(device)
        
        output = model(data)
        label = torch.unsqueeze(data.y.to(device), 1).float()
        
        loss = crit(output, label)
        loss.backward()
        loss.detach()
        loss_all += data.num_graphs * loss.item()

        optimizer.step()
    
    writer.add_scalar('training loss',
                     loss_all / len_dataset,
                     epoch)
    
    
    return loss_all / len(train_dataset), 0, output

# New Evaluation Function

In [15]:
def test(model, data_loader, epoch, crit, device, writer, len_dataset):
    model.eval()
    loss_all = 0
    all_preds = None
    all_labels = None
    
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            data = data.to(device)
            output = model(data)
            label = torch.unsqueeze(data.y.to(device), 1).float()
            if all_labels is not None:
                all_labels = torch.cat([all_labels, label])
            else:
                all_labels = label

            loss = crit(output, label)
            loss_all += data.num_graphs * loss.item()
            loss.detach()
    #         preds = output.data.max(1, keepdim=True)[1]

            if all_preds is not None:
                all_preds = torch.cat([all_preds, output])
            else:
                all_preds = output
    
    writer.add_scalar('validation loss',
                     loss_all / len_dataset,
                     epoch)
    
    all_preds, all_labels = all_preds.to('cpu'), all_labels.to('cpu')
    
    return loss_all / len(valid_dataset), 0, output, all_preds.detach().numpy(), all_labels.detach().numpy()

# New general 'train_model' function which handles full training

In [16]:
def train_model(model, 
                train_dataset,
                valid_dataset,
                crit, 
                optimizer, 
                lr, 
                experiment_label,
                momentum=0.8,
                device='cpu', 
                n_epochs=100, 
                batch_size=8, 
                n_workers=4,
                shuffle_data=True):
    
    writer = SummaryWriter(f'runs/{experiment_label}')
    
    model = model.to(device)
    optimizer = optimizer(model.parameters(), lr=lr, momentum=momentum)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle_data, num_workers=n_workers)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=shuffle_data, num_workers=n_workers)
    
    
    # Run initial pass through validation loop
    valid_loss, valid_acc, output, predictions, _ = test(model, valid_loader, 0, crit, device, writer, len(valid_dataset))
    best_acc = valid_acc
    valid_losses, valid_accuracies = [valid_loss], [valid_acc]
    train_losses, train_accuracies = [], []
    
    # Train for n_epochs
    for epoch in tqdm(range(n_epochs)):
        epoch_loss, epoch_acc, predictions = train(model, train_loader, epoch, crit, optimizer, device, writer, len(train_dataset))
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        # Every 10 epochs, run through a validation loop
        if epoch % 9 == 0:
            valid_loss, valid_acc, outputs, preds, labels = test(model, valid_loader, epoch, crit, device, writer, len(valid_dataset))
            valid_losses.append(valid_loss)
            valid_accuracies.append(valid_acc)
            
            # Record the highest observed validation accuracy
            if valid_acc > best_acc:
                save_model(model, f'models/{experiment_label}_best_valid_acc_epoch_{epoch}')
                
#             print(preds, labels)
            df = pd.DataFrame(data={'Predictions': [preds], 'Labels': [labels]})
#             df.to_csv(f'pp_{experiment_label}_epoch_{epoch}')
#             ax = sns.regplot(x='Predictions', y='labels', data=df)
    
    
    # Output loss/acc stats to csv
    validation_stats = np.array([valid_losses, valid_accuracies])
    training_stats = np.array([train_losses, train_accuracies])
    np.savetxt(f'stats/{experiment_label}_validation_stats.csv', validation_stats, delimiter=',')
    np.savetxt(f'stats/{experiment_label}_training_stats.csv', training_stats, delimiter=',')
    

In [17]:
# dataset_name = 'test_dataset'
# dataset = TopLevelProofDataset()
# dataset.shuffle()

# train_dataset = dataset[:2*floor(len(dataset)/3)]
# valid_dataset = dataset[2*floor(len(dataset)/3):]
# print(len(valid_dataset))

# train_model(model=PaliwalNet(t=4),
#            train_dataset=train_dataset,
#            valid_dataset=valid_dataset,
#            crit=F.mse_loss,
#            optimizer=torch.optim.SGD,
#            lr=0.005,
#            momentum=0.8,
#            experiment_label='testing_new_functional_training',
#            device='cuda:1',
#            n_epochs=100,
#            batch_size=2,
#            n_workers=1,
#            shuffle_data=False)

In [18]:
# Datasets: Binary/Multiclass (B/M), OnlyTop/Subtheorems (O/S), Merged/Unmerged (m/u)
dataset_features = {'n_classes': ['M', 'B'],
                   'theorems_used': ['S', 'O'],
                   'subexpression_sharing': ['u', 'm']}

distinct_features = dict()

data = None
dataset_name = None
to_merge = None


def create_datasets():
    global data
    global dataset_name
    global to_merge
    
    for binary, x in enumerate(dataset_features['n_classes']):
        for only_top, y in enumerate(dataset_features['theorems_used']):
            for merged, z in enumerate(dataset_features['subexpression_sharing']):
                dataset_name = x + y + z
                data = u.make_data(binary=bool(binary), only_top=bool(only_top))
                data = list(set(data))
                data_distribution = get_data_distribution(data)
                to_merge = bool(merged)
                distinct_features[dataset_name] = get_num_distinct_features(data)
#                 dataset = TopLevelProofDataset()
                print(f'{dataset_name}({len(data)}): ', data_distribution)

# create_datasets()

In [19]:
# print(distinct_features)
# pickle.dump(distinct_features, open( 'distinct_features.p', 'wb' ))
distinct_features = pickle.load(open('distinct_features.p', 'rb'))
print(distinct_features)

{'MSu': 615, 'MSm': 615, 'MOu': 527, 'MOm': 527, 'BSu': 615, 'BSm': 615, 'BOu': 527, 'BOm': 527}


In [20]:
dataset_name = 'BSu'
dataset = TopLevelProofDataset()
dataset.shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(len(valid_dataset))

train_model(model=PaliwalNet(t=4),
           train_dataset=train_dataset,
           valid_dataset=valid_dataset,
           crit=F.mse_loss,
           optimizer=torch.optim.SGD,
           lr=0.005,
           momentum=0.8,
           experiment_label='BSu_4',
           device='cuda:1',
           n_epochs=100,
           batch_size=16,
           n_workers=8,
           shuffle_data=True)

8831


  1%|          | 1/100 [03:47<6:15:47, 227.75s/it]


KeyboardInterrupt: 

In [None]:
dataset_name = 'BSm'
dataset = TopLevelProofDataset()
dataset.shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(len(valid_dataset))

train_model(model=PaliwalNet(t=4),
           train_dataset=train_dataset,
           valid_dataset=valid_dataset,
           crit=F.mse_loss,
           optimizer=torch.optim.SGD,
           lr=0.005,
           momentum=0.8,
           experiment_label='BSm_4',
           device='cuda:1',
           n_epochs=100,
           batch_size=16,
           n_workers=8,
           shuffle_data=True)