In [None]:
!pip install sentence_transformers
!pip install torch_geometric

Collecting sentence_transformers
  Downloading sentence_transformers-3.2.1-py3-none-any.whl.metadata (10 kB)
Downloading sentence_transformers-3.2.1-py3-none-any.whl (255 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m255.8/255.8 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence_transformers
Successfully installed sentence_transformers-3.2.1
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd /content/drive/MyDrive/CS5284

/content/drive/MyDrive/CS5284


### Test training GNN
1. Concatenate the question embedding with each node embedding. (GCN)

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

from functions import *
from torch.utils.data import DataLoader

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, input_dim, question_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim + question_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, batched_subgraphs):
        # concatenate question embeddings with node features for each subgraph along feature dimension
        question_emb_expanded = []
        for subgraph in batched_subgraphs.to_data_list():
            subgraph.x = torch.cat((subgraph.x, subgraph.qn.unsqueeze(0).expand(subgraph.x.size(0), -1)), dim=1)
            question_emb_expanded.append(subgraph.x)

        batched_subgraphs.x = torch.cat(question_emb_expanded, dim=0)
        x, edge_index = batched_subgraphs.x, batched_subgraphs.edge_index

        # GCN
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # For node classification

# Binary classification (answer candidate or not)
model = GCN(input_dim=64, question_dim=384, hidden_dim=64, output_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
# Training loop
def train(dataloader):
    """
    returns average loss for each epoch
    """
    model.train()
    total_loss = 0

    # loop batches from dataloader
    for batched_subgraphs, labels in dataloader:

        optimizer.zero_grad()

        batched_subgraphs = batched_subgraphs.to(device)
        labels = [label.to(device) for label in labels]

        # forward pass
        out = model(batched_subgraphs)

        # calculate loss
        batch_loss = 0
        for i, label in enumerate(labels):
            node_mask = (batched_subgraphs.batch == i)
            batch_loss += F.nll_loss(out[node_mask], label)

        # backward pass and optimization step
        batch_loss.backward()
        optimizer.step()

        # shift back to cpu
        batch_loss = batch_loss.detach().cpu()
        output_cpu = out.detach().cpu()
        labels_cpu = [label.detach().cpu() for label in labels]

        print('Batch loss is', batch_loss.item())

        total_loss += batch_loss.item()

    return total_loss / len(dataloader)

In [None]:
# Evaluation
def evaluate(dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batched_subgraphs, labels in dataloader:
            batched_subgraphs = batched_subgraphs.to(device)
            labels = [label for label in labels]

            out = model(batched_subgraphs)
            output_cpu = out.detach().cpu()

            # calculate accuracy for each subgraph
            for i, label in enumerate(labels):
                node_mask = (batched_subgraphs.batch == i).detach().cpu()
                preds = output_cpu[node_mask].argmax(dim=1)
                correct += (preds == label).sum().item()
                total += label.size(0)

    accuracy = correct / total

    return accuracy

In [None]:
path_to_node_embed = 'ud_node2vec_embeddings.txt'
path_to_idxes = 'idxes.json'
path_to_qa = 'qa_train.txt'

data = KGQADataset(path_to_node_embed, path_to_idxes, path_to_qa)
sub_data1 = torch.utils.data.Subset(data, list(range(1280)))
dataloader_train = DataLoader(sub_data1, batch_size=64, collate_fn=collate_fn, shuffle=True)

sub_data2 = torch.utils.data.Subset(data, list(range(320)))
dataloader_val = DataLoader(sub_data2, batch_size=32, collate_fn=collate_fn, shuffle=True)

In [None]:
test = KGQADataset(path_to_node_embed, path_to_idxes, 'qa_test.txt')
sub_data3 = torch.utils.data.Subset(test, list(range(320)))
dataloader_test = DataLoader(sub_data3, batch_size=32, collate_fn=collate_fn, shuffle=True)

In [None]:
for epoch in range(1):
    loss = train(dataloader_train)
    val_accuracy = evaluate(dataloader_val)
    test_accuracy = evaluate(dataloader_test)
    print(f'Epoch {epoch}, Train Loss: {loss}, Validation Accuracy: {val_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}')

Batch loss is 47.83632278442383
Batch loss is 19.901161193847656
Batch loss is 7.567558765411377
Batch loss is 3.4319286346435547
Batch loss is 1.6989784240722656
Batch loss is 1.5260528326034546
Batch loss is 2.3081700801849365
Batch loss is 1.495195746421814
Batch loss is 2.4808621406555176
Batch loss is 3.392915964126587
Batch loss is 2.7652828693389893
Batch loss is 3.8401472568511963
Batch loss is 2.264253854751587
Batch loss is 2.511430501937866
Batch loss is 2.5441229343414307
Batch loss is 2.496483564376831
Batch loss is 2.9182167053222656
Batch loss is 3.2042768001556396
Batch loss is 4.352218151092529
Batch loss is 4.143481254577637
Epoch 0, Train Loss: 6.133953022956848, Validation Accuracy: 0.9989, Test Accuracy: 0.9989
