In [1]:
import math
import torch
from torch_geometric.data import InMemoryDataset, Data, Batch
from torch.utils.data import DataLoader, Subset, random_split
from sklearn.calibration import LabelEncoder
import json 

  from .autonotebook import tqdm as notebook_tqdm


In [461]:
from torch import nn
from torch.nn import Linear, ReLU, LeakyReLU, init
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, GATv2Conv, TransformerConv, LayerNorm
import numpy as np
import torch.optim.lr_scheduler as lr_scheduler


In [3]:
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [556]:
DATASET_FILEPATH = "math_datagen_triplet_40k.json"
BATCH_SIZE = 256
TRAIN_RATIO = 0.9
VAL_RATIO = 0.1
DISTANCE_FUNC = F.pairwise_distance
margin = 0.8
rep = 16

In [234]:
TYPES = ["FUNC", "OPERATION", "POW", "LITERAL", "VARIABLE", "CONSTANT_LITERAL"]
OPERATIONS = ["ADD", "MUL", "FUNC", "POW"]
FUNCTIONS = ["SIN", "COS", "TAN", "EXP", "LOG", "SINH", "COSH", "TANH"]
VARIABLE_ALPHABET = [chr(x) for x in range(ord('a'), ord('z')+1) if chr(x) not in ["e", "i"]]
CLASSIC_CONSTANTS = ["PI", "I", "g","e", "zoo", "E"]

In [235]:

def dict_to_geometric_representation(in_graph_dict: dict, encoder) -> Data:
    node_list = []
    edge_mappings = []
    def traverse_graph(graph = in_graph_dict):
        nonlocal node_list
        nonlocal edge_mappings
        curr_node_index = len(node_list)
        encoded_data = encoder({"type": graph["type"], "subtype": graph["subtype"], "value": graph["value"]})
        # print("Encoded data: ", encoded_data)
        node_list.append(encoded_data)
        if "children" in graph.keys():
            for child in graph["children"]:
                edge_mappings.append((curr_node_index, traverse_graph(child)) ) #I'm retarded.
        return curr_node_index
    traverse_graph()
    nodes = torch.tensor(node_list,dtype=torch.float32)
    if not torch.all(torch.isfinite(nodes)):
        print("NODES CONTAIN INFINITE VALUE ---> \n")
        print("X = ", nodes)
    nodes[nodes == float("inf")] = 0
    nodes[nodes == -float("inf")] = 0
    if not torch.all(torch.isfinite(nodes)):
        print("MODIFIED NODES CONTAIN INFINITE VALUE ---> \n")
        print("X = ", nodes)
    edges = torch.tensor([[x[0] for x in edge_mappings], [x[1] for x in edge_mappings]], dtype=torch.long) # Probably slow and mentally degenerated
    geom_data = Data(x=nodes, edge_index=edges)
    return geom_data

#### Encoder

In [236]:
label_encoder = LabelEncoder()
label_encoder.fit(OPERATIONS+FUNCTIONS+CLASSIC_CONSTANTS+VARIABLE_ALPHABET+TYPES)
# --- misc
nan_count = 0 
sample_count = 0

In [237]:
def sanitize_value(value):
    if value in ["-Infinity","Infinity", float("inf"), float("-inf"), "inf", "-inf"]:
        return None
    return value

In [238]:
def node_attr_encoder(attr):
    global nan_count
    global sample_count
    
    def safe_label_encode(value):
        return label_encoder.transform([value])[0] if value else -1

    type_encoding = [safe_label_encode(attr.get("type"))] * rep
    subtype_encoding = [safe_label_encode(attr.get("subtype"))] * rep
    
    value_encoding_vec = [-1] * rep  # Default to all -1 if no valid value
    
    if sanitize_value(attr.get("value")) is not None:
        if isinstance(attr["value"], str):
            value_encoding = safe_label_encode(attr["value"])
            value_encoding_vec = [(0 if i % 2 == 0 else value_encoding) for i in range(rep)]
        else:
            try:
                val = float(attr["value"])
                if math.isfinite(val):
                    expon = math.log10(abs(val))
                    val = val // 10**int(expon)
                    value_encoding_vec = [(expon if i % 2 == 1 else val) for i in range(rep)]
                else:
                    raise ValueError("Non-finite value")
            except (ValueError, TypeError):
                nan_count += 1
                value_encoding_vec = [(-1 if i % 2 == 1 else 1) for i in range(rep)]
            finally:
                sample_count += 1

    result = type_encoding + subtype_encoding + value_encoding_vec
    return result

#### Loading in the dataset

In [239]:
triplet_dataset = []

In [240]:
# print('"value": Infinity,'.find("Infinity"))

In [241]:
with open(DATASET_FILEPATH) as file_handle:
    object_data = json.load(file_handle)
    for triplet in object_data:
        ## TODO : check if contains infinity (triplet to str and then search for Infinity)
        # print(triplet)
        # print(json.dumps(triplet).find("inf"))
        
        if json.dumps(triplet).find("inf") == -1:
            expr_true = triplet["expr_true"]
            expr_false = triplet["expr_false"]
            expr_anchor = triplet["expr_anchor"]
            geom_anchor = dict_to_geometric_representation(expr_anchor,node_attr_encoder )
            geom_true = dict_to_geometric_representation(expr_true,node_attr_encoder )
            geom_false = dict_to_geometric_representation(expr_false,node_attr_encoder )
            triplet_dataset.append((geom_anchor, geom_true, geom_false))
        
        

In [242]:
# with open("test","w") as f:
#     f.write(str(triplet_dataset))

#### Neural network training

In [294]:
num_features = triplet_dataset[0][0].num_features

In [504]:
class FormulaNet(nn.Module):
    def __init__(self, hidden_channels: int, embedding_space: int):
        super(FormulaNet, self).__init__()
        self.dense_1 = Linear(num_features, num_features*2) 
        self.dense_2 = Linear(num_features*2, hidden_channels) 
        
        self.relu_1 = LeakyReLU(0.2)
        self.gconv_1 = GCNConv(num_features*2, hidden_channels)
        self.gconv_2 = GCNConv(hidden_channels, hidden_channels)
        self.trans1 = TransformerConv(hidden_channels, hidden_channels, heads=1)
        self.trans2 = TransformerConv(hidden_channels, hidden_channels, heads=1)
        self.layer_norm = LayerNorm(in_channels=hidden_channels)
        
        self.dense_3 = Linear(hidden_channels, embedding_space)
        gain = torch.nn.init.calculate_gain("leaky_relu", 0.2)
        # Initialize dense_1 weights
        init.xavier_uniform_(self.dense_1.weight, gain=gain)
        init.xavier_uniform_(self.dense_2.weight, gain=gain)
        # # Initialize gconv_1 weights
        init.xavier_uniform_(self.gconv_1.lin.weight.data.T, gain=gain)
        # # Initialize gconv_2 weights
        init.xavier_uniform_(self.gconv_2.lin.weight.data.T, gain=gain)
        #Initialize dense_3 weights
        init.xavier_uniform_(self.dense_3.weight, gain=gain)
    
    def forward(self, x, edge_index, batch):
        if not torch.all(torch.isfinite(x)):
            print("INPUT CONTAINS INFINITE VALUE")
            #print("x = ", x)
            


        nan_mask = torch.isnan(x)
        if nan_mask.any():
            print("Some values of the input are nan: ", x[nan_mask])
            # Replace NaNs with 0s
            x[nan_mask] = -11
            
        
        xd = self.dense_1(x)
        xd = self.relu_1(xd)
        
        xd = self.dense_2(xd)
        xd = self.relu_1(xd)
        xc = xd
        # xc = self.gconv_1(xd, edge_index)
        # xc = self.relu_1(xc)
        # xc = self.gconv_2(xc, edge_index)
        # xc = self.relu_1(xc)
        xt = self.trans1(xc, edge_index)
        xt = self.layer_norm(xt)
        xt = self.relu_1(xt) + xc
        
        xt2 = self.trans2(xt, edge_index)
        xt2 = self.layer_norm(xt2)
        xt2 = self.relu_1(xt2) + xt
        
        # xc = self.gconv_3(xc, edge_index)
        # xc = self.relu_1(xc)
        xp = global_mean_pool(xt, batch)
        xp = F.dropout(xp, p=0.4,training=self.training)
        xp = self.dense_3(xp)
        xp = self.relu_1(xp)
        # nan_mask = torch.isnan(x)
        # if nan_mask.any():
        #     print("Some values of the input are nan: ", x[nan_mask])
            
        return xp
    

In [505]:
class TripletFormulaNet(nn.Module):
    def __init__(self, hidden_channels, embedding_space):
        super(TripletFormulaNet, self).__init__()
        self.formulanet = FormulaNet(hidden_channels, embedding_space)

    def forward(self, anchor, positive, negative):
        # Compute the embeddings for the anchor, positive, and negative inputs
        embed_anchor = self.formulanet(anchor.x, anchor.edge_index, anchor.batch)
        embed_positive = self.formulanet(positive.x, positive.edge_index, positive.batch)
        embed_negative = self.formulanet(negative.x, negative.edge_index, negative.batch)
        
        return embed_anchor, embed_positive, embed_negative

#### Training

In [506]:
# ## CHECK DATASET:
# with open("test", "w") as f:
#     for a, b,c in triplet_dataset:
#         f.write(str(a.x) + "\n" + str(b.x) + "\n" + str(c.x) + "\n")

In [557]:
def triplet_loss(anchor, positive, negative, margin=1.0):
    """
    Triplet loss function.
    
    Args:
        anchor (torch.Tensor): Embedding of the anchor data point.
        positive (torch.Tensor): Embedding of the positive data point (same class as anchor).
        negative (torch.Tensor): Embedding of the negative data point (different class from anchor).
        margin (float, optional): Margin value for triplet loss. Defaults to 1.0.
        
    Returns:
        torch.Tensor: Triplet loss value.
    """
    
    # Calculate the Euclidean distances between the embeddings
    distance_positive = DISTANCE_FUNC(anchor, positive)
    distance_negative = DISTANCE_FUNC(anchor, negative)
    
    # Calculate the triplet loss
    losses = F.relu(distance_positive - distance_negative + margin)
    
    return losses.mean()

In [508]:
total_len = len(triplet_dataset)
train_len = int(TRAIN_RATIO * total_len)
val_len = int(VAL_RATIO * total_len)
test_len = total_len - train_len - val_len

train_dataset, val_dataset, test_dataset = random_split(triplet_dataset, [train_len, val_len, test_len])

In [509]:
def collate(data_list):
    batchA = Batch.from_data_list([data[0] for data in data_list])
    batchB = Batch.from_data_list([data[1] for data in data_list])
    batchC = Batch.from_data_list([data[2] for data in data_list])
    return batchA, batchB, batchC

In [510]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
validation_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle= False, collate_fn=collate)

In [511]:
device = torch.device("cpu") #torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [512]:
model = TripletFormulaNet(32,32).to(device)

In [513]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.08)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=20)

In [514]:
from tqdm import tqdm

In [515]:
training_loss = []  # Store training loss per epoch
validation_loss = []  # Store validation loss per epoch

In [558]:
def train(epoch):
    model.train()
    epoch_loss = 0
    for batch_anchor, batch_positive, batch_negative in tqdm(train_loader, desc=f'Epoch {epoch}'):

        batch_anchor, batch_positive, batch_negative = batch_anchor.to(device), batch_positive.to(device), batch_negative.to(device)
        optimizer.zero_grad()
        embed_anchor, embed_positive, embed_negative = model(batch_anchor, batch_positive, batch_negative)
        loss = triplet_loss(embed_anchor, embed_positive, embed_negative, margin=margin)
        loss.backward()
        writer.add_scalar("Loss/train", loss, epoch)
        
        # Log gradients
        for name, param in model.formulanet.named_parameters():
            if param.grad is not None and param.grad.numel() > 0:
                writer.add_histogram(f'{name}.grad', param.grad, epoch)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        epoch_loss += loss.item()
    scheduler.step(epoch)
    training_loss.append(epoch_loss / len(train_loader))  # Store epoch average loss
    print(f'Epoch {epoch}, Loss: {epoch_loss / len(train_loader)}')

In [665]:
def validate(model, validation_loader, device):
    model.eval()
    triplet_loss_tot = 0
    correct_triplets = 0
    total_triplets = 0
    total_dist_pos = 0
    total_dist_neg = 0
    

    with torch.no_grad():
        for batch_anchor, batch_positive, batch_negative in validation_loader:
            batch_anchor, batch_positive, batch_negative = batch_anchor.to(device), batch_positive.to(device), batch_negative.to(device)
            embed_anchor, embed_positive, embed_negative = model(batch_anchor, batch_positive, batch_negative)

            # Compute triplet loss
            loss = triplet_loss(embed_anchor, embed_positive, embed_negative, margin=margin)
            triplet_loss_tot += loss.item() * batch_anchor.size(0)

            # Compute triplet accuracy
            distance_positive = DISTANCE_FUNC(embed_anchor, embed_positive)
            total_dist_pos += distance_positive.sum().item()
            distance_negative = DISTANCE_FUNC(embed_anchor, embed_negative)
            total_dist_neg += distance_negative.sum().item()
            correct_triplets += (distance_positive < distance_negative).sum().item() ## WARNING: FLIP THIS FOR PAIRWISE DISTANCE
            total_triplets += distance_positive.size(0)

    triplet_loss_tot /= len(validation_loader.dataset)
    triplet_accuracy = correct_triplets / total_triplets

    print(f'Validation Triplet Loss: {triplet_loss_tot}, Total Positive: {total_dist_pos}, Total Negative: {total_dist_neg}, Triplet Accuracy: {triplet_accuracy}')

    return triplet_loss_tot, triplet_accuracy

In [666]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [667]:
num_epochs = 180

for epoch in range(num_epochs):
    train(epoch)
    validate(model, validation_loader, device)

Epoch 0: 100%|██████████| 242/242 [00:21<00:00, 11.51it/s]


Epoch 0, Loss: 30.063697422092613
Validation Triplet Loss: 0.0, Total Positive: 7992883.0, Total Negative: 344426187.0, Triplet Accuracy: 1.0


Epoch 1: 100%|██████████| 242/242 [00:28<00:00,  8.59it/s]


Epoch 1, Loss: 481.9965731122277
Validation Triplet Loss: 230.98377247522043, Total Positive: 27894855.46875, Total Negative: 276834152.0, Triplet Accuracy: 0.9988372093023256


Epoch 2: 100%|██████████| 242/242 [00:25<00:00,  9.56it/s]


Epoch 2, Loss: 161.93854469922948
Validation Triplet Loss: 13.006577201152957, Total Positive: 31894754.03125, Total Negative: 488497330.0, Triplet Accuracy: 0.9995639534883721


Epoch 3: 100%|██████████| 242/242 [00:23<00:00, 10.49it/s]


Epoch 3, Loss: 71.02616464469797
Validation Triplet Loss: 39.67728822952093, Total Positive: 24611526.40625, Total Negative: 383518867.0, Triplet Accuracy: 0.9995639534883721


Epoch 4: 100%|██████████| 242/242 [00:23<00:00, 10.51it/s]


Epoch 4, Loss: 52.593781448531246
Validation Triplet Loss: 163.06359451427016, Total Positive: 27953039.25, Total Negative: 422779087.0, Triplet Accuracy: 0.9991279069767441


Epoch 5: 100%|██████████| 242/242 [00:21<00:00, 11.10it/s]


Epoch 5, Loss: 35.514980834624
Validation Triplet Loss: 0.06969922392000985, Total Positive: 23358191.1875, Total Negative: 623367506.0, Triplet Accuracy: 0.9998546511627907


Epoch 6: 100%|██████████| 242/242 [00:22<00:00, 10.99it/s]


Epoch 6, Loss: 23.22179131985696
Validation Triplet Loss: 0.0, Total Positive: 28204556.0, Total Negative: 512559648.0, Triplet Accuracy: 1.0


Epoch 7: 100%|██████████| 242/242 [00:21<00:00, 11.31it/s]


Epoch 7, Loss: 17.389588415314837
Validation Triplet Loss: 0.0, Total Positive: 24106878.875, Total Negative: 508957399.0, Triplet Accuracy: 1.0


Epoch 8: 100%|██████████| 242/242 [00:21<00:00, 11.37it/s]


Epoch 8, Loss: 217.73738372030337
Validation Triplet Loss: 0.0, Total Positive: 31487323.25, Total Negative: 600619147.0, Triplet Accuracy: 1.0


Epoch 9: 100%|██████████| 242/242 [00:21<00:00, 11.08it/s]


Epoch 9, Loss: 15.38154704142208
Validation Triplet Loss: 2.1916814277338426, Total Positive: 22137353.1875, Total Negative: 475578692.0, Triplet Accuracy: 0.9997093023255814


Epoch 10: 100%|██████████| 242/242 [00:21<00:00, 11.08it/s]


Epoch 10, Loss: 12.777739079530575
Validation Triplet Loss: 0.0, Total Positive: 29297259.1875, Total Negative: 538301194.0, Triplet Accuracy: 1.0


Epoch 11: 100%|██████████| 242/242 [00:20<00:00, 11.62it/s]


Epoch 11, Loss: 9.441539081660183
Validation Triplet Loss: 0.0, Total Positive: 20829352.40625, Total Negative: 570778992.0, Triplet Accuracy: 1.0


Epoch 12: 100%|██████████| 242/242 [00:24<00:00, 10.06it/s]


Epoch 12, Loss: 15.538283782557022
Validation Triplet Loss: 18.902538000151168, Total Positive: 24916791.5625, Total Negative: 493266826.0, Triplet Accuracy: 0.9997093023255814


Epoch 13: 100%|██████████| 242/242 [00:25<00:00,  9.42it/s]


Epoch 13, Loss: 6.530799108842188
Validation Triplet Loss: 0.0, Total Positive: 20598192.53125, Total Negative: 555266644.0, Triplet Accuracy: 1.0


Epoch 14: 100%|██████████| 242/242 [00:25<00:00,  9.51it/s]


Epoch 14, Loss: 0.5559566112096644
Validation Triplet Loss: 589.6643986435823, Total Positive: 17791293.375, Total Negative: 540516440.0, Triplet Accuracy: 0.9991279069767441


Epoch 15: 100%|██████████| 242/242 [00:24<00:00,  9.90it/s]


Epoch 15, Loss: 2.282290394641151
Validation Triplet Loss: 0.0, Total Positive: 17950666.75, Total Negative: 580590590.0, Triplet Accuracy: 1.0


Epoch 16: 100%|██████████| 242/242 [00:25<00:00,  9.57it/s]


Epoch 16, Loss: 0.9761088911659461
Validation Triplet Loss: 0.0, Total Positive: 15244991.5625, Total Negative: 527902710.0, Triplet Accuracy: 1.0


Epoch 17: 100%|██████████| 242/242 [00:24<00:00, 10.03it/s]


Epoch 17, Loss: 2.217487446779062
Validation Triplet Loss: 0.0, Total Positive: 15809339.5, Total Negative: 633774654.0, Triplet Accuracy: 1.0


Epoch 18: 100%|██████████| 242/242 [00:25<00:00,  9.60it/s]


Epoch 18, Loss: 0.8506611033904651
Validation Triplet Loss: 0.0, Total Positive: 14238225.46875, Total Negative: 575638500.0, Triplet Accuracy: 1.0


Epoch 19: 100%|██████████| 242/242 [00:25<00:00,  9.47it/s]


Epoch 19, Loss: 0.0963055338741334
Validation Triplet Loss: 0.0, Total Positive: 13006113.09375, Total Negative: 511959101.0, Triplet Accuracy: 1.0


Epoch 20: 100%|██████████| 242/242 [00:24<00:00,  9.71it/s]


Epoch 20, Loss: 123.38641229739859
Validation Triplet Loss: 0.0, Total Positive: 12836765.984375, Total Negative: 550946918.0, Triplet Accuracy: 1.0


Epoch 21: 100%|██████████| 242/242 [00:25<00:00,  9.50it/s]


Epoch 21, Loss: 0.305904793541509
Validation Triplet Loss: 0.0, Total Positive: 12849269.046875, Total Negative: 563387512.0, Triplet Accuracy: 1.0


Epoch 22: 100%|██████████| 242/242 [00:25<00:00,  9.64it/s]


Epoch 22, Loss: 0.0
Validation Triplet Loss: 0.0, Total Positive: 12867070.953125, Total Negative: 564021614.0, Triplet Accuracy: 1.0


Epoch 23: 100%|██████████| 242/242 [00:25<00:00,  9.45it/s]


Epoch 23, Loss: 0.21064087970197692
Validation Triplet Loss: 0.0, Total Positive: 12985252.421875, Total Negative: 569424738.0, Triplet Accuracy: 1.0


Epoch 24: 100%|██████████| 242/242 [00:25<00:00,  9.64it/s]


Epoch 24, Loss: 0.17763658988574318
Validation Triplet Loss: 0.0, Total Positive: 12982321.609375, Total Negative: 573304578.0, Triplet Accuracy: 1.0


Epoch 25: 100%|██████████| 242/242 [00:27<00:00,  8.73it/s]


Epoch 25, Loss: 0.11784408112202795
Validation Triplet Loss: 0.0, Total Positive: 12887673.296875, Total Negative: 569723678.0, Triplet Accuracy: 1.0


Epoch 26: 100%|██████████| 242/242 [00:20<00:00, 11.90it/s]


Epoch 26, Loss: 0.12685945112843158
Validation Triplet Loss: 0.0, Total Positive: 12622971.796875, Total Negative: 566891288.0, Triplet Accuracy: 1.0


Epoch 27: 100%|██████████| 242/242 [00:20<00:00, 11.83it/s]


Epoch 27, Loss: 0.12632038967668519
Validation Triplet Loss: 0.0, Total Positive: 12650957.515625, Total Negative: 567275488.0, Triplet Accuracy: 1.0


Epoch 28: 100%|██████████| 242/242 [00:20<00:00, 11.97it/s]


Epoch 28, Loss: 0.07995320942776263
Validation Triplet Loss: 0.0, Total Positive: 12713020.515625, Total Negative: 562902050.0, Triplet Accuracy: 1.0


Epoch 29: 100%|██████████| 242/242 [00:20<00:00, 11.87it/s]


Epoch 29, Loss: 10.302037392766023
Validation Triplet Loss: 0.0, Total Positive: 12537609.390625, Total Negative: 567839508.0, Triplet Accuracy: 1.0


Epoch 30: 100%|██████████| 242/242 [00:20<00:00, 11.60it/s]


Epoch 30, Loss: 0.32270776240293647
Validation Triplet Loss: 0.0, Total Positive: 12331730.5625, Total Negative: 561376348.0, Triplet Accuracy: 1.0


Epoch 31: 100%|██████████| 242/242 [00:21<00:00, 11.00it/s]


Epoch 31, Loss: 0.5612136608312938
Validation Triplet Loss: 0.0, Total Positive: 11797181.890625, Total Negative: 544152732.0, Triplet Accuracy: 1.0


Epoch 32: 100%|██████████| 242/242 [00:21<00:00, 11.44it/s]


Epoch 32, Loss: 0.0
Validation Triplet Loss: 0.0, Total Positive: 11811612.21875, Total Negative: 548568328.0, Triplet Accuracy: 1.0


Epoch 33: 100%|██████████| 242/242 [00:19<00:00, 12.18it/s]


Epoch 33, Loss: 0.11422355312946414
Validation Triplet Loss: 0.0, Total Positive: 11800251.09375, Total Negative: 539813376.0, Triplet Accuracy: 1.0


Epoch 34: 100%|██████████| 242/242 [00:19<00:00, 12.22it/s]


Epoch 34, Loss: 0.6316942676039766
Validation Triplet Loss: 0.0, Total Positive: 11134501.84375, Total Negative: 545601266.0, Triplet Accuracy: 1.0


Epoch 35: 100%|██████████| 242/242 [00:19<00:00, 12.25it/s]


Epoch 35, Loss: 98.1524522375962
Validation Triplet Loss: 0.0, Total Positive: 10834984.125, Total Negative: 532480482.0, Triplet Accuracy: 1.0


Epoch 36: 100%|██████████| 242/242 [00:19<00:00, 12.35it/s]


Epoch 36, Loss: 0.05073478694789666
Validation Triplet Loss: 0.0, Total Positive: 10903902.0, Total Negative: 533587226.0, Triplet Accuracy: 1.0


Epoch 37: 100%|██████████| 242/242 [00:19<00:00, 12.39it/s]


Epoch 37, Loss: 0.031490868773342166
Validation Triplet Loss: 0.0, Total Positive: 10920936.0, Total Negative: 533448994.0, Triplet Accuracy: 1.0


Epoch 38: 100%|██████████| 242/242 [00:19<00:00, 12.36it/s]


Epoch 38, Loss: 0.09310223445419438
Validation Triplet Loss: 0.0, Total Positive: 10815254.78125, Total Negative: 530088638.0, Triplet Accuracy: 1.0


Epoch 39: 100%|██████████| 242/242 [00:19<00:00, 12.46it/s]


Epoch 39, Loss: 0.14928761395541104
Validation Triplet Loss: 0.0, Total Positive: 10762144.609375, Total Negative: 528344314.0, Triplet Accuracy: 1.0


Epoch 40: 100%|██████████| 242/242 [00:19<00:00, 12.44it/s]


Epoch 40, Loss: 0.017483013720551797
Validation Triplet Loss: 0.0, Total Positive: 10718001.609375, Total Negative: 525901648.0, Triplet Accuracy: 1.0


Epoch 41: 100%|██████████| 242/242 [00:19<00:00, 12.32it/s]


Epoch 41, Loss: 0.16914404423768856
Validation Triplet Loss: 0.0, Total Positive: 10562555.3125, Total Negative: 519711588.0, Triplet Accuracy: 1.0


Epoch 42: 100%|██████████| 242/242 [00:19<00:00, 12.25it/s]


Epoch 42, Loss: 188.15617702026998
Validation Triplet Loss: 0.0, Total Positive: 10298843.125, Total Negative: 476628623.0, Triplet Accuracy: 1.0


Epoch 43: 100%|██████████| 242/242 [00:19<00:00, 12.12it/s]


Epoch 43, Loss: 0.1415820496141418
Validation Triplet Loss: 0.0, Total Positive: 10383622.546875, Total Negative: 493107421.0, Triplet Accuracy: 1.0


Epoch 44: 100%|██████████| 242/242 [00:19<00:00, 12.28it/s]


Epoch 44, Loss: 0.05883367396583242
Validation Triplet Loss: 0.0, Total Positive: 10553664.109375, Total Negative: 487029086.0, Triplet Accuracy: 1.0


Epoch 45: 100%|██████████| 242/242 [00:19<00:00, 12.28it/s]


Epoch 45, Loss: 0.10099674970650475
Validation Triplet Loss: 0.0, Total Positive: 10401361.765625, Total Negative: 515963352.0, Triplet Accuracy: 1.0


Epoch 46: 100%|██████████| 242/242 [00:19<00:00, 12.42it/s]


Epoch 46, Loss: 0.02045152601131723
Validation Triplet Loss: 0.0, Total Positive: 10588242.375, Total Negative: 521460280.0, Triplet Accuracy: 1.0


Epoch 47: 100%|██████████| 242/242 [00:19<00:00, 12.24it/s]


Epoch 47, Loss: 187.00983363069778
Validation Triplet Loss: 0.0, Total Positive: 10074548.484375, Total Negative: 478343705.0, Triplet Accuracy: 1.0


Epoch 48: 100%|██████████| 242/242 [00:19<00:00, 12.14it/s]


Epoch 48, Loss: 0.17456277539907408
Validation Triplet Loss: 0.0, Total Positive: 9761718.75, Total Negative: 478482020.0, Triplet Accuracy: 1.0


Epoch 49: 100%|██████████| 242/242 [00:19<00:00, 12.31it/s]


Epoch 49, Loss: 0.3299542360069338
Validation Triplet Loss: 0.0, Total Positive: 9937380.1875, Total Negative: 489701593.0, Triplet Accuracy: 1.0


Epoch 50: 100%|██████████| 242/242 [00:19<00:00, 12.42it/s]


Epoch 50, Loss: 0.10998728649675353
Validation Triplet Loss: 0.0, Total Positive: 9772401.296875, Total Negative: 477137252.0, Triplet Accuracy: 1.0


Epoch 51: 100%|██████████| 242/242 [00:19<00:00, 12.49it/s]


Epoch 51, Loss: 0.216432752688069
Validation Triplet Loss: 0.0, Total Positive: 9656247.328125, Total Negative: 474152201.0, Triplet Accuracy: 1.0


Epoch 52: 100%|██████████| 242/242 [00:19<00:00, 12.29it/s]


Epoch 52, Loss: 68.07302901198057
Validation Triplet Loss: 0.0, Total Positive: 9624980.28125, Total Negative: 475716550.0, Triplet Accuracy: 1.0


Epoch 53: 100%|██████████| 242/242 [00:19<00:00, 12.38it/s]


Epoch 53, Loss: 0.25906530022621155
Validation Triplet Loss: 0.0, Total Positive: 9563944.53125, Total Negative: 468252311.0, Triplet Accuracy: 1.0


Epoch 54: 100%|██████████| 242/242 [00:19<00:00, 12.20it/s]


Epoch 54, Loss: 156.95252089362498
Validation Triplet Loss: 0.0, Total Positive: 9200224.359375, Total Negative: 433865363.0, Triplet Accuracy: 1.0


Epoch 55: 100%|██████████| 242/242 [00:19<00:00, 12.10it/s]


Epoch 55, Loss: 312.1778996030161
Validation Triplet Loss: 0.0, Total Positive: 8679486.515625, Total Negative: 393340036.0, Triplet Accuracy: 1.0


Epoch 56: 100%|██████████| 242/242 [00:19<00:00, 12.25it/s]


Epoch 56, Loss: 0.4492012991392908
Validation Triplet Loss: 0.0, Total Positive: 9134153.171875, Total Negative: 428739728.0, Triplet Accuracy: 1.0


Epoch 57: 100%|██████████| 242/242 [00:19<00:00, 12.32it/s]


Epoch 57, Loss: 0.40769687762930373
Validation Triplet Loss: 0.0, Total Positive: 9149413.65625, Total Negative: 442348357.0, Triplet Accuracy: 1.0


Epoch 58: 100%|██████████| 242/242 [00:19<00:00, 12.19it/s]


Epoch 58, Loss: 142.7561185576699
Validation Triplet Loss: 0.0, Total Positive: 8663516.234375, Total Negative: 396174833.0, Triplet Accuracy: 1.0


Epoch 59: 100%|██████████| 242/242 [00:19<00:00, 12.34it/s]


Epoch 59, Loss: 0.344208105782832
Validation Triplet Loss: 0.0, Total Positive: 8778934.15625, Total Negative: 404579783.0, Triplet Accuracy: 1.0


Epoch 60: 100%|██████████| 242/242 [00:19<00:00, 12.23it/s]


Epoch 60, Loss: 3.353623401034962
Validation Triplet Loss: 0.0, Total Positive: 8920316.359375, Total Negative: 421101874.0, Triplet Accuracy: 1.0


Epoch 61: 100%|██████████| 242/242 [00:19<00:00, 12.15it/s]


Epoch 61, Loss: 0.43152371024297287
Validation Triplet Loss: 0.0, Total Positive: 8838177.640625, Total Negative: 434052991.0, Triplet Accuracy: 1.0


Epoch 62: 100%|██████████| 242/242 [00:20<00:00, 12.00it/s]


Epoch 62, Loss: 0.2218354359145992
Validation Triplet Loss: 0.0, Total Positive: 8589375.109375, Total Negative: 412919661.0, Triplet Accuracy: 1.0


Epoch 63: 100%|██████████| 242/242 [00:19<00:00, 12.23it/s]


Epoch 63, Loss: 0.08682393165659313
Validation Triplet Loss: 0.0, Total Positive: 8642870.34375, Total Negative: 408148964.0, Triplet Accuracy: 1.0


Epoch 64: 100%|██████████| 242/242 [00:19<00:00, 12.33it/s]


Epoch 64, Loss: 0.1237217611517788
Validation Triplet Loss: 0.0, Total Positive: 8716623.859375, Total Negative: 428853808.0, Triplet Accuracy: 1.0


Epoch 65: 100%|██████████| 242/242 [00:19<00:00, 12.27it/s]


Epoch 65, Loss: 0.10503851185160235
Validation Triplet Loss: 0.0, Total Positive: 8740213.828125, Total Negative: 424552369.0, Triplet Accuracy: 1.0


Epoch 66: 100%|██████████| 242/242 [00:19<00:00, 12.14it/s]


Epoch 66, Loss: 0.4574916737138732
Validation Triplet Loss: 0.0, Total Positive: 8490254.734375, Total Negative: 393150539.0, Triplet Accuracy: 1.0


Epoch 67: 100%|██████████| 242/242 [00:19<00:00, 12.21it/s]


Epoch 67, Loss: 66.69498234346878
Validation Triplet Loss: 0.0, Total Positive: 8351273.0, Total Negative: 374761691.0, Triplet Accuracy: 1.0


Epoch 68: 100%|██████████| 242/242 [00:20<00:00, 12.02it/s]


Epoch 68, Loss: 49.00338014319909
Validation Triplet Loss: 0.0, Total Positive: 8398996.65625, Total Negative: 373999564.0, Triplet Accuracy: 1.0


Epoch 69: 100%|██████████| 242/242 [00:19<00:00, 12.13it/s]


Epoch 69, Loss: 132.1957866960813
Validation Triplet Loss: 0.0, Total Positive: 8413490.40625, Total Negative: 381585419.0, Triplet Accuracy: 1.0


Epoch 70: 100%|██████████| 242/242 [00:19<00:00, 12.25it/s]


Epoch 70, Loss: 0.13913355977082056
Validation Triplet Loss: 0.0, Total Positive: 8627046.65625, Total Negative: 421870422.0, Triplet Accuracy: 1.0


Epoch 71: 100%|██████████| 242/242 [00:19<00:00, 12.13it/s]


Epoch 71, Loss: 0.09367724415684535
Validation Triplet Loss: 0.0, Total Positive: 8465894.578125, Total Negative: 407259129.0, Triplet Accuracy: 1.0


Epoch 72: 100%|██████████| 242/242 [00:19<00:00, 12.12it/s]


Epoch 72, Loss: 138.0785608439406
Validation Triplet Loss: 0.0, Total Positive: 8174941.953125, Total Negative: 380511242.0, Triplet Accuracy: 1.0


Epoch 73: 100%|██████████| 242/242 [00:20<00:00, 12.06it/s]


Epoch 73, Loss: 0.07059924292170312
Validation Triplet Loss: 0.0, Total Positive: 8381549.921875, Total Negative: 409905260.0, Triplet Accuracy: 1.0


Epoch 74: 100%|██████████| 242/242 [00:19<00:00, 12.12it/s]


Epoch 74, Loss: 0.23156096915568203
Validation Triplet Loss: 0.0, Total Positive: 8824801.671875, Total Negative: 443636502.0, Triplet Accuracy: 1.0


Epoch 75: 100%|██████████| 242/242 [00:19<00:00, 12.19it/s]


Epoch 75, Loss: 0.03624686327847568
Validation Triplet Loss: 0.0, Total Positive: 8741415.96875, Total Negative: 437128636.0, Triplet Accuracy: 1.0


Epoch 76: 100%|██████████| 242/242 [00:19<00:00, 12.15it/s]


Epoch 76, Loss: 139.5673304894739
Validation Triplet Loss: 0.0, Total Positive: 8429485.984375, Total Negative: 415422393.0, Triplet Accuracy: 1.0


Epoch 77: 100%|██████████| 242/242 [00:19<00:00, 12.11it/s]


Epoch 77, Loss: 2.009577277524412
Validation Triplet Loss: 0.0, Total Positive: 8488319.21875, Total Negative: 430906506.0, Triplet Accuracy: 1.0


Epoch 78: 100%|██████████| 242/242 [00:19<00:00, 12.24it/s]


Epoch 78, Loss: 0.06778237051215054
Validation Triplet Loss: 0.0, Total Positive: 8451754.09375, Total Negative: 428132044.0, Triplet Accuracy: 1.0


Epoch 79: 100%|██████████| 242/242 [00:19<00:00, 12.37it/s]


Epoch 79, Loss: 49.94875620514893
Validation Triplet Loss: 0.0, Total Positive: 8468994.09375, Total Negative: 428748382.0, Triplet Accuracy: 1.0


Epoch 80: 100%|██████████| 242/242 [00:19<00:00, 12.25it/s]


Epoch 80, Loss: 150.65626709741994
Validation Triplet Loss: 0.0, Total Positive: 8357694.765625, Total Negative: 409525848.0, Triplet Accuracy: 1.0


Epoch 81: 100%|██████████| 242/242 [00:19<00:00, 12.24it/s]


Epoch 81, Loss: 0.19784007092152744
Validation Triplet Loss: 0.0, Total Positive: 8492979.3125, Total Negative: 407923597.0, Triplet Accuracy: 1.0


Epoch 82: 100%|██████████| 242/242 [00:19<00:00, 12.35it/s]


Epoch 82, Loss: 0.0
Validation Triplet Loss: 0.0, Total Positive: 8492981.96875, Total Negative: 407923704.0, Triplet Accuracy: 1.0


Epoch 83: 100%|██████████| 242/242 [00:19<00:00, 12.21it/s]


Epoch 83, Loss: 0.07972901714734795
Validation Triplet Loss: 0.0, Total Positive: 8271916.578125, Total Negative: 389935737.0, Triplet Accuracy: 1.0


Epoch 84: 100%|██████████| 242/242 [00:19<00:00, 12.42it/s]


Epoch 84, Loss: 0.0
Validation Triplet Loss: 0.0, Total Positive: 8271882.171875, Total Negative: 389931166.0, Triplet Accuracy: 1.0


Epoch 85: 100%|██████████| 242/242 [00:20<00:00, 11.67it/s]


Epoch 85, Loss: 0.022066949812833928
Validation Triplet Loss: 0.0, Total Positive: 8105545.3125, Total Negative: 376466510.0, Triplet Accuracy: 1.0


Epoch 86: 100%|██████████| 242/242 [00:21<00:00, 11.13it/s]


Epoch 86, Loss: 0.19091516880949666
Validation Triplet Loss: 0.0, Total Positive: 7867671.953125, Total Negative: 356695558.0, Triplet Accuracy: 1.0


Epoch 87: 100%|██████████| 242/242 [00:19<00:00, 12.32it/s]


Epoch 87, Loss: 0.12207499811471986
Validation Triplet Loss: 0.0, Total Positive: 8186802.1875, Total Negative: 390463867.0, Triplet Accuracy: 1.0


Epoch 88: 100%|██████████| 242/242 [00:19<00:00, 12.33it/s]


Epoch 88, Loss: 97.81510685692149
Validation Triplet Loss: 0.0, Total Positive: 7773039.5625, Total Negative: 355209714.0, Triplet Accuracy: 1.0


Epoch 89: 100%|██████████| 242/242 [00:19<00:00, 12.11it/s]


Epoch 89, Loss: 85.94098244028643
Validation Triplet Loss: 0.0, Total Positive: 7959421.671875, Total Negative: 390880818.0, Triplet Accuracy: 1.0


Epoch 90: 100%|██████████| 242/242 [00:19<00:00, 12.15it/s]


Epoch 90, Loss: 0.10013107819990678
Validation Triplet Loss: 0.0, Total Positive: 7726174.0, Total Negative: 370251448.0, Triplet Accuracy: 1.0


Epoch 91: 100%|██████████| 242/242 [00:20<00:00, 12.01it/s]


Epoch 91, Loss: 102.81228122139765
Validation Triplet Loss: 0.0, Total Positive: 7605325.09375, Total Negative: 363379844.0, Triplet Accuracy: 1.0


Epoch 92: 100%|██████████| 242/242 [00:20<00:00, 12.05it/s]


Epoch 92, Loss: 95.45267589624262
Validation Triplet Loss: 0.0, Total Positive: 7328344.65625, Total Negative: 329236458.0, Triplet Accuracy: 1.0


Epoch 93: 100%|██████████| 242/242 [00:19<00:00, 12.10it/s]


Epoch 93, Loss: 0.4402388592889486
Validation Triplet Loss: 0.0, Total Positive: 7769259.140625, Total Negative: 362695988.0, Triplet Accuracy: 1.0


Epoch 94: 100%|██████████| 242/242 [00:19<00:00, 12.21it/s]


Epoch 94, Loss: 0.06873913343287696
Validation Triplet Loss: 0.0, Total Positive: 7904603.75, Total Negative: 380053861.0, Triplet Accuracy: 1.0


Epoch 95: 100%|██████████| 242/242 [00:19<00:00, 12.17it/s]


Epoch 95, Loss: 41.577652278271586
Validation Triplet Loss: 0.0, Total Positive: 7729533.578125, Total Negative: 356714454.0, Triplet Accuracy: 1.0


Epoch 96: 100%|██████████| 242/242 [00:19<00:00, 12.22it/s]


Epoch 96, Loss: 0.20242705448599888
Validation Triplet Loss: 0.0, Total Positive: 7851305.203125, Total Negative: 363697983.0, Triplet Accuracy: 1.0


Epoch 97: 100%|██████████| 242/242 [00:20<00:00, 12.06it/s]


Epoch 97, Loss: 96.83655119558011
Validation Triplet Loss: 0.0, Total Positive: 7916944.203125, Total Negative: 376870741.0, Triplet Accuracy: 1.0


Epoch 98: 100%|██████████| 242/242 [00:19<00:00, 12.28it/s]


Epoch 98, Loss: 0.07282782290592667
Validation Triplet Loss: 0.0, Total Positive: 7949535.984375, Total Negative: 373671093.0, Triplet Accuracy: 1.0


Epoch 99: 100%|██████████| 242/242 [00:19<00:00, 12.17it/s]


Epoch 99, Loss: 0.36831243373145744
Validation Triplet Loss: 0.0, Total Positive: 8182136.5625, Total Negative: 392595738.0, Triplet Accuracy: 1.0


Epoch 100: 100%|██████████| 242/242 [00:20<00:00, 12.03it/s]


Epoch 100, Loss: 0.04842498775356072
Validation Triplet Loss: 0.0, Total Positive: 8269331.515625, Total Negative: 392311106.0, Triplet Accuracy: 1.0


Epoch 101: 100%|██████████| 242/242 [00:19<00:00, 12.24it/s]


Epoch 101, Loss: 0.0833943257154512
Validation Triplet Loss: 0.0, Total Positive: 8177993.40625, Total Negative: 359335258.0, Triplet Accuracy: 1.0


Epoch 102: 100%|██████████| 242/242 [00:19<00:00, 12.27it/s]


Epoch 102, Loss: 0.07339866890394983
Validation Triplet Loss: 0.0, Total Positive: 8191884.703125, Total Negative: 367224848.0, Triplet Accuracy: 1.0


Epoch 103: 100%|██████████| 242/242 [00:20<00:00, 12.10it/s]


Epoch 103, Loss: 0.0026998945996781027
Validation Triplet Loss: 0.0, Total Positive: 8177294.21875, Total Negative: 364286348.0, Triplet Accuracy: 1.0


Epoch 104: 100%|██████████| 242/242 [00:19<00:00, 12.19it/s]


Epoch 104, Loss: 86.32136634404748
Validation Triplet Loss: 0.0, Total Positive: 7869520.59375, Total Negative: 368078166.0, Triplet Accuracy: 1.0


Epoch 105: 100%|██████████| 242/242 [00:20<00:00, 12.10it/s]


Epoch 105, Loss: 10.55268792795741
Validation Triplet Loss: 0.0, Total Positive: 7818492.75, Total Negative: 363143839.0, Triplet Accuracy: 1.0


Epoch 106: 100%|██████████| 242/242 [00:19<00:00, 12.12it/s]


Epoch 106, Loss: 12.464728800718449
Validation Triplet Loss: 0.0, Total Positive: 7833424.46875, Total Negative: 374235213.0, Triplet Accuracy: 1.0


Epoch 107: 100%|██████████| 242/242 [00:20<00:00, 12.06it/s]


Epoch 107, Loss: 0.1732208060854969
Validation Triplet Loss: 0.0, Total Positive: 7986013.75, Total Negative: 390774801.0, Triplet Accuracy: 1.0


Epoch 108: 100%|██████████| 242/242 [00:20<00:00, 11.95it/s]


Epoch 108, Loss: 81.09481133606808
Validation Triplet Loss: 0.0, Total Positive: 7699870.4375, Total Negative: 366361806.0, Triplet Accuracy: 1.0


Epoch 109: 100%|██████████| 242/242 [00:22<00:00, 10.99it/s]


Epoch 109, Loss: 0.06863889024277364
Validation Triplet Loss: 0.0, Total Positive: 7765618.546875, Total Negative: 383455684.0, Triplet Accuracy: 1.0


Epoch 110: 100%|██████████| 242/242 [00:23<00:00, 10.42it/s]


Epoch 110, Loss: 0.002158798708403406
Validation Triplet Loss: 0.0, Total Positive: 7787292.21875, Total Negative: 382881189.0, Triplet Accuracy: 1.0


Epoch 111:  59%|█████▊    | 142/242 [00:16<00:11,  8.58it/s]


KeyboardInterrupt: 

In [668]:
torch.save(model.state_dict(), "./full_model_state_v6")

In [669]:
torch.save(model, "./full_model_v6")

## Testing

In [775]:
from generate_dataset import node_to_dict, sympy_to_node
import sympy as sp
from latex2sympy2 import latex2sympy

In [776]:
def latex_to_graph(tex):
    return node_to_dict(sympy_to_node(latex2sympy(tex)))

In [851]:
expr_l = latex_to_graph(r"x + y")
expr_r = latex_to_graph(r"x + y ")

In [852]:
expr_l

{'type': 'OPERATION',
 'value': None,
 'subtype': 'ADD',
 'children': [{'type': 'VARIABLE',
   'value': 'x',
   'subtype': None,
   'children': []},
  {'type': 'VARIABLE', 'value': 'y', 'subtype': None, 'children': []}]}

In [853]:
expr_l

{'type': 'OPERATION',
 'value': None,
 'subtype': 'ADD',
 'children': [{'type': 'VARIABLE',
   'value': 'x',
   'subtype': None,
   'children': []},
  {'type': 'VARIABLE', 'value': 'y', 'subtype': None, 'children': []}]}

In [854]:
expr_l_obj = dict_to_geometric_representation(expr_l, node_attr_encoder)
expr_r_obj = dict_to_geometric_representation(expr_r, node_attr_encoder)

In [858]:
expr_l_obj.x

tensor([[11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,
         11., 11.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1.],
        [18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18.,
         18., 18., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1.,  0., 41.,  0., 41.,  0., 41.,  0., 41.,  0., 41.,
          0., 41.,  0., 41.,  0., 41.],
        [18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18.,
         18., 18., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1.,  0., 42.,  0., 42.,  0., 42.,  0., 42.,  0., 42.,
          0., 42.,  0., 42.,  0., 42.]])

In [856]:
emb_l, emb_r, emb_r_c = model(expr_l_obj, expr_r_obj, expr_r_obj)

In [850]:
print(F.pairwise_distance(emb_l, emb_r))

tensor([96952.6172], grad_fn=<NormBackward1>)


In [767]:
expr_left = {
      "type": "POW",
      "value": None,
      "subtype": None,
      "children": [
        {
          "type": "POW",
          "value": None,
          "subtype": None,
          "children": [
            {
              "type": "LITERAL",
              "value": -3.5,
              "subtype": None,
              "children": []
            },
            {
              "type": "LITERAL",
              "value": 3.0,
              "subtype": None,
              "children": []
            }
          ]
        },
        {
          "type": "LITERAL",
          "value": 4.5,
          "subtype": None,
          "children": []
        }
      ]
    }
expr_right = {
      "type": "POW",
      "value": None,
      "subtype": None,
      "children": [
        {
          "type": "OPERATION",
          "value": None,
          "subtype": "ADD",
          "children": [
            {
              "type": "LITERAL",
              "value": -5.5,
              "subtype": None,
              "children": []
            },
            {
              "type": "LITERAL",
              "value": -10.0,
              "subtype": None,
              "children": []
            }
          ]
        },
        {
          "type": "LITERAL",
          "value": 4.0,
          "subtype": None,
          "children": []
        }
      ]
    }

expr_right_constant_error ={
      "type": "POW",
      "value": None,
      "subtype": None,
      "children": [
        {
          "type": "OPERATION",
          "value": None,
          "subtype": "ADD",
          "children": [
            {
              "type": "LITERAL",
              "value": -5.5,
              "subtype": None,
              "children": []
            },
            {
              "type": "LITERAL",
              "value": -10.0,
              "subtype": None,
              "children": []
            }
          ]
        },
        {
          "type": "LITERAL",
          "value": 4.0,
          "subtype": None,
          "children": []
        }
      ]
    }

In [302]:
expr_l_obj = dict_to_geometric_representation(expr_left, node_attr_encoder)
expr_r_obj = dict_to_geometric_representation(expr_right, node_attr_encoder)
expr_r_c_obj = dict_to_geometric_representation(expr_right_constant_error, node_attr_encoder)

In [303]:
emb_l, emb_r, emb_r_c = model(expr_l_obj, expr_r_obj, expr_r_c_obj)

In [304]:
print(DISTANCE_FUNC(emb_l, emb_r))
print(DISTANCE_FUNC(emb_l, emb_r_c))
print(DISTANCE_FUNC(emb_r, emb_r_c))
print(DISTANCE_FUNC(emb_r, emb_r))



tensor([1.0000], grad_fn=<SumBackward1>)
tensor([1.0000], grad_fn=<SumBackward1>)
tensor([1.0000], grad_fn=<SumBackward1>)
tensor([1.0000], grad_fn=<SumBackward1>)


In [774]:
model.info()

AttributeError: 'TripletFormulaNet' object has no attribute 'info'

In [None]:
    {
        "expr_anchor": {"type": "OPERATION", "value": null, "subtype": "ADD", "children": [{"type": "FUNC", "value": null, "subtype": "COS", "children": [{"type": "POW", "value": null, "subtype": null, "children": [{"type": "VARIABLE", "value": "x", "subtype": null, "children": []}, {"type": "LITERAL", "value": 2.0, "subtype": null, "children": []}]}]}, {"type": "FUNC", "value": null, "subtype": "SIN", "children": [{"type": "POW", "value": null, "subtype": null, "children": [{"type": "VARIABLE", "value": "x", "subtype": null, "children": []}, {"type": "LITERAL", "value": 2.0, "subtype": null, "children": []}]}]}]},

        "expr_true": {"type": "LITERAL", "value": 1.0, "subtype": null, "children": []},
        "expr_false": {"type": "FUNC", "value": null, "subtype": "COS", "children": [{"type": "VARIABLE", "value": "x", "subtype": null, "children": []}]}

    }