# Graph Contrastive Learning for Query Trees

In [48]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.transforms import RandomNodeSplit
import random
from db import Database
import config
import pandas as pd
from qep import Graph
from embeddings import nodeEmbedder
import pickle
from collections import defaultdict


In [49]:
queries = pd.read_csv('traindataset/queries_tpch_train_labelled.csv', header=0, names=['queries', 'template'], usecols=[2, 3])
queries

Unnamed: 0,queries,template
0,"select l_returnflag, l_linestatus, sum(l_quant...",0
1,select sum(l_extendedprice* (1 - l_discount)) ...,1
2,"select s_name, s_address from supplier, nation...",2
3,"select s_name, count(*) as numwait from suppli...",3
4,"select cntrycode, count(*) as numcust, sum(c_a...",4
...,...,...
3767,"select p_brand, p_type, p_size, count(distinct...",12
3768,select sum(l_extendedprice) / 7.0 as avg_yearl...,13
3769,select sum(l_extendedprice* (1 - l_discount)) ...,1
3770,"select s_name, s_address from supplier, nation...",2


In [50]:
# Load the qtree embeddings
with open('traindataset/qtrees_embedding.pkl', 'rb') as file:
    embeddings = pickle.load(file)

embeddings

[array([[-5.34530836e+00,  1.34176389e+02,  5.53131029e+01,
          3.14679140e+02,  4.42498280e+02, -3.25956586e+02,
          3.99364452e+02, -4.20629466e+02, -2.68855425e+01,
          7.46615569e+01, -7.25084726e+02,  5.60728535e+02,
         -1.68665841e+02, -2.56369421e+01, -9.47779657e+00,
         -1.37886117e+00,  4.19347349e+01,  7.17416155e+01,
         -6.36660421e+01, -1.84538750e+02, -1.05231516e+01,
          1.35718335e+02, -2.74003484e+01,  4.84949632e+01,
          2.00906512e+02,  3.14749392e+02, -3.08232206e+02,
          7.27073055e+02,  2.46025402e+01, -7.32570352e+02,
          3.66160726e+02,  3.13033570e+02, -1.47618934e+02,
         -1.59239543e+02,  4.62692970e+02, -3.92523481e+02,
          1.79197765e+02,  1.34049228e+02,  1.48862997e+02,
          1.42268248e+02,  2.71062144e+02, -1.48590781e+02,
          5.27327305e+02,  2.54558167e+02,  6.52968234e+01,
         -2.09359082e+02,  2.10932220e+02, -9.16592403e+02],
        [-5.06442931e+00,  1.33055053e+

In [51]:
len(embeddings)

3732

In [72]:

db = Database(user=config.USER, dbname=config.DBASE)  # Database connection
db.connect()  # Connect to the database
qtrees = []
count = 0
label_count = defaultdict(int)
for i in range(len(queries)):
    query = queries['queries'][i]
    qtree, _, _, _, error = db.getQep(query)
    #filtering out invalid queries from the query set
    if error:
        continue
    label_count[queries['template'][i]] += 1
    G = Graph()
    G.parseQep(qtree)
    # Create graph features and edge indices
    #emb = nodeEmbedder(8, 'mps', False)
    #x = torch.tensor(emb.createGraphFeatures(nodes=G.nodes), dtype=torch.float)
    x = torch.tensor(embeddings[count], dtype=torch.float)
    edge_index = torch.tensor(G.edges, dtype=torch.long)
    qtrees.append(Data(x=x, edge_index=edge_index))
    count += 1
    # if count == 200:
    #     break
    


    


Error executing query: syntax error at or near "("
LINE 1: ...hipdate <= date '1998-12-01' - interval '112' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or near "("
LINE 1: ...shipdate <= date '1998-12-01' - interval '97' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or near "("
LINE 1: ...shipdate <= date '1998-12-01' - interval '79' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or near "("
LINE 1: ...hipdate <= date '1998-12-01' - interval '101' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or ne

In [73]:
print(qtrees)

[Data(x=[7, 48], edge_index=[2, 6]), Data(x=[12, 48], edge_index=[2, 11]), Data(x=[17, 48], edge_index=[2, 16]), Data(x=[8, 48], edge_index=[2, 7]), Data(x=[23, 48], edge_index=[2, 22]), Data(x=[20, 48], edge_index=[2, 19]), Data(x=[4, 48], edge_index=[2, 3]), Data(x=[19, 48], edge_index=[2, 18]), Data(x=[22, 48], edge_index=[2, 21]), Data(x=[18, 48], edge_index=[2, 17]), Data(x=[8, 48], edge_index=[2, 7]), Data(x=[8, 48], edge_index=[2, 7]), Data(x=[7, 48], edge_index=[2, 6]), Data(x=[7, 48], edge_index=[2, 6]), Data(x=[12, 48], edge_index=[2, 11]), Data(x=[8, 48], edge_index=[2, 7]), Data(x=[23, 48], edge_index=[2, 22]), Data(x=[11, 48], edge_index=[2, 10]), Data(x=[7, 48], edge_index=[2, 6]), Data(x=[20, 48], edge_index=[2, 19]), Data(x=[4, 48], edge_index=[2, 3]), Data(x=[19, 48], edge_index=[2, 18]), Data(x=[22, 48], edge_index=[2, 21]), Data(x=[18, 48], edge_index=[2, 17]), Data(x=[14, 48], edge_index=[2, 13]), Data(x=[8, 48], edge_index=[2, 7]), Data(x=[8, 48], edge_index=[2, 7]

In [54]:
label_count

defaultdict(int,
            {np.int64(1): 471,
             np.int64(2): 471,
             np.int64(3): 7,
             np.int64(4): 471,
             np.int64(5): 402,
             np.int64(6): 9,
             np.int64(7): 34,
             np.int64(8): 144,
             np.int64(9): 470,
             np.int64(10): 67,
             np.int64(11): 179,
             np.int64(12): 470,
             np.int64(13): 384,
             np.int64(14): 73,
             np.int64(15): 32,
             np.int64(16): 4,
             np.int64(17): 35,
             np.int64(18): 5,
             np.int64(19): 4})

In [74]:
class QueryTreeDataset(Dataset):
    def __init__(self, qtrees, root='', transform=None, pre_transform=None):
        """
        A custom PyTorch Geometric Dataset for storing QEP-based graph data.

        Args:
            queries (dict): A dictionary containing query strings under `queries['queries']`.
            root (str): Root directory for saving/loading the dataset (not used here since it's dynamic).
            transform (callable, optional): A function/transform to apply to each graph.
            pre_transform (callable, optional): A function/transform to apply before saving the dataset.
        """
        super().__init__(root, transform, pre_transform)
        self.qtrees = qtrees
        #self.templates = qtrees[1]

    def len(self):
        """
        Returns the number of queries in the dataset.
        """
        return len(self.qtrees)

    def get(self, idx):
        """
        Fetches and parses the QEP for a specific query to generate a graph.

        Args:
            idx (int): Index of the query.

        Returns:
            torch_geometric.data.Data: A graph object representing the QEP.
        """

        return self.qtrees[idx]
    



In [75]:
# Define the GNN model
class GNNEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        # Global mean pooling to get a graph-level representation
        x = global_mean_pool(x, batch)
        return x

### NCE Loss

For a given sample \(i\), the loss is defined as:

$$
\mathcal{L}_i = -\log \frac{\exp(\text{sim}(z_1[i], z_2[i]) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(z_1[i], z_2[j]) / \tau)}
$$

Where:
- $( \text{sim}(z_1[i], z_2[j])) $: Similarity between embeddings $( z_1[i] )$ and $( z_2[j] )$ (e.g., cosine similarity or dot product).
- $( \tau )$: Temperature parameter controlling the sharpness of the distribution.
- $( N )$: Number of samples in the batch.

The total loss across the batch is averaged:

$$
\mathcal{L} = \frac{1}{N} \sum_{i=1}^N \mathcal{L}_i
$$


In [76]:
def contrastive_loss(z1, z2, temperature=0.5):
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    similarity_matrix = torch.mm(z1, z2.t()) / temperature
    positive_sim = torch.diag(similarity_matrix)
    negative_sim = torch.exp(similarity_matrix).sum(dim=1) - torch.exp(positive_sim)
    loss = -torch.log(torch.exp(positive_sim) / (negative_sim + 1e-8)).mean()
    return loss

In [77]:
def mask_node_features(data, device, mask_prob=0.3):
    """
    Masks random features in the node feature matrix by setting them to zero.

    Args:
        data (torch_geometric.data.Data): A graph data object with `data.x` as the node feature matrix.
        mask_prob (float): The probability of masking each feature (default is 0.1).

    Returns:
        torch_geometric.data.Data: A graph data object with masked node features.
    """
    x = data.x  # Node feature matrix (shape: [num_nodes, num_features])
    # Create a random mask with the same shape as the feature matrix
    mask = torch.rand(x.shape) > mask_prob
    mask = mask.to(device)
    # Apply the mask (features with a mask value of False are set to 0)
    x = x * mask.float()
    # Update the graph's feature matrix with the masked features
    data.x = x
    return data

In [80]:
# Training loop
def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        # print(data.x.shape)
        # Generate two augmented views
        view1 = mask_node_features(data, device)
        view2 = mask_node_features(data, device)
        view1, view2 = view1.to(device), view2.to(device)
        # Get embeddings
        z1 = model(view1.x, view1.edge_index, view1.batch)
        z2 = model(view2.x, view2.edge_index, view2.batch)
        # Compute contrastive loss
        loss = contrastive_loss(z1, z2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

In [81]:
# Main
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = 48  # Adjust based on your graph node features
hidden_dim = 64
output_dim = 32


query_data = QueryTreeDataset(qtrees)

loader = DataLoader(query_data, batch_size=32, shuffle=True)
# Model and optimizer
model = GNNEncoder(input_dim, hidden_dim, output_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# Training with early stopping
epochs = 50
patience = 5
best_loss = float('inf')
counter = 0
for epoch in range(epochs):
    loss_new = train(model, loader, optimizer, device)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss_new:.4f}")
    if loss_new < best_loss:
        best_loss = loss_new
        counter = 0
    else:
        counter += 1
    if counter == patience:
        print(f"Early stopping at epoch {epoch+1}")
        break

torch.save(model.state_dict(), 'models/GNNQueryEncoder.pt')

Epoch 1/50, Loss: 2.1020
Epoch 2/50, Loss: 1.7604
Epoch 3/50, Loss: 1.6760
Epoch 4/50, Loss: 1.6435
Epoch 5/50, Loss: 1.6326
Epoch 6/50, Loss: 1.6159
Epoch 7/50, Loss: 1.6141
Epoch 8/50, Loss: 1.6203
Epoch 9/50, Loss: 1.6213
Epoch 10/50, Loss: 1.6217
Epoch 11/50, Loss: 1.6301
Epoch 12/50, Loss: 1.6047
Epoch 13/50, Loss: 1.6218
Epoch 14/50, Loss: 1.6225
Epoch 15/50, Loss: 1.5907
Epoch 16/50, Loss: 1.5786
Epoch 17/50, Loss: 1.5821
Epoch 18/50, Loss: 1.5859
Epoch 19/50, Loss: 1.5679
Epoch 20/50, Loss: 1.5612
Epoch 21/50, Loss: 1.5651
Epoch 22/50, Loss: 1.5877
Epoch 23/50, Loss: 1.5683
Epoch 24/50, Loss: 1.5584
Epoch 25/50, Loss: 1.5648
Epoch 26/50, Loss: 1.5777
Epoch 27/50, Loss: 1.5739
Epoch 28/50, Loss: 1.5801
Epoch 29/50, Loss: 1.5880
Early stopping at epoch 29


In [None]:
count = 0
for name, param in model.named_parameters():
    count += 1
    if param.grad is not None:
        print(f"{name} gradient norm: {param.grad.norm()}")

print(f'number of parameters: {count}')

conv1.bias gradient norm: 1.1809155608943911e-07
conv1.lin.weight gradient norm: 1.441623317077756e-05
conv2.bias gradient norm: 5.601187069714797e-08
conv2.lin.weight gradient norm: 1.925736796692945e-05
number of parameters: 4
