# Graph Contrastive Learning for Query Trees

In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.transforms import RandomNodeSplit
import random
from db import Database
import config
import pandas as pd
from qep import Graph
import embeddings as emb


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
queries = pd.read_csv('traindataset/queries_tpch_train.csv', header = 0, names = ['queries'], usecols= [2])
queries

Unnamed: 0,queries
0,"select l_returnflag, l_linestatus, sum(l_quant..."
1,select sum(l_extendedprice* (1 - l_discount)) ...
2,"select s_name, s_address from supplier, nation..."
3,"select s_name, count(*) as numwait from suppli..."
4,"select cntrycode, count(*) as numcust, sum(c_a..."
...,...
3767,"select p_brand, p_type, p_size, count(distinct..."
3768,select sum(l_extendedprice) / 7.0 as avg_yearl...
3769,select sum(l_extendedprice* (1 - l_discount)) ...
3770,"select s_name, s_address from supplier, nation..."


In [3]:

db = Database(user=config.USER, dbname=config.DBASE)  # Database connection
db.connect()  # Connect to the database
qtrees = []
count = 0
for query in queries['queries']:
    qtree, _, _, _, error = db.getQep(query)
    #filtering out invalid queries from the query set
    if error:
        continue
    G = Graph()
    G.parseQep(qtree)
    # Create graph features and edge indices
    x = torch.tensor(emb.createGraphFeatures(nodes=G.nodes), dtype=torch.float)
    edge_index = torch.tensor(G.edges, dtype=torch.long)
    qtrees.append(Data(x=x, edge_index=edge_index))
    count += 1
    if count == 200:
        break


    


Error executing query: syntax error at or near "("
LINE 1: ...hipdate <= date '1998-12-01' - interval '112' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or near "("
LINE 1: ...shipdate <= date '1998-12-01' - interval '97' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or near "("
LINE 1: ...shipdate <= date '1998-12-01' - interval '79' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or near "("
LINE 1: ...hipdate <= date '1998-12-01' - interval '101' day (3) group ...
                                                             ^

rolling back transaction
transaction rolled back
Error executing query: syntax error at or ne

In [None]:
print(qtrees[0])

[{'Plan': {'Node Type': 'Aggregate', 'Strategy': 'Plain', 'Partial Mode': 'Finalize', 'Parallel Aware': False, 'Async Capable': False, 'Startup Cost': 41040.96, 'Total Cost': 41040.97, 'Plan Rows': 1, 'Plan Width': 32, 'Plans': [{'Node Type': 'Gather', 'Parent Relationship': 'Outer', 'Parallel Aware': False, 'Async Capable': False, 'Startup Cost': 41040.73, 'Total Cost': 41040.94, 'Plan Rows': 2, 'Plan Width': 32, 'Workers Planned': 2, 'Single Copy': False, 'Plans': [{'Node Type': 'Aggregate', 'Strategy': 'Plain', 'Partial Mode': 'Partial', 'Parent Relationship': 'Outer', 'Parallel Aware': False, 'Async Capable': False, 'Startup Cost': 40040.73, 'Total Cost': 40040.74, 'Plan Rows': 1, 'Plan Width': 32, 'Plans': [{'Node Type': 'Hash Join', 'Parent Relationship': 'Outer', 'Parallel Aware': False, 'Async Capable': False, 'Join Type': 'Inner', 'Startup Cost': 2521.19, 'Total Cost': 40040.66, 'Plan Rows': 10, 'Plan Width': 12, 'Inner Unique': True, 'Hash Cond': '(lineitem.l_partkey = part.p

In [4]:
class QueryTreeDataset(Dataset):
    def __init__(self, qtrees, root='', transform=None, pre_transform=None):
        """
        A custom PyTorch Geometric Dataset for storing QEP-based graph data.

        Args:
            queries (dict): A dictionary containing query strings under `queries['queries']`.
            root (str): Root directory for saving/loading the dataset (not used here since it's dynamic).
            transform (callable, optional): A function/transform to apply to each graph.
            pre_transform (callable, optional): A function/transform to apply before saving the dataset.
        """
        super().__init__(root, transform, pre_transform)
        self.qtrees = qtrees

    def len(self):
        """
        Returns the number of queries in the dataset.
        """
        return len(self.qtrees)

    def get(self, idx):
        """
        Fetches and parses the QEP for a specific query to generate a graph.

        Args:
            idx (int): Index of the query.

        Returns:
            torch_geometric.data.Data: A graph object representing the QEP.
        """

        return self.qtrees[idx]
    



In [5]:
# Define the GNN model
class GNNEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        # Global mean pooling to get a graph-level representation
        x = global_mean_pool(x, batch)
        return x

### NCE Loss

For a given sample \(i\), the loss is defined as:

$$
\mathcal{L}_i = -\log \frac{\exp(\text{sim}(z_1[i], z_2[i]) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(z_1[i], z_2[j]) / \tau)}
$$

Where:
- $( \text{sim}(z_1[i], z_2[j])) $: Similarity between embeddings $( z_1[i] )$ and $( z_2[j] )$ (e.g., cosine similarity or dot product).
- $( \tau )$: Temperature parameter controlling the sharpness of the distribution.
- $( N )$: Number of samples in the batch.

The total loss across the batch is averaged:

$$
\mathcal{L} = \frac{1}{N} \sum_{i=1}^N \mathcal{L}_i
$$


In [None]:
def contrastive_loss(z1, z2, temperature=0.5):
    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    similarity_matrix = torch.mm(z1, z2.t()) / temperature
    positive_sim = torch.diag(similarity_matrix)
    negative_sim = torch.exp(similarity_matrix).sum(dim=1) - torch.exp(positive_sim)
    loss = -torch.log(torch.exp(positive_sim) / negative_sim + 1e-8).mean()
    return loss

In [12]:
def mask_node_features(data, device, mask_prob=0.3):
    """
    Masks random features in the node feature matrix by setting them to zero.

    Args:
        data (torch_geometric.data.Data): A graph data object with `data.x` as the node feature matrix.
        mask_prob (float): The probability of masking each feature (default is 0.1).

    Returns:
        torch_geometric.data.Data: A graph data object with masked node features.
    """
    x = data.x  # Node feature matrix (shape: [num_nodes, num_features])
    # Create a random mask with the same shape as the feature matrix
    mask = torch.rand(x.shape) > mask_prob
    mask = mask.to(device)
    # Apply the mask (features with a mask value of False are set to 0)
    x = x * mask.float()
    # Update the graph's feature matrix with the masked features
    data.x = x
    return data

In [13]:
# Training loop
def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        print(data.x.shape)
        # Generate two augmented views
        view1 = mask_node_features(data, device)
        view2 = mask_node_features(data, device)
        view1, view2 = view1.to(device), view2.to(device)
        # Get embeddings
        z1 = model(view1.x, view1.edge_index, view1.batch)
        z2 = model(view2.x, view2.edge_index, view2.batch)
        # Compute contrastive loss
        loss = contrastive_loss(z1, z2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
# Main
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = 2307  # Adjust based on your graph node features
hidden_dim = 64
output_dim = 32


query_data = QueryTreeDataset(qtrees)

loader = DataLoader(query_data, batch_size=32, shuffle=True)
# Model and optimizer
model = GNNEncoder(input_dim, hidden_dim, output_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# Training
epochs = 50
counter = 0
for epoch in range(epochs):
    loss_new = train(model, loader, optimizer, device)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss_new:.4f}")

torch.save(model.state_dict(), 'models/GNNQueryEncoder.pt')

torch.Size([370, 2307])
torch.Size([386, 2307])
torch.Size([465, 2307])
torch.Size([387, 2307])
torch.Size([367, 2307])
torch.Size([390, 2307])
torch.Size([119, 2307])
Epoch 1/50, Loss: 3.2071
torch.Size([421, 2307])
torch.Size([392, 2307])
torch.Size([353, 2307])
torch.Size([424, 2307])
torch.Size([366, 2307])
torch.Size([425, 2307])
torch.Size([103, 2307])
Epoch 2/50, Loss: 3.2212
torch.Size([346, 2307])
torch.Size([398, 2307])
torch.Size([427, 2307])
torch.Size([371, 2307])
torch.Size([429, 2307])
torch.Size([417, 2307])
torch.Size([96, 2307])
Epoch 3/50, Loss: 3.2208
torch.Size([392, 2307])
torch.Size([346, 2307])
torch.Size([345, 2307])
torch.Size([441, 2307])
torch.Size([438, 2307])
torch.Size([426, 2307])
torch.Size([96, 2307])
Epoch 4/50, Loss: 3.2212
torch.Size([407, 2307])
torch.Size([450, 2307])
torch.Size([386, 2307])
torch.Size([420, 2307])
torch.Size([430, 2307])
torch.Size([312, 2307])
torch.Size([79, 2307])
Epoch 5/50, Loss: 3.2213
torch.Size([439, 2307])
torch.Size([38

KeyboardInterrupt: 