In [None]:
import utilities as u

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Embedding
from torch.nn import Sequential as Seq, Linear, ReLU
from torch.utils.tensorboard import SummaryWriter

from torch_geometric.data import Data, Dataset, InMemoryDataset, DataLoader
from torch_geometric.nn import MessagePassing, TopKPooling, GCNConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.utils import remove_self_loops, add_self_loops

import numpy as np
# import seaborn as sns
import pandas as pd

from math import floor
import os
import pickle
import random
from tqdm import tqdm
import wandb

In [None]:
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
def save_model(model, PATH):
    torch.save(model.state_dict(), PATH)
    
def load_model(model_type, PATH, **kwargs):
    model = model_type(**kwargs)
    model.load_state_dict(torch.load(PATH))
    return model

In [None]:
dummy_test = False

if dummy_test:
    test_thm = '(fun (a A B) (a A (a A B)))'
    print(test_thm)
    thm = u.process_theorem(test_thm)
    print(thm)
    thm_tree, _ = u.thm_to_tree(thm)
    print(len(thm_tree))
    print(thm_tree.subtrees[0].parents[0])
    thm_tree = u.merge_subexpressions(thm_tree)
    x = u.graph_to_data(thm_tree)
    print(x)

    #print([t.root for t in thm_tree.subtrees[0].subtrees])


    print(thm_tree.root)
    print([t.root for t in thm_tree.subtrees])
    t_0, t_1 = thm_tree.subtrees
    print([t.root for t in t_0.subtrees])
    print([t.root for t in t_1.subtrees])
    print(t_1.subtrees[0].subtree_str)
    print(len(thm_tree))

# Define Dataset

In [None]:
class TopLevelProofDataset(InMemoryDataset):
    """
    InMemoryDataset, collects training examples from the first 150 files
    """
    def __init__(self, root='', transform=None, pre_transform=None):
        super(TopLevelProofDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'../datasets/{dataset_name}.dataset']
    
    def download(self):
        pass
    
    def process(self):
        global data
        data_list = []
        all_features = set()
        trees = []
        
        for thm, y in tqdm(data):
            thm = u.process_theorem(thm)
            tree, distinct_features = u.thm_to_tree(thm, to_merge)
            all_features = all_features | distinct_features
            trees.append((tree, y))
        
#         normalized_features = {k: [random.random() for i in range(128)] for k in list(all_features)}
        normalized_features = {k: [i] for i,k in enumerate(all_features)}
            
        for idx, (tree, y) in tqdm(enumerate(trees)):
            merged_tree = u.merge_subexpressions(tree) if to_merge else tree
            x, (edge_index_up, edge_index_down), (edge_features_up, edge_features_down) = u.graph_to_data(tree, 
                                                                                                           normalized_features)
            datum = Data(x=x, 
                        y=y, 
                        edge_index=torch.cat((edge_index_up, edge_index_down), dim=1),
                        edge_attr=torch.cat((edge_features_up, edge_features_down)),
                       )
            data_list.append(datum)
#             trees[idx] = None
            
        
        all_data, slices = self.collate(data_list)
        torch.save((all_data, slices), self.processed_paths[0])

In [None]:
n_graphs = 10
n_files = 600

class ProofDataset(Dataset):
    """
    Saved dataset. Collects a fixed number of training examples, and saves each in an individual file in 
    'datasets/{dataset_name}/    
    """
    def __init__(self, root='', transform=None, pre_transform=None):
        super(ProofDataset, self).__init__(root, transform, pre_transform)
#         self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'{dataset_name}/data_{i}.pt' for i in range(n_graphs)]
    
    def download(self):
        pass
    
    def process(self):
        global seen
        global data
        global binary
        global to_merge
        print(f'{dataset_name}: binary={bool(binary)}, merge={to_merge}')
        
        
#         seen = set()
#         counter = dict()
        all_features = set()
        
        trees = []
        data = []
        count = 0
        collecting = True
        data_dict = dict()
        
        if not os.path.exists(f'processed/{dataset_name}'):
            os.makedirs(f'processed/{dataset_name}')
        
        for i in range(n_files):
            if collecting is False:
                break
                

            data_from_file = u.get_data_from_file(i, binary, only_top=False)
        
            for thm, y in data_from_file:
                if collecting is False:
                    break
                    
                if thm in data_dict.keys():
                    data_dict[thm] = min(data_dict[thm], y)
                else:
                    data_dict[thm] = y
                    count += 1
                    
                    if count == n_graphs:
                        collecting = False
                    
                    if count % 100 == 99:
                        print(f'{count + 1}/{n_graphs} : {((count + 1)/n_graphs) *100:.2f}%', end='\r', flush=True)
            
            
        data = [(thm, y) for t,y in data_dict.items()]
        
        
        for thm, y in tqdm(data):
            thm = u.process_theorem(thm)
            tree, distinct_features = u.thm_to_tree(thm, to_merge)
            all_features = all_features | distinct_features
            trees.append((tree, y))
 
    #         normalized_features = {k: [random.random() for i in range(128)] for k in list(all_features)}
        print()
        normalized_features = {feature: [i] for i,feature in enumerate(all_features)}

        for idx, (tree, y) in tqdm(enumerate(trees)):
            merged_tree = u.merge_subexpressions(tree) if to_merge else tree
            x, (edge_index_up, edge_index_down), (edge_features_up, edge_features_down) = u.graph_to_data(tree, 
                                                                                                           normalized_features)
            datum = Data(x=x, 
                        y=y, 
                        edge_index=torch.cat((edge_index_up, edge_index_down), dim=1),
                        edge_attr=torch.cat((edge_features_up, edge_features_down)),
                       )
            torch.save(datum, f'processed/{dataset_name}/data_{idx}.pt')
            
        
        
            
        
    def len(self):
        return len(self.processed_file_names)

    
    def get(self, idx):
        datum = torch.load(f'processed/{dataset_name}/data_{idx}.pt')
        return datum

# Model 2 (Subgraph Pooling Paper)

In [None]:
class MLP(torch.nn.Module):
    """
    Standard MLP scheme for use in Subgraph Pooling model
    """
    def __init__(self, in_channels, out_channels, dim=0):
        super(MLP, self).__init__()
        self.lin1 = Linear(in_channels, 64, dim)
#         self.bn1 = nn.BatchNorm1d(64)
        self.hidden = Linear(64, 32)
#         self.bn2 = nn.BatchNorm1d(32)
        self.lin2 = Linear(32, 32)
        
    def forward(self, x):
        x = F.elu(self.lin1(x))
        x = F.dropout(x, 0.3)
        x = F.elu(self.hidden(x))
        x = F.dropout(x, 0.3)
        x = F.elu(self.lin2(x))
        return x

    
class PaliwalMP(MessagePassing):
    """Define the message-passing scheme from Subgraph Pooling paper."""
    
    def __init__(self, in_channels, out_channels):
        super(PaliwalMP, self).__init__(aggr='mean', flow='source_to_target') #  "Mean" aggregation.
        
        # MLP for Parents and Children, step 2 of Paliwal MP
        self.MLP_edge = BuildingBlock(3*in_channels, in_channels)
        self.MLP_edge_hat = BuildingBlock(3*in_channels, in_channels)
        
        # MLP to pass aggregated message through, step 3 of Paliwal MP
        self.MLP_aggr = BuildingBlock(3*in_channels, in_channels)

        
    def forward(self, x, edge_index_parents, edge_index_children, edge_attr_parents, edge_attr_children):
        # x has shape [N, in_channels]
        # edge_index_x has shape [2, E/2]
        out_parents = self.propagate(edge_index_parents, 
                                     x=x, 
                                     edge_attr=edge_attr_parents, 
                                     direction='up', 
                                     size=None)
        out_children = self.propagate(edge_index_children, 
                                      x=x, 
                                      edge_attr=edge_attr_children, 
                                      direction='down', 
                                      size=None)
        
        out = torch.cat([x, out_parents, out_children], dim=1)
        out = self.MLP_aggr(out) + x
        
        
        return out

    
    def message(self, x_i, x_j, edge_attr, direction):

        s_ij = torch.cat([x_i, x_j, edge_attr], dim=1)
        if direction == 'up':
            s_ij = self.MLP_edge(s_ij)
        elif direction == 'down':
            s_ij = self.MLP_edge_hat(s_ij)
        
        return s_ij

    
    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]
        return aggr_out


In [None]:
embed_dim = 128

class BuildingBlock(torch.nn.Module):
    """
    Standard MLP scheme for use in Subgraph Pooling model
    """
    def __init__(self, in_channels, out_channels):#, dim=0):
        super(BuildingBlock, self).__init__()
        self.lin1 = Linear(in_channels, 256)
        self.hidden = Linear(256, 128)
        self.lin2 = Linear(128, 128)
        
#         self.bn1 = nn.BatchNorm1d(256)
#         self.bn2 = nn.BatchNorm1d(128)
        
    def forward(self, x):
        x = F.elu(self.lin1(x))
        x = F.dropout(x, 0.3)
#         x = self.bn1(x)
        x = F.elu(self.hidden(x))
        x = F.dropout(x, 0.3)
#         x = self.bn2(x)
        x = F.elu(self.lin2(x))
        return x
    

class PaliwalNet(torch.nn.Module):
    """
    Implement GNN from Subgraph Pooling model. Accepts an arbitrary-size graph and produces a scalar output value.
    """
    def __init__(self, t, no_upsample=False, sigmoid=True, softmax=False):
        super(PaliwalNet, self).__init__()
        
        self.no_upsample = no_upsample
        self.using_softmax = softmax
        
        self.embedding = Embedding(num_embeddings=distinct_features[dataset_name]+1, embedding_dim=embed_dim)
        
        self.MLP_V = BuildingBlock(128, 128)
        if t > 0:
            self.MLP_E = BuildingBlock(1, 128)
        
        self.message_passing_steps = nn.ModuleList()
        for i in range(t):
            self.message_passing_steps.append(PaliwalMP(embed_dim, embed_dim))
            

        self.conv1 = nn.Conv1d(128, 512, 1)
        self.conv2 = nn.Conv1d(512, 1024, 1)

        # TODO: Try removing some layers or adding batch norm
        self.lin1 = Linear(1024, 512)
        self.lin2 = Linear(512, 256)
        self.lin3 = Linear(256, 128)
        self.lin4 = Linear(128, 128)
        if self.using_softmax:
            self.lin5 = Linear(128, 4)
        else:
            self.lin5 = Linear(128, 1)
        
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(128)
        if sigmoid:
            self.sigmoid = nn.Sigmoid()
            self.using_sigmoid = True
        else:
            self.using_sigmoid = False
        if softmax:
            self.softmax = nn.Softmax(dim=0)

  
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        edge_index_u, edge_index_d = torch.split(edge_index, int(edge_index.shape[1]/2), dim=1)
        edge_attr_u, edge_attr_d = torch.split(edge_attr, int(edge_attr.shape[0]/2))
        
        # Generate learnable embeddings for node features
        x = x.squeeze(-1)
        x = self.embedding(x)
        x = self.MLP_V(x)
        
        # Embed node and edge features into high dimensional space
        if len(self.message_passing_steps) > 0:
            edge_attr_u = self.MLP_E(edge_attr_u.float())
            edge_attr_d = self.MLP_E(edge_attr_d.float())
        
        for i, message_passing_step in enumerate(self.message_passing_steps):
            x = message_passing_step(x, edge_index_u, edge_index_d, edge_attr_u, edge_attr_d)
        
        x = x.transpose(0,1).unsqueeze(0)
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = x.squeeze(0).transpose(0,1)
        
        # Final prediction network
        x = gmp(x, batch)
        
        x = F.elu(self.lin1(x))
        x = self.bn1(x)
        x = F.elu(self.lin2(x))
        x = self.bn2(x)
        x = F.elu(self.lin3(x))
        x = self.bn3(x)
        x = F.elu(self.lin4(x))
        x = self.bn4(x)
        if self.using_sigmoid:
            x = self.sigmoid(self.lin5(x))
        elif self.using_softmax:
            x = self.softmax(self.lin5(x))
        else:
            x = self.lin5(x)
        
        return x

# Data inspections

In [None]:
def get_data_distribution(data):
    """Returns the distribution of target values for a dataset"""
    counter = dict()
    for _, y in data:
        if y in counter:
            counter[y] += 1
        else:
            counter[y] = 1
    counter = list(counter.items())
    counter.sort(key=lambda x: x[0], reverse=False)
    percentages = [(x, y/len(data)*100) for x,y in counter]
    return percentages

In [None]:
def get_num_distinct_features(data):
    """Returns the number of distinct node values for a dataset"""
    distinct_features = set()
    for thm, _ in tqdm(data):
        thm = u.process_theorem(thm)
        thm_tree, features = u.thm_to_tree(thm, to_merge=False)
        distinct_features = distinct_features.union(features)
    return len(distinct_features)

# New Train Function

In [None]:
def train(model, data_loader, epoch, crit, optimizer, device, len_dataset):
    model.train()
    loss_all = 0
    
    for i, data in enumerate(data_loader):
        optimizer.zero_grad()
        model.zero_grad()
        data = data.to(device)
        
        output = model(data)
#         label = torch.unsqueeze(data.y.to(device), 1).float()
        label = data.y.to(device).long()
        
        loss = crit(output, label)
        loss.backward()
        loss = loss.detach()
        loss_all += data.num_graphs * loss.item()

        optimizer.step()
    
    wandb.log({"Train Loss": loss_all/len_dataset},
             step=epoch+1)
    
    
    return loss_all / len(train_dataset), 0

# New Evaluation Function

In [None]:
def test(model, data_loader, epoch, crit, device, len_dataset):
    model.eval()
    loss_all = 0
    all_preds = None
    all_labels = None
    
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            data = data.to(device)
            output = model(data)
#             label = torch.unsqueeze(data.y.to(device), 1).float()
            label = data.y.to(device).long()
            
            loss = crit(output, label)
            loss = loss.detach()
            loss_all += data.num_graphs * loss.item()
            
            if all_labels is not None:
                all_labels = torch.cat([all_labels, label.detach().cpu()])
            else:
                all_labels = label.detach().cpu()
            
            _, output = torch.max(output, dim=1)
            if all_preds is not None:
                all_preds = torch.cat([all_preds, output.detach().cpu()])
            else:
                all_preds = output.detach().cpu()
    

    all_labels, all_preds = all_labels.numpy(), all_preds.numpy()
    wandb.log({"Test Loss": loss_all/len_dataset}, 
              step=epoch+1)
    
    return loss_all / len(valid_dataset), 0, all_preds, all_labels

# New general 'train_model' function which handles full training

In [None]:
def train_model(train_dataset,
                valid_dataset,
                crit,
                experiment_label):
    

    config = wandb.config

    train_loader = DataLoader(train_dataset, 
                              pin_memory=True,
                              batch_size=config.batch_size, 
                              shuffle=config.shuffle_data,
                              num_workers=config.n_workers
                             )
    
    valid_loader = DataLoader(valid_dataset, 
                              pin_memory=True,
                              batch_size=config.batch_size, 
                              shuffle=config.shuffle_data,
                              num_workers=config.n_workers,
                             )
    
    if not os.path.exists(f'plotting/{experiment_label}'):
        os.makedirs(f'plotting/{experiment_label}')
        print('Made it')
    

    if config.model == 'PaliwalNet':
        if 'M' in config.dataset:
            model = PaliwalNet(t=config.message_passing_steps, sigmoid=False, softmax=True)
        else:
            model = PaliwalNet(t=config.message_passing_steps)
    model = model.to(config.device)
    
    
    if config.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
    elif config.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        
    wandb.watch(model, log="all")
    
    
    # Run initial pass through validation loop
    valid_loss, valid_acc, preds, labels = test(model, valid_loader, 0, crit, config.device, len(valid_dataset))
    best_loss = valid_loss
    best_loss_epoch = 1
    wandb.log({"Best Loss": best_loss},
              step=1)
    
    df = pd.DataFrame(data={'Predictions': preds, 'Labels': labels})
    df['epoch'] = 0
    df.to_csv(f'plotting/{experiment_label}/{experiment_label}_epoch_0')
    
#     save_model(model, f'models/{experiment_label}_best_epoch_0')
    valid_losses, valid_accuracies = [valid_loss], [valid_acc]
    train_losses, train_accuracies = [], []
    
    # Train for n_epochs
    for epoch in tqdm(range(config.n_epochs)):
        epoch_loss, epoch_acc = train(model, train_loader, epoch, crit, optimizer, config.device, len(train_dataset))
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        # Every 10 epochs, run through a validation loop
        if epoch % 5 == 4:
            valid_loss, valid_acc, preds, labels = test(model, valid_loader, epoch, crit, config.device, len(valid_dataset))
            valid_losses.append(valid_loss)
            valid_accuracies.append(valid_acc)
            
            # Record the highest observed validation accuracy
            if valid_loss < best_loss:
                best_loss = valid_loss
                best_loss_epoch = epoch + 1
                
            wandb.log({"Best Loss": best_loss},
                      step=epoch+1)
                

            df = pd.DataFrame(data={'Predictions': preds, 'Labels': labels})
            df['epoch'] = epoch+1
            df.to_csv(f'plotting/{experiment_label}/{experiment_label}_epoch_{epoch+1}')
    
    
    # Output loss/acc stats to csv
    save_model(model, f'models/{experiment_label}_final_{config.n_epochs}')
    wandb.save(f'models/{experiment_label}_final_{config.n_epochs}')
    
    validation_stats = np.array([valid_losses, valid_accuracies])
    training_stats = np.array([train_losses, train_accuracies])
    np.savetxt(f'stats/{experiment_label}_validation_stats.csv', validation_stats, delimiter=',')
    np.savetxt(f'stats/{experiment_label}_training_stats.csv', training_stats, delimiter=',')
    

In [None]:
# Datasets: Binary/Multiclass (B/M), OnlyTop/Subtheorems (O/S), Merged/Unmerged (m/u)
dataset_features = {'n_classes': ['M', 'B'],
                   'theorems_used': ['S', 'O'],
                   'subexpression_sharing': ['u', 'm']}

distinct_features = dict()
distinct_features = pickle.load(open('distinct_features.p', 'rb'))

data = None
dataset_name = None
to_merge = None
binary = None
only_top = None
n_graphs = 30

seen = None


# def create_datasets():
#     global data
#     global dataset_name
#     global to_merge
#     global binary
    
#     for binary, x in enumerate(dataset_features['n_classes']):
#         for merged, z in enumerate(dataset_features['subexpression_sharing']):
#             dataset_name = x + z
# #                 data = u.make_data(binary=bool(binary), only_top=bool(only_top))
#             to_merge = bool(merged)
#             dataset = ProofDataset()
#             distinct_features[dataset_name] = get_num_distinct_features(data)
#             data_distribution = get_data_distribution(data)
#             print(f'{dataset_name}({len(data)}): ', data_distribution)
            

def create_datasets():
    global data
    global dataset_name
    global to_merge
    global binary
    
    for binary, x in enumerate(dataset_features['n_classes']):
        binary = 0
        x = 'M2'
        for merged, z in enumerate(dataset_features['subexpression_sharing']):
            dataset_name = x + z
            data = u.make_data(binary=bool(binary), only_top=False)
            data = list(set(data))
            data_dict = dict()
            for t,y in data:
                if t in data_dict.keys():
                    data_dict[t] = min(data_dict[t], y)
                else:
                    data_dict[t] = y
            
            data = [(t,y) for t,y in data_dict.items()]
            if binary == 1:
                data_0 = [(t,y) for t,y in data if y == 0]
                data_1 = [(t,y) for t,y in data if y == 1]
                min_len = min(len(data_0), len(data_1))
                data_0 = [(t,y) for t,y in data_0[:min_len]]
                data_1 = [(t,y) for t,y in data_1[:min_len]]
                data = data_0 + data_1
            else:
                temp_data = []
                for i in range(4):
                    temp_data.append([(t,y) for t,y in data if y == i])
                min_len = min([len(temp_data[j]) for j in range(4)])
                for i in range(4):
                    temp_data[i] = temp_data[i][:min_len]
                data = temp_data[0] + temp_data[1] + temp_data[2] + temp_data[3]
            
            data_distribution = get_data_distribution(data)
            to_merge = bool(merged)
            distinct_features[dataset_name] = get_num_distinct_features(data)
            print(data_distribution)
            dataset = TopLevelProofDataset()

            
            print(f'{dataset_name}({len(data)}): ', data_distribution)
        break
        

# create_datasets()
# pickle.dump(distinct_features, open( 'distinct_features.p', 'wb' ))
distinct_features = pickle.load(open('distinct_features.p', 'rb'))

In [None]:
dataset_name = 'M2u'
experiment_label = 'M2u0_ce'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 0,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=nn.CrossEntropyLoss(), experiment_label=experiment_label)

In [None]:
dataset_name = 'M2u'
experiment_label = 'M2u4_ce'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 4,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=nn.CrossEntropyLoss(), experiment_label=experiment_label)

In [None]:
dataset_name = 'M2m'
experiment_label = 'M2m0'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 0,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
assert 1 == 2

In [None]:
dataset_name = 'M2u'
experiment_label = 'M2u1'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps':1,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'M2u'
experiment_label = 'M2u2'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps':2,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'M2u'
experiment_label = 'M2u3'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 3,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'M2u'
experiment_label = 'M2u4'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 4,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
assert 1 == 2

In [None]:
dataset_name = 'M2m'
experiment_label = 'M2m1'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 1,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'M2m'
experiment_label = 'M2m2'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 2,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'M2m'
experiment_label = 'M2m3'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 3,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'M2m'
experiment_label = 'M2m4'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 4,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'SGD',
            'lr': 0.01,
            'momentum': 0.8,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
assert 1 == 2

In [None]:
dataset_name = 'Bu'
experiment_label = 'Bu12_2'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 12,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 250,
            'optimizer': 'SGD',
            'lr': 0.005,
            'momentum': 0.8,
            'weight_decay': None,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

# wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

# train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'Bm'
experiment_label = 'Bm1_3'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 1,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'Adam',
            'lr': 0.001,
            'momentum': None,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'Bm'
experiment_label = 'Bm4_3'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 4,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 101,
            'optimizer': 'Adam',
            'lr': 0.001,
            'momentum': None,
            'weight_decay': 0,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)

In [None]:
dataset_name = 'Bm'
experiment_label = 'Bm12_2'

config = {
            'model': 'PaliwalNet',
            'message_passing_steps': 12,
            'batch_size': 16,
            'n_workers': 4,
            'n_epochs': 250,
            'optimizer': 'SGD',
            'lr': 0.005,
            'momentum': 0.8,
            'weight_decay': None,
            'SWA': False,
            'device': 'cuda:1',
            'seed': 42,
            'dataset': dataset_name,
            'shuffle_data': True
        }

# wandb.init(name=experiment_label, project="complexity-prediction", config=config)

dataset = TopLevelProofDataset().shuffle()

train_dataset = dataset[:2*floor(len(dataset)/3)]
valid_dataset = dataset[2*floor(len(dataset)/3):]
print(f'Training Data: {len(train_dataset)}, Validation Data: {len(valid_dataset)}')

# train_model(train_dataset=train_dataset, valid_dataset=valid_dataset, crit=F.mse_loss, experiment_label=experiment_label)