In [1]:
!pip install sentence_transformers
!pip install torch_geometric

Collecting sentence_transformers
  Downloading sentence_transformers-3.2.1-py3-none-any.whl.metadata (10 kB)
Downloading sentence_transformers-3.2.1-py3-none-any.whl (255 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m255.8/255.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence_transformers
Successfully installed sentence_transformers-3.2.1
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
cd /content/drive/MyDrive/CS5284

/content/drive/MyDrive/CS5284


### Test training GNN
1. Concatenate the question embedding with each node embedding. (GCN)

In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# adjust this import accordingly to how you call the script
from functions import *
from torch.utils.data import DataLoader

from sklearn.metrics import precision_score, recall_score, f1_score

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# model
class GCN(torch.nn.Module):
    def __init__(self, node_dim, question_dim, hidden_dim, output_dim=1):
        super(GCN, self).__init__()
        # 4 layers, 100 hidden_dim
        self.conv1 = GCNConv(node_dim + question_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.conv4 = GCNConv(hidden_dim, output_dim)

    def forward(self, batched_subgraphs):
        # concatenate question embeddings with node features for each subgraph along feature dimension
        question_emb_expanded = []
        for subgraph in batched_subgraphs.to_data_list():
            subgraph.x = torch.cat((subgraph.x, subgraph.qn.unsqueeze(0).expand(subgraph.x.size(0), -1)), dim=1)
            question_emb_expanded.append(subgraph.x)

        batched_subgraphs.x = torch.cat(question_emb_expanded, dim=0)
        x, edge_index = batched_subgraphs.x, batched_subgraphs.edge_index

        # GCN
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        # Output logits directly for BCEWithLogitsLoss
        return x.squeeze(-1)

# Binary classification (answer candidate or not)
model = GCN(node_dim=64, question_dim=384, hidden_dim=100).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [4]:
def focal_loss(logits, labels, alpha=0.25, gamma=2.0):
    """
    NOT USED.
    logits: Predicted output from the model (after log-softmax).
    labels: Ground truth labels (0 or 1).
    alpha: Balancing factor for the minority class.
    gamma: Focusing parameter for adjusting the rate at which easy examples are down-weighted.
    """
    # Compute cross-entropy loss per example
    ce_loss = F.nll_loss(logits, labels, reduction='none')
    # Probabilities for each example
    pt = torch.exp(-ce_loss)
    # Apply the focal loss adjustment
    focal_loss = alpha * (1 - pt) ** gamma * ce_loss

    # Return the mean loss
    return focal_loss.mean()

In [5]:
# Training loop
def train(dataloader):
    """
    returns average loss for each epoch
    """
    model.train()
    total_loss = 0

    # loop batches from dataloader
    for batched_subgraphs, labels in dataloader:

        optimizer.zero_grad()

        batched_subgraphs = batched_subgraphs.to(device)
        labels = [label.to(device) for label in labels]

        # forward pass
        out = model(batched_subgraphs)

        # calculate loss
        batch_loss = 0
        for i, label in enumerate(labels):
            node_mask = (batched_subgraphs.batch == i)
            logits = out[node_mask]
            target = label.float()

            pos_weight = torch.tensor([len(target) / target.sum()], device=device)
            loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            batch_loss += loss_fn(logits, target)

        # backward pass and optimization step
        batch_loss.backward()
        optimizer.step()

        # shift back to cpu
        batch_loss = batch_loss.detach().cpu()
        output_cpu = out.detach().cpu()
        labels_cpu = [label.detach().cpu() for label in labels]

        print('Batch loss is', batch_loss.item())

        total_loss += batch_loss.item()

    torch.cuda.empty_cache() # help clear cache taking up cuda space

    return total_loss / len(dataloader)

In [6]:
# Evaluation
def evaluate(dataloader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batched_subgraphs, labels in dataloader:
            batched_subgraphs = batched_subgraphs.to(device)
            labels = [label for label in labels]

            out = model(batched_subgraphs)
            output_cpu = out.detach().cpu()

            # calculate accuracy for each subgraph
            for i, label in enumerate(labels):
                node_mask = (batched_subgraphs.batch == i).detach().cpu()
                preds = (torch.sigmoid(output_cpu[node_mask]) > 0.5).int()
                
                all_preds.extend(preds.tolist())
                all_labels.extend(label.tolist())
                correct += (preds == label).sum().item()
                total += label.size(0)

    precision = precision_score(all_labels, all_preds, average='binary') # for positive class
    recall = recall_score(all_labels, all_preds, average='binary') # for positive class
    f1 = f1_score(all_labels, all_preds, average='binary') # for positive class
    accuracy = correct / total # biased towards 0

    torch.cuda.empty_cache() # help clear cache taking up cuda space

    return accuracy, precision, recall, f1, all_preds, all_labels

In [7]:
path_to_node_embed = '../Datasets/MetaQA_dataset/processed/node2vec _embeddings/ud_node2vec_embeddings.txt'
path_to_idxes = '../Datasets/MetaQA_dataset/processed/idxes.json'
path_to_qa = '../Datasets/MetaQA_dataset/vanilla 3-hop/qa_train.txt'

# train
data = KGQADataset(path_to_node_embed, path_to_idxes, path_to_qa)
sub_data1 = torch.utils.data.Subset(data, list(range(2000)))
dataloader_train = DataLoader(sub_data1, batch_size=64, collate_fn=collate_fn, shuffle=True)
# some from train to evaluate
sub_data2 = torch.utils.data.Subset(data, list(range(5000, 5000+400)))
dataloader_val = DataLoader(sub_data2, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [8]:
# some from test to evaluate
test = KGQADataset(path_to_node_embed, path_to_idxes, '../Datasets/MetaQA_dataset/vanilla 3-hop/qa_test.txt')
sub_data3 = torch.utils.data.Subset(test, list(range(400)))
dataloader_test = DataLoader(sub_data3, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [9]:
# train 1 epoch
for epoch in range(1):
    loss = train(dataloader_train)
    val_accuracy, val_p, val_r, val_f1, _, _ = evaluate(dataloader_val)
    test_accuracy, test_p, test_r, test_f1, _, _ = evaluate(dataloader_test)
    print(f'Epoch {epoch}, Train Loss: {loss}')
    print(f'Validation Accuracy: {val_accuracy:.8f}, Validation P/R/F1: {val_p:.3f}/{val_r:.3f}/{val_f1:.3f}')
    print(f'Test Accuracy: {test_accuracy:.8f}, Test P/R/F1: {test_p:.3f}/{test_r:.3f}/{test_f1:.3f}')

Batch loss is 89.95256042480469
Batch loss is 93.21959686279297
Batch loss is 81.2720718383789
Batch loss is 82.74179077148438
Batch loss is 81.05941009521484
Batch loss is 79.3653335571289
Batch loss is 80.254150390625
Batch loss is 79.60807800292969
Batch loss is 79.48625183105469
Batch loss is 79.6988296508789
Batch loss is 81.76297760009766
Batch loss is 82.50495147705078
Batch loss is 78.03034973144531
Batch loss is 78.94815826416016
Batch loss is 80.22224426269531
Batch loss is 79.8222885131836
Batch loss is 77.80859375
Batch loss is 76.4232177734375
Batch loss is 75.96676635742188
Batch loss is 77.85187530517578
Batch loss is 77.0613021850586
Batch loss is 77.3415298461914
Batch loss is 75.35107421875
Batch loss is 77.83760070800781
Batch loss is 72.38226318359375
Batch loss is 75.3498764038086
Batch loss is 77.80155181884766
Batch loss is 79.95442199707031
Batch loss is 75.32730102539062
Batch loss is 75.8548355102539
Batch loss is 74.46247100830078
Batch loss is 19.73121643066