In [55]:
import math
import torch
from torch_geometric.data import InMemoryDataset, Data, Batch
from torch.utils.data import DataLoader, Subset, random_split
from sklearn.calibration import LabelEncoder
import json 

In [56]:
from torch import nn
from torch.nn import Linear, ReLU, LeakyReLU, init
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import numpy as np
import torch.optim.lr_scheduler as lr_scheduler


In [57]:
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [58]:
DATASET_FILEPATH = "math_datagen_triplet_8k.json"
BATCH_SIZE = 256
TRAIN_RATIO = 0.6
VAL_RATIO = 0.2

In [59]:
TYPES = ["FUNC", "OPERATION", "POW", "LITERAL", "VARIABLE", "CONSTANT_LITERAL"]
OPERATIONS = ["ADD", "MUL", "FUNC", "POW"]
FUNCTIONS = ["SIN", "COS", "TAN", "EXP", "LOG", "SINH", "COSH", "TANH"]
VARIABLE_ALPHABET = [chr(x) for x in range(ord('a'), ord('z')+1) if chr(x) not in ["e", "i"]]
CLASSIC_CONSTANTS = ["PI", "I", "g","e", "zoo", "E"]

In [60]:

def dict_to_geometric_representation(in_graph_dict: dict, encoder) -> Data:
    node_list = []
    edge_mappings = []
    def traverse_graph(graph = in_graph_dict):
        nonlocal node_list
        nonlocal edge_mappings
        curr_node_index = len(node_list)
        encoded_data = encoder({"type": graph["type"], "subtype": graph["subtype"], "value": graph["value"]})
        # print("Encoded data: ", encoded_data)
        node_list.append(encoded_data)
        if "children" in graph.keys():
            for child in graph["children"]:
                edge_mappings.append((curr_node_index, traverse_graph(child)) ) #I'm retarded.
        return curr_node_index
    traverse_graph()
    nodes = torch.tensor(node_list,dtype=torch.float32)
    if not torch.all(torch.isfinite(nodes)):
        print("NODES CONTAIN INFINITE VALUE ---> \n")
        print("X = ", nodes)
    nodes[nodes == float("inf")] = 0
    nodes[nodes == -float("inf")] = 0
    if not torch.all(torch.isfinite(nodes)):
        print("MODIFIED NODES CONTAIN INFINITE VALUE ---> \n")
        print("X = ", nodes)
    edges = torch.tensor([[x[0] for x in edge_mappings], [x[1] for x in edge_mappings]], dtype=torch.long) # Probably slow and mentally degenerated
    geom_data = Data(x=nodes, edge_index=edges)
    return geom_data

#### Encoder

In [61]:
rep = 4

In [62]:
label_encoder = LabelEncoder()
label_encoder.fit(OPERATIONS+FUNCTIONS+CLASSIC_CONSTANTS+VARIABLE_ALPHABET+TYPES)
# --- misc
nan_count = 0 
sample_count = 0

In [63]:
def sanitize_value(value):
    if value in ["-Infinity","Infinity", float("inf"), float("-inf"), "inf", "-inf"]:
        return None
    return value

In [65]:
def node_attr_encoder(attr):
    global nan_count
    global sample_count
    
    def safe_label_encode(value):
        return label_encoder.transform([value])[0] if value else -1

    type_encoding = [safe_label_encode(attr.get("type"))] * rep
    subtype_encoding = [safe_label_encode(attr.get("subtype"))] * rep
    
    value_encoding_vec = [-1] * rep  # Default to all -1 if no valid value
    
    if sanitize_value(attr.get("value")) is not None:
        if isinstance(attr["value"], str):
            value_encoding = safe_label_encode(attr["value"])
            value_encoding_vec = [(0 if i % 2 == 0 else value_encoding) for i in range(rep)]
        else:
            try:
                val = float(attr["value"])
                if math.isfinite(val):
                    expon = math.log10(abs(val))
                    val = val // 10**int(expon)
                    value_encoding_vec = [(expon if i % 2 == 1 else val) for i in range(rep)]
                else:
                    raise ValueError("Non-finite value")
            except (ValueError, TypeError):
                nan_count += 1
                value_encoding_vec = [(-1 if i % 2 == 1 else 1) for i in range(rep)]
            finally:
                sample_count += 1

    result = type_encoding + subtype_encoding + value_encoding_vec
    return result

#### Loading in the dataset

In [66]:
triplet_dataset = []

In [67]:
# print('"value": Infinity,'.find("Infinity"))

In [68]:
with open(DATASET_FILEPATH) as file_handle:
    object_data = json.load(file_handle)
    for triplet in object_data:
        ## TODO : check if contains infinity (triplet to str and then search for Infinity)
        # print(triplet)
        # print(json.dumps(triplet).find("inf"))
        
        if json.dumps(triplet).find("inf") == -1:
            expr_true = triplet["expr_true"]
            expr_false = triplet["expr_false"]
            expr_anchor = triplet["expr_anchor"]
            geom_anchor = dict_to_geometric_representation(expr_anchor,node_attr_encoder )
            geom_true = dict_to_geometric_representation(expr_true,node_attr_encoder )
            geom_false = dict_to_geometric_representation(expr_false,node_attr_encoder )
            triplet_dataset.append((geom_anchor, geom_true, geom_false))
        
        

In [69]:
# with open("test","w") as f:
#     f.write(str(triplet_dataset))

#### Neural network training

In [70]:
num_features = triplet_dataset[0][0].num_features

In [71]:
class FormulaNet(nn.Module):
    def __init__(self, hidden_channels: int, embedding_space: int):
        super(FormulaNet, self).__init__()
        self.dense_1 = Linear(num_features, num_features*2) 
        self.dense_2 = Linear(num_features*2, num_features*2) 
        
        self.relu_1 = LeakyReLU(0.2)
        self.gconv_1 = GCNConv(num_features*2, hidden_channels)
        self.gconv_2 = GCNConv(hidden_channels, hidden_channels)
        self.dense_3 = Linear(hidden_channels, embedding_space)
        gain = torch.nn.init.calculate_gain("leaky_relu", 0.2)
        # Initialize dense_1 weights
        init.xavier_uniform_(self.dense_1.weight, gain=gain)
        init.xavier_uniform_(self.dense_2.weight, gain=gain)
        # # Initialize gconv_1 weights
        init.xavier_uniform_(self.gconv_1.lin.weight.data.T, gain=gain)
        # # Initialize gconv_2 weights
        init.xavier_uniform_(self.gconv_2.lin.weight.data.T, gain=gain)
        #Initialize dense_3 weights
        init.xavier_uniform_(self.dense_3.weight, gain=gain)
    
    def forward(self, x, edge_index, batch):
        if not torch.all(torch.isfinite(x)):
            print("INPUT CONTAINS INFINITE VALUE")
            #print("x = ", x)
            


        nan_mask = torch.isnan(x)
        if nan_mask.any():
            print("Some values of the input are nan: ", x[nan_mask])
            # Replace NaNs with 0s
            x[nan_mask] = -11
            
        
        # print("X = ", str(x))
        x = self.dense_1(x)
        # print("d1 X = ", str(x))
        x = self.relu_1(x)
        
        x = self.dense_2(x)
        x = self.relu_1(x)

        # #print("rd1 X = ", str(x))
        x = self.gconv_1(x, edge_index)
        # print("gcn1 X = ", str(x))
        x = self.relu_1(x)
        # print("rgcn1 X = ", str(x))
        # x = self.gconv_2(x, edge_index)
        # # print("gcn2 X = ", str(x))
        # x = self.relu_1(x)
        # # print("rgcn2 X = ", str(x))
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.4,training=self.training)
        x = self.dense_3(x)
        x = self.relu_1(x)
        # nan_mask = torch.isnan(x)
        # if nan_mask.any():
        #     print("Some values of the input are nan: ", x[nan_mask])
            
        return x
    

In [72]:
class TripletFormulaNet(nn.Module):
    def __init__(self, hidden_channels, embedding_space):
        super(TripletFormulaNet, self).__init__()
        self.formulanet = FormulaNet(hidden_channels, embedding_space)

    def forward(self, anchor, positive, negative):
        # Compute the embeddings for the anchor, positive, and negative inputs
        embed_anchor = self.formulanet(anchor.x, anchor.edge_index, anchor.batch)
        embed_positive = self.formulanet(positive.x, positive.edge_index, positive.batch)
        embed_negative = self.formulanet(negative.x, negative.edge_index, negative.batch)
        
        return embed_anchor, embed_positive, embed_negative

#### Training

In [73]:
## CHECK DATASET:
with open("test", "w") as f:
    for a, b,c in triplet_dataset:
        f.write(str(a.x) + "\n" + str(b.x) + "\n" + str(c.x) + "\n")

In [74]:
def triplet_loss(anchor, positive, negative, margin=1.0):
    """
    Triplet loss function.
    
    Args:
        anchor (torch.Tensor): Embedding of the anchor data point.
        positive (torch.Tensor): Embedding of the positive data point (same class as anchor).
        negative (torch.Tensor): Embedding of the negative data point (different class from anchor).
        margin (float, optional): Margin value for triplet loss. Defaults to 1.0.
        
    Returns:
        torch.Tensor: Triplet loss value.
    """
    
    # Calculate the Euclidean distances between the embeddings
    distance_positive = F.pairwise_distance(anchor, positive)
    distance_negative = F.pairwise_distance(anchor, negative)
    
    # Calculate the triplet loss
    losses = F.relu(distance_positive - distance_negative + margin)
    
    return losses.mean()

In [75]:
total_len = len(triplet_dataset)
train_len = int(TRAIN_RATIO * total_len)
val_len = int(VAL_RATIO * total_len)
test_len = total_len - train_len - val_len

train_dataset, val_dataset, test_dataset = random_split(triplet_dataset, [train_len, val_len, test_len])

In [76]:
def collate(data_list):
    batchA = Batch.from_data_list([data[0] for data in data_list])
    batchB = Batch.from_data_list([data[1] for data in data_list])
    batchC = Batch.from_data_list([data[2] for data in data_list])
    return batchA, batchB, batchC

In [97]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
validation_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle= False, collate_fn=collate)

In [98]:
device = torch.device("cpu") #torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [99]:
model = TripletFormulaNet(32,32).to(device)

In [100]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=0.001, end_factor=0.0005, total_iters=10)

In [101]:
from tqdm import tqdm

In [102]:
training_loss = []  # Store training loss per epoch
validation_loss = []  # Store validation loss per epoch

In [103]:
def train(epoch):
    model.train()
    epoch_loss = 0
    for batch_anchor, batch_positive, batch_negative in tqdm(train_loader, desc=f'Epoch {epoch}'):

        batch_anchor, batch_positive, batch_negative = batch_anchor.to(device), batch_positive.to(device), batch_negative.to(device)
        optimizer.zero_grad()
        embed_anchor, embed_positive, embed_negative = model(batch_anchor, batch_positive, batch_negative)
        loss = triplet_loss(embed_anchor, embed_positive, embed_negative, margin=1.0)
        loss.backward()
        writer.add_scalar("Loss/train", loss, epoch)
        
        # Log gradients
        # for name, param in model.formulanet.named_parameters():
        #     if param.grad is not None and param.grad.numel() > 0:
        #         writer.add_histogram(f'{name}.grad', param.grad, epoch)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        epoch_loss += loss.item()
    scheduler.step(epoch)
    training_loss.append(epoch_loss / len(train_loader))  # Store epoch average loss
    print(f'Epoch {epoch}, Loss: {epoch_loss / len(train_loader)}')

In [104]:
def validate(model, validation_loader, device):
    model.eval()
    triplet_loss_tot = 0
    correct_triplets = 0
    total_triplets = 0

    with torch.no_grad():
        for batch_anchor, batch_positive, batch_negative in validation_loader:
            batch_anchor, batch_positive, batch_negative = batch_anchor.to(device), batch_positive.to(device), batch_negative.to(device)
            embed_anchor, embed_positive, embed_negative = model(batch_anchor, batch_positive, batch_negative)

            # Compute triplet loss
            loss = triplet_loss(embed_anchor, embed_positive, embed_negative, margin=1.0)
            triplet_loss_tot += loss.item() * batch_anchor.size(0)

            # Compute triplet accuracy
            distance_positive = F.pairwise_distance(embed_anchor, embed_positive)
            distance_negative = F.pairwise_distance(embed_anchor, embed_negative)
            correct_triplets += (distance_positive < distance_negative).sum().item()
            total_triplets += distance_positive.size(0)

    triplet_loss_tot /= len(validation_loader.dataset)
    triplet_accuracy = correct_triplets / total_triplets

    print(f'Validation Triplet Loss: {triplet_loss_tot}, Triplet Accuracy: {triplet_accuracy}')

    return triplet_loss_tot, triplet_accuracy

In [105]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [106]:
num_epochs = 10

for epoch in range(num_epochs):
    train(epoch)
    validate(model, validation_loader, device)

Epoch 0: 100%|██████████| 150/150 [00:07<00:00, 21.38it/s]


Epoch 0, Loss: 5.936689825057983
Validation Triplet Loss: 37.549303402230144, Triplet Accuracy: 0.226484375


Epoch 1: 100%|██████████| 150/150 [00:06<00:00, 21.64it/s]


Epoch 1, Loss: 5.770085296630859
Validation Triplet Loss: 36.80406214598566, Triplet Accuracy: 0.22765625


Epoch 2: 100%|██████████| 150/150 [00:06<00:00, 21.69it/s]


Epoch 2, Loss: 5.627102230389913
Validation Triplet Loss: 36.10427328631282, Triplet Accuracy: 0.22671875


Epoch 3: 100%|██████████| 150/150 [00:07<00:00, 21.37it/s]


Epoch 3, Loss: 5.555323699315389
Validation Triplet Loss: 35.43232946626842, Triplet Accuracy: 0.22671875


Epoch 4: 100%|██████████| 150/150 [00:07<00:00, 19.92it/s]


Epoch 4, Loss: 5.438143167495728
Validation Triplet Loss: 34.83094189401716, Triplet Accuracy: 0.226953125


Epoch 5: 100%|██████████| 150/150 [00:06<00:00, 22.31it/s]


Epoch 5, Loss: 5.371730017662048
Validation Triplet Loss: 34.27489921368659, Triplet Accuracy: 0.226796875


Epoch 6: 100%|██████████| 150/150 [00:07<00:00, 19.83it/s]


Epoch 6, Loss: 5.27255446434021
Validation Triplet Loss: 33.741417535580695, Triplet Accuracy: 0.2278125


Epoch 7: 100%|██████████| 150/150 [00:08<00:00, 18.52it/s]


Epoch 7, Loss: 5.21879915078481
Validation Triplet Loss: 33.24720354553312, Triplet Accuracy: 0.2271875


Epoch 8: 100%|██████████| 150/150 [00:08<00:00, 17.26it/s]


Epoch 8, Loss: 5.14193672657013
Validation Triplet Loss: 32.797125057838855, Triplet Accuracy: 0.22734375


Epoch 9: 100%|██████████| 150/150 [00:07<00:00, 18.95it/s]


Epoch 9, Loss: 5.051755712827046
Validation Triplet Loss: 32.38576854676008, Triplet Accuracy: 0.227421875


## Testing

In [120]:
expr_left = {
      "type": "POW",
      "value": None,
      "subtype": None,
      "children": [
        {
          "type": "POW",
          "value": None,
          "subtype": None,
          "children": [
            {
              "type": "LITERAL",
              "value": -3.5,
              "subtype": None,
              "children": []
            },
            {
              "type": "LITERAL",
              "value": 3.0,
              "subtype": None,
              "children": []
            }
          ]
        },
        {
          "type": "LITERAL",
          "value": 4.5,
          "subtype": None,
          "children": []
        }
      ]
    }
expr_right = {
      "type": "POW",
      "value": None,
      "subtype": None,
      "children": [
        {
          "type": "OPERATION",
          "value": None,
          "subtype": "ADD",
          "children": [
            {
              "type": "LITERAL",
              "value": -5.5,
              "subtype": None,
              "children": []
            },
            {
              "type": "LITERAL",
              "value": -10.0,
              "subtype": None,
              "children": []
            }
          ]
        },
        {
          "type": "LITERAL",
          "value": 4.0,
          "subtype": None,
          "children": []
        }
      ]
    }

expr_right_constant_error ={
      "type": "POW",
      "value": None,
      "subtype": None,
      "children": [
        {
          "type": "OPERATION",
          "value": None,
          "subtype": "ADD",
          "children": [
            {
              "type": "LITERAL",
              "value": -5.5,
              "subtype": None,
              "children": []
            },
            {
              "type": "LITERAL",
              "value": -10.0,
              "subtype": None,
              "children": []
            }
          ]
        },
        {
          "type": "LITERAL",
          "value": 4.0,
          "subtype": None,
          "children": []
        }
      ]
    }

In [121]:
expr_l_obj = dict_to_geometric_representation(expr_left, node_attr_encoder)
expr_r_obj = dict_to_geometric_representation(expr_right, node_attr_encoder)
expr_r_c_obj = dict_to_geometric_representation(expr_right_constant_error, node_attr_encoder)

In [122]:
emb_l, emb_r, emb_r_c = model(expr_l_obj, expr_r_obj, expr_r_c_obj)

In [123]:
print(F.pairwise_distance(emb_l, emb_r))
print(F.pairwise_distance(emb_l, emb_r_c))
print(F.pairwise_distance(emb_r, emb_r_c))
print(F.pairwise_distance(emb_r, emb_r))



tensor([2.0355], grad_fn=<NormBackward1>)
tensor([1.9440], grad_fn=<NormBackward1>)
tensor([1.9969], grad_fn=<NormBackward1>)
tensor([5.6569e-06], grad_fn=<NormBackward1>)
