In [1]:
import pandas as pd
from lib.trees import get_tree, parse_edge_list, to_torch
from lib.dataset import split_training_validation
import random

2025-05-02 01:38:21.898189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746142701.913669  168654 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746142701.918043  168654 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746142701.929180  168654 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746142701.929201  168654 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746142701.929202  168654 computation_placer.cc:177] computation placer alr

In [2]:
sentences = pd.read_csv("../data/train.csv")
sentences["language"] = sentences["language"].astype("category")
sentences["edgelist"] = sentences["edgelist"].apply(parse_edge_list)
sentences["tree"] = sentences["edgelist"].apply(get_tree)
sentences.head()

Unnamed: 0,language,sentence,n,edgelist,root,tree
0,Japanese,2,23,"[(6, 4), (2, 6), (2, 23), (20, 2), (15, 20), (...",10,"(6, 4, 2, 23, 20, 15, 3, 5, 14, 8, 12, 9, 18, ..."
1,Japanese,5,18,"[(8, 9), (14, 8), (4, 14), (5, 4), (1, 2), (6,...",10,"(8, 9, 14, 4, 5, 1, 2, 6, 17, 12, 3, 7, 11, 16..."
2,Japanese,8,33,"[(2, 10), (2, 14), (4, 2), (16, 4), (6, 16), (...",3,"(2, 10, 14, 4, 16, 6, 12, 32, 26, 3, 29, 27, 2..."
3,Japanese,11,30,"[(30, 1), (14, 24), (21, 14), (3, 21), (7, 3),...",30,"(30, 1, 14, 24, 21, 3, 7, 12, 27, 16, 8, 5, 26..."
4,Japanese,12,19,"[(19, 13), (16, 19), (2, 16), (4, 10), (4, 15)...",11,"(19, 13, 16, 2, 4, 10, 15, 5, 14, 12, 3, 1, 8,..."


In [3]:
random.seed(42)
training, validation = split_training_validation(sentences, 0.2)

print("Training set size:", len(training))
print("Validation set size:", len(validation))

Training set size: 8400
Validation set size: 2100


Let's experiment with adding no features to the nodes or the edges, and just letting the model learn the structure of the graph.

In [31]:
from torch_geometric.loader import DataLoader
import networkx as nx


def get_node_labels(row: pd.Series):
    return {n: [n == row["root"]] for n in row["tree"].nodes()}


def get_node_features(row: pd.Series):
    tree = row["tree"]
    degree_centrality = nx.degree_centrality(tree)
    harmonic_centrality = nx.harmonic_centrality(tree)
    betweenness_centrality = nx.betweenness_centrality(tree)
    pagerank = nx.pagerank(tree)

    return {
        n: [degree_centrality[n], harmonic_centrality[n], betweenness_centrality[n], pagerank[n]] for n in tree.nodes()
    }


trees = training.apply(
    lambda r: to_torch(r["tree"], node_features=get_node_features(r), node_labels=get_node_labels(r)), axis=1
).tolist()
train_dataset = DataLoader(trees, batch_size=32, shuffle=True)

validation_trees = validation.apply(
    lambda r: to_torch(r["tree"], node_features=get_node_features(r), node_labels=get_node_labels(r)), axis=1
).tolist()
validation_dataset = DataLoader(validation_trees, batch_size=32, shuffle=False)

In [5]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch


class GCN(torch.nn.Module):
    def __init__(self, num_features: int, hidden_channels: int, num_classes: int):
        super(GCN, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [6]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE


def visualize(h, color):
    z = TSNE(n_components=2, perplexity=min(30, h.shape[0] - 1)).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10, 10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

In [7]:
model = GCN(num_features=trees[0].x.shape[1], hidden_channels=16, num_classes=1)
print(model)

GCN(
  (conv1): GCNConv(4, 16)
  (conv2): GCNConv(16, 1)
)


In [32]:
from tqdm import tqdm

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


class_weights = {0: 0.1, 1: 1.0}


def custom_mse_loss(y_pred, y_true):
    """
    MSE loss function with class weights.
    """
    y_pred = y_pred.view(-1)
    y_true = y_true.view(-1)
    weights = torch.tensor([class_weights[int(y)] for y in y_true], dtype=torch.float32).to(y_pred.device)
    loss = F.mse_loss(y_pred, y_true, reduction="none")
    return (loss * weights).mean()


model.train()
max_epochs = 100
for epoch in range(1, max_epochs + 1):
    with tqdm(
        total=len(train_dataset), desc=f"Epoch {epoch}/{max_epochs}", leave=False, ncols=100, unit="batch", position=0
    ) as bar:
        for batch in train_dataset:
            out = model(batch.x, batch.edge_index)
            loss = custom_mse_loss(out, batch.y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            bar.set_postfix_str(f"loss={loss.item():.5f}")
            bar.update(1)

        val_loss = 0
        for batch in validation_dataset:
            with torch.no_grad():
                out = model(batch.x, batch.edge_index)
                val_loss += custom_mse_loss(out, batch.y).item()
        val_loss /= len(validation_dataset)
    print(f"Epoch {epoch}/{max_epochs}: loss={loss:.5f} val_loss={val_loss:.5f}")

                                                                                                    

Epoch 1/100: loss=0.05674 val_loss=0.05406


                                                                                                    

Epoch 2/100: loss=0.05536 val_loss=0.05406


                                                                                                    

Epoch 3/100: loss=0.05387 val_loss=0.05406


                                                                                                    

Epoch 4/100: loss=0.04893 val_loss=0.05406


                                                                                                    

Epoch 5/100: loss=0.05714 val_loss=0.05406


                                                                                                    

Epoch 6/100: loss=0.05063 val_loss=0.05406


                                                                                                    

Epoch 7/100: loss=0.05993 val_loss=0.05406


                                                                                                    

Epoch 8/100: loss=0.04969 val_loss=0.05406


                                                                                                    

Epoch 9/100: loss=0.06107 val_loss=0.05406


                                                                                                    

Epoch 10/100: loss=0.05556 val_loss=0.05406


                                                                                                    

Epoch 11/100: loss=0.05178 val_loss=0.05406


                                                                                                    

Epoch 12/100: loss=0.06478 val_loss=0.05406


                                                                                                    

KeyboardInterrupt: 

In [34]:
correct = 0
for batch in validation_dataset:
    with torch.no_grad():
        out = model(batch.x, batch.edge_index)
        print(out)
        expected = torch.argmax(batch.y, dim=0)
        predicted = torch.argmax(out, dim=0)
        if predicted == expected:
            correct += 1
accuracy = correct / len(validation_dataset.dataset)
print(f"Accuracy: {accuracy:.4f}")

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      