In [1]:
import pandas as pd
from lib.trees import get_tree, parse_edge_list, to_torch
from lib.training import TrainingLoop
from lib.training.callbacks import EarlyStopping
from lib.dataset import split_training_validation
import random
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv
from tqdm import tqdm
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import networkx as nx
import torch
from sklearn.manifold import TSNE

2025-05-03 20:55:02.676393: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746298502.690793  167569 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746298502.695000  167569 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746298502.705707  167569 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746298502.705731  167569 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746298502.705733  167569 computation_placer.cc:177] computation placer alr

In [2]:
sentences = pd.read_csv("../data/train.csv")
sentences["language"] = sentences["language"].astype("category")
sentences["edgelist"] = sentences["edgelist"].apply(parse_edge_list)
sentences["tree"] = sentences["edgelist"].apply(get_tree)
sentences.head()

Unnamed: 0,language,sentence,n,edgelist,root,tree
0,Japanese,2,23,"[(6, 4), (2, 6), (2, 23), (20, 2), (15, 20), (...",10,"(6, 4, 2, 23, 20, 15, 3, 5, 14, 8, 12, 9, 18, ..."
1,Japanese,5,18,"[(8, 9), (14, 8), (4, 14), (5, 4), (1, 2), (6,...",10,"(8, 9, 14, 4, 5, 1, 2, 6, 17, 12, 3, 7, 11, 16..."
2,Japanese,8,33,"[(2, 10), (2, 14), (4, 2), (16, 4), (6, 16), (...",3,"(2, 10, 14, 4, 16, 6, 12, 32, 26, 3, 29, 27, 2..."
3,Japanese,11,30,"[(30, 1), (14, 24), (21, 14), (3, 21), (7, 3),...",30,"(30, 1, 14, 24, 21, 3, 7, 12, 27, 16, 8, 5, 26..."
4,Japanese,12,19,"[(19, 13), (16, 19), (2, 16), (4, 10), (4, 15)...",11,"(19, 13, 16, 2, 4, 10, 15, 5, 14, 12, 3, 1, 8,..."


In [3]:
random.seed(42)
training, validation = split_training_validation(sentences, 0.2)

print("Training set size:", len(training))
print("Validation set size:", len(validation))

Training set size: 8400
Validation set size: 2100


Let's experiment with adding no features to the nodes or the edges, and just letting the model learn the structure of the graph.

In [4]:
def get_node_labels(row: pd.Series):
    return {n: [float(n == row["root"]), float(n != row["root"])] for n in row["tree"].nodes()}


def get_node_features(row: pd.Series):
    tree = row["tree"]
    degree_centrality = nx.degree_centrality(tree)
    harmonic_centrality = nx.harmonic_centrality(tree)
    betweenness_centrality = nx.betweenness_centrality(tree)
    pagerank = nx.pagerank(tree)

    return {
        n: [degree_centrality[n], harmonic_centrality[n], betweenness_centrality[n], pagerank[n]] for n in tree.nodes()
    }


trees = training.apply(
    lambda r: to_torch(r["tree"], node_features=get_node_features(r), node_labels=get_node_labels(r)), axis=1
).tolist()
train_dataset = DataLoader(trees, batch_size=32, shuffle=True)
# DataLoader groups sub-graphs into mini-batches. See https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html
# for more details. Basically, it combines the smaller graphs into a single large
# disconnected graph without padding, using smart indexing and concatenation
# techniques.

validation_trees = validation.apply(
    lambda r: to_torch(r["tree"], node_features=get_node_features(r), node_labels=get_node_labels(r)), axis=1
).tolist()
validation_dataset = DataLoader(validation_trees, batch_size=32, shuffle=False)

In [5]:
class GCN(torch.nn.Module):
    def __init__(self, num_features: int, hidden_channels: int, num_classes: int, hidden_layers: int = 10):
        super(GCN, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.hidden_layers = [GCNConv(hidden_channels, hidden_channels) for _ in range(hidden_layers)]
        self.dense = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.1, training=self.training)
        for layer in self.hidden_layers:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.1, training=self.training)
        x = self.dense(x)
        return F.softmax(x, dim=1)

In [6]:
model = GCN(num_features=trees[0].x.shape[1], hidden_channels=16, num_classes=2, hidden_layers=10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

class_weights = {0: 0.1, 1: 1.0}


def custom_mse_loss(y_pred, y_true):
    """
    MSE loss function with class weights.
    """
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    weights = torch.tensor([class_weights[int(y)] for y in y_true], dtype=torch.float32).to(y_pred.device)
    loss = F.mse_loss(y_pred, y_true, reduction="none")
    return (loss * weights).mean()


TrainingLoop(
    model=model,
    optimizer=optimizer,
    loss_fn=custom_mse_loss,
    train_loader=train_dataset,
    val_loader=validation_dataset,
    epochs=100,
    callbacks=[EarlyStopping(patience=10)],
).train()

                                                                                                    

Epoch 1/100: loss=0.02797 val_loss=0.02903


                                                                                                    

Epoch 2/100: loss=0.02467 val_loss=0.02876


                                                                                                    

Epoch 3/100: loss=0.03156 val_loss=0.02851


                                                                                                    

Epoch 4/100: loss=0.02597 val_loss=0.02833


                                                                                                    

Epoch 5/100: loss=0.02813 val_loss=0.02822


                                                                                                    

Epoch 6/100: loss=0.02890 val_loss=0.02818


                                                                                                    

Epoch 7/100: loss=0.02725 val_loss=0.02818


                                                                                                    

Epoch 8/100: loss=0.02692 val_loss=0.02818


                                                                                                    

Epoch 9/100: loss=0.02647 val_loss=0.02819


                                                                                                    

Epoch 10/100: loss=0.02818 val_loss=0.02818


                                                                                                    

Epoch 11/100: loss=0.02809 val_loss=0.02818


                                                                                                    

Epoch 12/100: loss=0.02456 val_loss=0.02818


                                                                                                    

Epoch 13/100: loss=0.02899 val_loss=0.02817


                                                                                                    

Epoch 14/100: loss=0.02985 val_loss=0.02818


                                                                                                    

Epoch 15/100: loss=0.02637 val_loss=0.02817


                                                                                                    

Epoch 16/100: loss=0.02890 val_loss=0.02818


                                                                                                    

Epoch 17/100: loss=0.03036 val_loss=0.02818


                                                                                                    

Epoch 18/100: loss=0.02766 val_loss=0.02818


                                                                                                    

Epoch 19/100: loss=0.03234 val_loss=0.02817


                                                                                                    

Epoch 20/100: loss=0.02469 val_loss=0.02818


                                                                                                    

Epoch 21/100: loss=0.02677 val_loss=0.02819


                                                                                                    

Epoch 22/100: loss=0.02685 val_loss=0.02818


                                                                                                    

Epoch 23/100: loss=0.02599 val_loss=0.02817


                                                                                                    

Epoch 24/100: loss=0.02565 val_loss=0.02818


                                                                                                    

Epoch 25/100: loss=0.02609 val_loss=0.02818


                                                                                                    

Epoch 26/100: loss=0.02948 val_loss=0.02821


                                                                                                    

Epoch 27/100: loss=0.02751 val_loss=0.02819


                                                                                                    

Epoch 28/100: loss=0.02765 val_loss=0.02817


                                                                                                    

Epoch 29/100: loss=0.02602 val_loss=0.02819


                                                                                                    

Epoch 30/100: loss=0.02749 val_loss=0.02818


                                                                                                    

Epoch 31/100: loss=0.02684 val_loss=0.02817


                                                                                                    

Epoch 32/100: loss=0.02401 val_loss=0.02822


                                                                                                    

Epoch 33/100: loss=0.02701 val_loss=0.02818




In [7]:
correct = 0
for tree in tqdm(validation_trees, leave=False):
    with torch.no_grad():
        out = model(tree.x, tree.edge_index)
        expected = torch.argmax(tree.y, dim=0)[0]
        predicted = torch.argmax(out, dim=0)[0]
        if predicted == expected:
            correct += 1
accuracy = correct / len(validation_dataset.dataset)
print(f"Accuracy: {accuracy:.4f}")

                                                    

Accuracy: 0.5933


