In [161]:
from typing import Tuple

In [162]:
import torch
from torch_geometric.data import InMemoryDataset, Data, Batch
from torch.utils.data import DataLoader, Subset
from sklearn.calibration import LabelEncoder
from generate_dataset import generate_dataset
import json 

Constants

In [163]:
GENERATED_DATASET_SIZE = 130
BATCH_SIZE = 32
TRAIN_SAMPLES = 50
TRAIN_RATIO = 0.6
VAL_RATIO = 0.2

In [164]:
def dict_to_geometric_representation(in_graph_dict: dict, encoder) -> Data:
    node_list = []
    edge_mappings = []
    def traverse_graph(graph = in_graph_dict):
        nonlocal node_list
        nonlocal edge_mappings
        curr_node_index = len(node_list)
        encoded_data = encoder(graph["val"])
        node_list.append(encoded_data)
        if "children" in graph.keys():
            for child in graph["children"]:
                edge_mappings.append((curr_node_index, traverse_graph(child)) ) #I'm retarded.
        return curr_node_index
    traverse_graph()
    nodes = torch.tensor(node_list,dtype=torch.float32)
    edges = torch.tensor([[x[0] for x in edge_mappings], [x[1] for x in edge_mappings]], dtype=torch.long) # Probably slow and mentally degenerated
    geom_data = Data(x=nodes, edge_index=edges)
    return geom_data

In [165]:
OPERATIONS = ["ADD", "MUL", "FUNC", "POW"]
FUNCTIONS = ["SIN", "COS", "TAN", "EXP", "LOG", "f", "g", "h"]
ATOMICS = ["LITERAL", "VARIABLE"]
VARIABLE_ALPHABET = [chr(x) for x in range(ord("a"), ord("z")+1) if chr(x) not in ["f", "g", "h"]]

In [166]:
def make_node_attribute_encoder(label_encoder:LabelEncoder, rep = 3):
    def node_attr_encoder(attr):
        if isinstance(attr, str) and attr in OPERATIONS+FUNCTIONS+ATOMICS+VARIABLE_ALPHABET:
            res = label_encoder.transform([attr])
            return [res[0]]*(rep + 1)
        else:
            return [0] + [float(attr)]*rep
            
    return node_attr_encoder

In [167]:
def create_dataset_class(expression):
    # Will it be the same for both datasets ? 
    le = LabelEncoder()
    le.fit(OPERATIONS+FUNCTIONS+ATOMICS+VARIABLE_ALPHABET)
    class MathExpressionDataset(InMemoryDataset):
        def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
            super().__init__(root, transform, pre_transform, pre_filter, force_reload=True)
            self.load(self.processed_paths[0])
            
        @property
        def raw_file_names(self):
            return ['math_datagen.json']

        @property
        def processed_file_names(self):
            return ['data.pt']
        

        def process(self):
            # Read data into huge `Data` list.
            data_list = []
            for file in self.raw_file_names:
                with open(file) as file_handle:
                    object_data = json.load(file_handle)
                    for comparison in object_data:
                        expr = comparison[expression]
                        score = comparison["score"]
                        geometric_expr = dict_to_geometric_representation(expr, make_node_attribute_encoder(le))
                        geometric_expr.y = score #torch.tensor([score],dtype=torch.float32)
                        data_list.append(geometric_expr)
                        
            if self.pre_filter is not None:
                data_list = [data for data in data_list if self.pre_filter(data)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(data) for data in data_list]
            self.save(data_list, self.processed_paths[0])
    return MathExpressionDataset
    

In [168]:
class ExpressionPairDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__()
        self.dataset_l = create_dataset_class("expr_l")(root+"_l",transform=None, pre_transform=None, pre_filter=None)
        self.dataset_r = create_dataset_class("expr_r")(root+"_r",transform=None, pre_transform=None, pre_filter=None)
        
    @property 
    def num_features(self):
        return self.dataset_l.num_features
    
    def __len__(self):
        return len(self.dataset_l)
        
    
    
    def __getitem__(self, idx):
        return self.dataset_l[idx], self.dataset_r[idx]

In [169]:
generate_dataset(GENERATED_DATASET_SIZE,"math_datagen.json")
dataset = ExpressionPairDataset(root="/dataset")

Processing...
Done!
Processing...
Done!


In [170]:
from torch import nn
from torch.nn import Linear, ReLU
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

In [171]:
class FormulaNet(nn.Module):
    def __init__(self, hidden_channels: int, embedding_space: int):
        super(FormulaNet, self).__init__()
        self.dense_1 = Linear(dataset.num_features, dataset.num_features) 
        self.relu_1 = ReLU()
        self.gconv_1 = GCNConv(dataset.num_features, hidden_channels)
        self.gconv_2 = GCNConv(hidden_channels, hidden_channels)
        self.gconv_3 = GCNConv(hidden_channels, hidden_channels)
        self.dense_2 = Linear(hidden_channels, embedding_space)
    
    def forward(self, x, edge_index, batch):
        x = self.dense_1(x)
        x = self.relu_1(x)
        x = self.gconv_1(x, edge_index)
        x = self.relu_1(x)
        x = self.gconv_2(x, edge_index)
        x = self.relu_1(x)
        x = self.gconv_3(x, edge_index)
        x = self.relu_1(x)
        x = global_mean_pool(x, batch)
        x = F.dropout(x, p=0.3,training=self.training)
        x = self.dense_2(x)
        return x
    

In [172]:
class SiameseFormulaNet(nn.Module):
    def __init__(self, hidden_channels, embedding_space):
        super(SiameseFormulaNet, self).__init__()
        self.formulanet = FormulaNet(hidden_channels, embedding_space)
        # self.fc = nn.Sequential(
        #     Linear(embedding_space*2, embedding_space),
        #     ReLU(inplace=True),
        #     Linear(embedding_space, 1)
        # )
        # self.sigmoid = nn.Sigmoid() # TODO: Only used it for testing purposes, everything is subject to change Okay
    

    def forward(self, expr_l, expr_r):
        # print(expr_l)
        # print("X = ", expr_l.x)
        # print("X shape = ", expr_l.x.shape)
        # print("Batch = ",expr_l.batch)
        # print("Batch shape = ",expr_l.batch.shape)
        # print("Y =", expr_l.y)
        # print("Edge index = ",expr_l.edge_index)
        # print("Edge index shape = ",expr_l.edge_index.shape)
        
        embed_l = self.formulanet(expr_l.x,expr_l.edge_index, expr_l.batch)
        embed_l = embed_l.view(embed_l.size()[0], -1)
        embed_r = self.formulanet(expr_r.x,expr_r.edge_index, expr_r.batch)
        embed_r = embed_r.view(embed_r.size()[0], -1)
        
        # output = torch.cat((embed_l, embed_r), 1)
        
        # output = self.fc(output)
        # output = self.sigmoid(output)
        # return output
        return embed_l, embed_r
        
        
        

In [173]:
# train_dataset = Subset(dataset, list(range(TRAIN_SAMPLES)))
# validation_dataset = Subset(dataset, list(range()))
# test_dataset = Subset(dataset, list(range(TRAIN_SAMPLES, -1)))

In [174]:
from torch.utils.data import random_split

In [175]:
total_len = len(dataset.dataset_l)
train_len = int(TRAIN_RATIO * total_len)
val_len = int(VAL_RATIO * total_len)
test_len = total_len - train_len - val_len
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_len, val_len, test_len])

In [176]:

def collate(data_list):
    batchA = Batch.from_data_list([data[0] for data in data_list])
    batchB = Batch.from_data_list([data[1] for data in data_list])
    return batchA, batchB
# NOTE: Type ignore only for collate_fn_t ... make sure it doesn't get in the way of correct typing for the dataset
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate) # type: ignore
validation_laoder = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle= False, collate_fn = collate)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate) # type: ignore


In [177]:
device = torch.device("cpu") #torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [178]:
model = SiameseFormulaNet(32,64).to(device)

In [179]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [180]:
from tqdm import tqdm

In [181]:
def contrastive_loss(embed_l, embed_r, labels, margin=1.0, threshold=0.5):
    euclidean_distance = F.pairwise_distance(embed_l, embed_r)
    labels = (labels > threshold).float()  # Convert labels to 0 or 1
    loss_contrastive = torch.mean((1 - labels) * torch.pow(euclidean_distance, 2) +
                                  labels * torch.pow(torch.clamp(margin - euclidean_distance, min=0.0), 2))
    return loss_contrastive

In [182]:
def train(epoch):
    model.train()
    epoch_loss = 0
    for batch_l, batch_r in tqdm(train_loader, desc=f'Epoch {epoch}'):
        batch_l, batch_r = batch_l.to(device), batch_r.to(device)
        optimizer.zero_grad()
        embed_l, embed_r = model(batch_l, batch_r)
        loss = contrastive_loss(embed_l, embed_r, batch_l.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'Epoch {epoch}, Loss: {epoch_loss / len(train_loader)}')

In [183]:
def validate():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_l, batch_r in validation_laoder:
            batch_l, batch_r = batch_l.to(device), batch_r.to(device)
            embed_l, embed_r = model(batch_l, batch_r)
            euclidean_distance = F.pairwise_distance(embed_l, embed_r)
            pred = (euclidean_distance < 0.5).float()  # Adjust the threshold as needed
            correct += (pred == batch_l.y).sum().item()
            test_loss += contrastive_loss(embed_l, embed_r, batch_l.y).item()
    test_loss /= len(test_loader)
    acc = correct / len(test_dataset)
    print(f'Test Loss: {test_loss}, Accuracy: {acc}')

In [184]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [185]:
num_epochs = 10
for epoch in range(num_epochs):
    train(epoch)
    validate()

Epoch 0: 100%|██████████| 8/8 [00:00<00:00, 28.13it/s]


Epoch 0, Loss: 0.5591804850846529
Test Loss: 0.6791805823644003, Accuracy: 0.0


Epoch 1:  50%|█████     | 4/8 [00:00<00:00, 39.81it/s]

Epoch 1: 100%|██████████| 8/8 [00:00<00:00, 35.07it/s]


Epoch 1, Loss: 0.26508706994354725
Test Loss: 0.6795454025268555, Accuracy: 0.0


Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 35.13it/s]


Epoch 2, Loss: 0.2624411229044199
Test Loss: 0.6878628730773926, Accuracy: 0.0


Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 29.90it/s]


Epoch 3, Loss: 0.26856590062379837
Test Loss: 0.6832131743431091, Accuracy: 0.0


Epoch 4: 100%|██████████| 8/8 [00:00<00:00, 38.31it/s]


Epoch 4, Loss: 0.2476294282823801
Test Loss: 0.6734213431676229, Accuracy: 0.0


Epoch 5: 100%|██████████| 8/8 [00:00<00:00, 41.69it/s]


Epoch 5, Loss: 0.2756047993898392
Test Loss: 0.6813262303670248, Accuracy: 0.0


Epoch 6: 100%|██████████| 8/8 [00:00<00:00, 42.09it/s]


Epoch 6, Loss: 0.2569238431751728
Test Loss: 0.681311825911204, Accuracy: 0.0


Epoch 7: 100%|██████████| 8/8 [00:00<00:00, 47.37it/s]


Epoch 7, Loss: 0.24857244081795216
Test Loss: 0.6861916581789652, Accuracy: 0.0


Epoch 8: 100%|██████████| 8/8 [00:00<00:00, 48.10it/s]


Epoch 8, Loss: 0.25428937934339046
Test Loss: 0.6842256387074789, Accuracy: 0.0


Epoch 9: 100%|██████████| 8/8 [00:00<00:00, 46.81it/s]


Epoch 9, Loss: 0.23729438334703445
Test Loss: 0.6837652126948038, Accuracy: 0.0


## Evaluation

In [186]:
le = LabelEncoder()
le.fit(OPERATIONS+FUNCTIONS+ATOMICS+VARIABLE_ALPHABET)
eval_node_attr_encoder = make_node_attribute_encoder(le)

In [187]:
expr_left = {
            "val": "COS",
            "id": 1,
            "children": [
                {
                    "val": "ADD",
                    "id": 3,
                    "children": [
                        {
                            "val": "j",
                            "id": 7
                        },
                        {
                            "val": "POW",
                            "id": 8,
                            "children": [
                                {
                                    "val": "-9.5",
                                    "id": 17
                                },
                                {
                                    "val": "q",
                                    "id": 18
                                }
                            ]
                        }
                    ]
                }
            ]
        }
expr_right = {
            "val": "COS",
            "id": 1,
            "children": [
                {
                    "val": "ADD",
                    "id": 3,
                    "children": [
                        {
                            "val": "MUL",
                            "id": 7,
                            "children": [
                                {
                                    "val": "j",
                                    "id": 15
                                },
                                {
                                    "val": "-1",
                                    "id": 16
                                }
                            ]
                        },
                        {
                            "val": "POW",
                            "id": 8,
                            "children": [
                                {
                                    "val": "9.5",
                                    "id": 17
                                },
                                {
                                    "val": "MUL",
                                    "id": 18,
                                    "children": [
                                        {
                                            "val": "q",
                                            "id": 37
                                        },
                                        {
                                            "val": "-1",
                                            "id": 38
                                        }
                                    ]
                                }
                            ]
                        }
                    ]
                }
            ]
        } 

expr_right_constant_error = {
            "val": "COS",
            "id": 1,
            "children": [
                {
                    "val": "ADD",
                    "id": 3,
                    "children": [
                        {
                            "val": "MUL",
                            "id": 7,
                            "children": [
                                {
                                    "val": "j",
                                    "id": 15
                                },
                                {
                                    "val": "-1",
                                    "id": 16
                                }
                            ]
                        },
                        {
                            "val": "POW",
                            "id": 8,
                            "children": [
                                {
                                    "val": "9.5",
                                    "id": 17
                                },
                                {
                                    "val": "MUL",
                                    "id": 18,
                                    "children": [
                                        {
                                            "val": "a",
                                            "id": 37
                                        },
                                        {
                                            "val": "-1",
                                            "id": 38
                                        }
                                    ]
                                }
                            ]
                        }
                    ]
                }
            ]
        }

In [188]:
expr_l_obj = dict_to_geometric_representation(expr_left, eval_node_attr_encoder)
expr_r_obj = dict_to_geometric_representation(expr_right, eval_node_attr_encoder)
expr_r_c_obj = dict_to_geometric_representation(expr_right_constant_error, eval_node_attr_encoder)

In [189]:
y_1 = model(expr_l_obj, expr_r_obj)
y_2 = model(expr_l_obj, expr_l_obj)
y_3 = model(expr_r_obj, expr_r_obj)
y_4 = model(expr_r_obj, expr_r_c_obj)

In [190]:
print(F.pairwise_distance(y_1[0], y_1[1]))
print(F.pairwise_distance(y_2[0], y_2[1]))
print(F.pairwise_distance(y_3[0], y_3[1]))
print(F.pairwise_distance(y_4[0], y_4[1]))


tensor([0.0666], grad_fn=<NormBackward1>)
tensor([8.0000e-06], grad_fn=<NormBackward1>)
tensor([8.0000e-06], grad_fn=<NormBackward1>)
tensor([0.0078], grad_fn=<NormBackward1>)


# Misc

In [191]:
# from torchviz import make_dot

In [192]:
# make_dot(y_f.mean(), params=dict(formulanet.named_parameters()))