In [1]:
!pip install sentence_transformers
!pip install torch_geometric
!pip install opentsne

Collecting sentence_transformers
  Downloading sentence_transformers-3.2.1-py3-none-any.whl.metadata (10 kB)
Downloading sentence_transformers-3.2.1-py3-none-any.whl (255 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m255.8/255.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence_transformers
Successfully installed sentence_transformers-3.2.1
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
cd /content/drive/MyDrive/CS5284

[Errno 2] No such file or directory: '/content/drive/MyDrive/CS5284'
/scratch/users/nus/e1329380/cs5284/QA_graph/training-graph-models


In [1]:
import sys
sys.path.append("../data_preparation")

In [2]:
import random, time
import numpy as np
from matplotlib import pyplot

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv

# adjust this import accordingly to how you call the script
from functions_modified import *
from torch.utils.data import DataLoader

# evaluation
from sklearn.metrics import precision_score, recall_score, f1_score

# visualise learnt intermediate embeddings
import matplotlib.pyplot as plt
from openTSNE import TSNE
# from sklearn.manifold import TSNE # very slow
tsne = TSNE(
    perplexity=30,
    metric="euclidean",
    n_jobs=8,
    random_state=42,
    verbose=True,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# visualise learnt output embeddings
def visualize(h, color):
    # z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
    z = tsne.fit(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    print(f"Number of positive is {sum(color)}")
    print(f"Total number is {len(color)}")
    s = [0.5 if c==0 else 7 for c in color]

    plt.scatter(z[:, 0], z[:, 1], s=1, c=color, cmap=pyplot.jet())
    plt.show()

In [11]:
# new

# train code
def train(dataloader):
    """
    one epoch
    returns average loss for one epoch
    """
    model.train()
    total_loss = 0

    start = time.time()
    
    # loop batches from dataloader
    for d, (batched_subgraphs, question_embeddings, stacked_labels, node_maps, labels, answer_types) in enumerate(dataloader):

        optimizer.zero_grad()

        batched_subgraphs = batched_subgraphs.to(device)
        labels = [label.to(device) for label in labels]

        # forward pass
        out, x_inter = model(batched_subgraphs)

        # calculate loss
        batch_loss = 0
        # batch_valid_nodes = 0
        for i, (label, answer_type) in enumerate(zip(labels, answer_types)):
            # Create mask for nodes that belong to the ith subgraph
            node_mask = (batched_subgraphs.batch == i)
            logits = out[node_mask]
            target = label

            # Create mask that keeps node types same as the answer type
            type_mask = torch.tensor([nt == answer_type for nt in batched_subgraphs.node_types[i]], device=device)
            masked_logits = logits[type_mask]
            masked_labels = target[type_mask].float()

            # Skip subgraph if no valid masked nodes
            if masked_labels.numel() == 0 or masked_labels.sum() == 0:
                continue

            pos_weight = torch.tensor([len(masked_labels) / masked_labels.sum()], device=device)
            loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            # normalize by number of nodes predicted on
            # This ensures stability, even if subgraphs vary significantly in the number of valid nodes.
            batch_loss += loss_fn(masked_logits, masked_labels) / len(masked_labels)
            # batch_valid_nodes += len(masked_labels)

        # if batch_valid_nodes > 0:
        # normalize by number of nodes predicted on
        # This ensures stability, even if subgraphs vary significantly in the number of valid nodes.
        # batch_loss = batch_loss / batch_valid_nodes
        # backward pass and optimization step
        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()

        # shift back to cpu
        batch_loss = batch_loss.detach().cpu()
        output_cpu = out.detach().cpu()
        x_inter_cpu = x_inter.detach().cpu()
        labels_cpu = [label.detach().cpu() for label in labels]

        total_loss += batch_loss.item()

        # print batch loss every 5 steps
        if d%5 == 0:
            print('Batch loss is', batch_loss.item())
    
        # # visualise last batch in the epoch
        # if d == len(dataloader) - 1:
        #     print(f"Duration for one epoch is {(time.time() - start)/60} minutes")
        #     visualize(x_inter_cpu, color=torch.cat(labels_cpu, dim=0))
    
    torch.cuda.empty_cache() # help clear cache taking up cuda space

    return total_loss / len(dataloader)

# evaluation code
def evaluate(dataloader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batched_subgraphs, question_embeddings, stacked_labels, node_maps, labels, answer_types in dataloader:
            batched_subgraphs = batched_subgraphs.to(device)
            labels = [label for label in labels]

            out, _ = model(batched_subgraphs)
            output_cpu = out.detach().cpu()
            
            # calculate accuracy for each subgraph
            for i, (label, answer_type) in enumerate(zip(labels, answer_types)):
                # Create mask for nodes that belong to the ith subgraph
                node_mask = (batched_subgraphs.batch == i).detach().cpu()
                # Create mask that keeps node types same as the answer type
                type_mask = torch.tensor([nt == answer_type for nt in batched_subgraphs.node_types[i]], device="cpu")
                preds = (torch.sigmoid(output_cpu[node_mask][type_mask]) > 0.5).int()
                label = label[type_mask]
                
                all_preds.extend(preds.tolist())
                all_labels.extend(label.tolist())
                correct += (preds == label).sum().item()
                total += label.size(0)

    precision = precision_score(all_labels, all_preds, average='binary') # for positive class
    recall = recall_score(all_labels, all_preds, average='binary') # for positive class
    f1 = f1_score(all_labels, all_preds, average='binary') # for positive class
    accuracy = correct / total # biased towards 0

    torch.cuda.empty_cache() # help clear cache taking up cuda space

    return accuracy, precision, recall, f1, all_preds, all_labels

In [4]:
# # train code
# def train(dataloader):
#     """
#     one epoch
#     returns average loss for one epoch
#     """
#     model.train()
#     total_loss = 0

#     start = time.time()
    
#     # loop batches from dataloader
#     for d, (batched_subgraphs, labels) in enumerate(dataloader):

#         optimizer.zero_grad()

#         batched_subgraphs = batched_subgraphs.to(device)
#         labels = [label.to(device) for label in labels]

#         # forward pass
#         out, x_inter = model(batched_subgraphs)

#         # calculate loss
#         batch_loss = 0
#         for i, label in enumerate(labels):
#             node_mask = (batched_subgraphs.batch == i)
#             logits = out[node_mask]
#             target = label.float()

#             pos_weight = torch.tensor([len(target) / target.sum()], device=device)
#             loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
#             batch_loss += loss_fn(logits, target)

#         # backward pass and optimization step
#         batch_loss.backward()
#         optimizer.step()

#         # shift back to cpu
#         batch_loss = batch_loss.detach().cpu()
#         output_cpu = out.detach().cpu()
#         x_inter_cpu = x_inter.detach().cpu()
#         labels_cpu = [label.detach().cpu() for label in labels]

#         total_loss += batch_loss.item()

#         # print batch loss every 5 steps
#         if d%5 == 0:
#             print('Batch loss is', batch_loss.item())
        
#         # visualise last batch in the epoch
#         if d == len(dataloader) - 1:
#             print(f"Duration for one epoch is {(time.time() - start)/60} minutes")
#             visualize(x_inter_cpu, color=torch.cat(labels_cpu, dim=0))
    
#     torch.cuda.empty_cache() # help clear cache taking up cuda space

#     return total_loss / len(dataloader)

# # evaluation code
# def evaluate(dataloader):
#     model.eval()
#     correct = 0
#     total = 0
#     all_preds = []
#     all_labels = []

#     with torch.no_grad():
#         for batched_subgraphs, labels in dataloader:
#             batched_subgraphs = batched_subgraphs.to(device)
#             labels = [label for label in labels]

#             out, _ = model(batched_subgraphs)
#             output_cpu = out.detach().cpu()

#             # calculate accuracy for each subgraph
#             for i, label in enumerate(labels):
#                 node_mask = (batched_subgraphs.batch == i).detach().cpu()
#                 preds = (torch.sigmoid(output_cpu[node_mask]) > 0.5).int()
                
#                 all_preds.extend(preds.tolist())
#                 all_labels.extend(label.tolist())
#                 correct += (preds == label).sum().item()
#                 total += label.size(0)

#     precision = precision_score(all_labels, all_preds, average='binary') # for positive class
#     recall = recall_score(all_labels, all_preds, average='binary') # for positive class
#     f1 = f1_score(all_labels, all_preds, average='binary') # for positive class
#     accuracy = correct / total # biased towards 0

#     torch.cuda.empty_cache() # help clear cache taking up cuda space

#     return accuracy, precision, recall, f1, all_preds, all_labels

In [4]:
# not sure
def focal_loss(logits, labels, alpha=0.25, gamma=2.0):
    """
    NOT USED.
    logits: Predicted output from the model (after log-softmax).
    labels: Ground truth labels (0 or 1).
    alpha: Balancing factor for the minority class.
    gamma: Focusing parameter for adjusting the rate at which easy examples are down-weighted.
    """
    # Compute cross-entropy loss per example
    ce_loss = F.nll_loss(logits, labels, reduction='none')
    # Probabilities for each example
    pt = torch.exp(-ce_loss)
    # Apply the focal loss adjustment
    focal_loss = alpha * (1 - pt) ** gamma * ce_loss

    # Return the mean loss
    return focal_loss.mean()

In [5]:
# new

# load data
path_to_node_embed = '../Datasets/MetaQA_dataset/processed/node2vec _embeddings/ud_node2vec_embeddings.txt'
path_to_idxes = '../Datasets/MetaQA_dataset/processed/idxes.json'
path_to_qa = '../Datasets/MetaQA_dataset/vanilla 3-hop/qa_train.txt'
path_to_ans_types = '../Datasets/MetaQA_dataset/processed/train_ans_type.txt'

# train
data = KGQADataset(path_to_node_embed, path_to_idxes, path_to_qa, path_to_ans_types, train = True)
dataloader_train = DataLoader(data, batch_size=64, collate_fn=collate_fn, shuffle=True)

# evaluate
test = KGQADataset(path_to_node_embed, path_to_idxes, '../Datasets/MetaQA_dataset/vanilla 3-hop/qa_test.txt', '../Datasets/MetaQA_dataset/processed/dev_ans_type.txt', train = False)
dataloader_test = DataLoader(test, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [5]:
# # load data
# path_to_node_embed = '../Datasets/MetaQA_dataset/processed/node2vec _embeddings/ud_node2vec_embeddings.txt'
# path_to_idxes = '../Datasets/MetaQA_dataset/processed/idxes.json'
# path_to_qa = '../Datasets/MetaQA_dataset/vanilla 3-hop/qa_train.txt'

# # train
# data = KGQADataset(path_to_node_embed, path_to_idxes, path_to_qa)
# sub_data1 = torch.utils.data.Subset(data, list(range(3000)))
# dataloader_train = DataLoader(sub_data1, batch_size=64, collate_fn=collate_fn, shuffle=True)
# # some from train to evaluate
# sub_data2 = torch.utils.data.Subset(data, list(range(5000, 5000+400)))
# dataloader_val = DataLoader(sub_data2, batch_size=32, collate_fn=collate_fn, shuffle=True)

# # some from test to evaluate
# test = KGQADataset(path_to_node_embed, path_to_idxes, '../Datasets/MetaQA_dataset/vanilla 3-hop/qa_test.txt')
# sub_data3 = torch.utils.data.Subset(test, list(range(400)))
# dataloader_test = DataLoader(sub_data3, batch_size=32, collate_fn=collate_fn, shuffle=True)

### Test training GNN
1. Concatenate the question embedding with each node embedding.

In [5]:
torch.manual_seed(2024)
random.seed(2024)
np.random.seed(2024)

In [6]:
# model
class GCN(torch.nn.Module):
    def __init__(self, node_dim, question_dim, hidden_dim, output_dim=1):
        super(GCN, self).__init__()
        # 4 layers, 100 hidden_dim
        self.conv1 = GCNConv(node_dim + question_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.conv4 = GCNConv(hidden_dim, output_dim)

    def forward(self, batched_subgraphs):
        # concatenate question embeddings with node features for each subgraph along feature dimension
        question_emb_expanded = []
        for subgraph in batched_subgraphs.to_data_list():
            subgraph.x = torch.cat((subgraph.x, subgraph.qn.unsqueeze(0).expand(subgraph.x.size(0), -1)), dim=1)
            question_emb_expanded.append(subgraph.x)

        batched_subgraphs.x = torch.cat(question_emb_expanded, dim=0)
        x, edge_index = batched_subgraphs.x, batched_subgraphs.edge_index

        # GCN
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x_inter = self.conv3(x, edge_index)
        x = F.relu(x_inter)
        x = self.conv4(x, edge_index)
        # Output logits directly for BCEWithLogitsLoss
        return x.squeeze(-1), x_inter

# Binary classification (answer candidate or not)
model = GCN(node_dim=64, question_dim=384, hidden_dim=100).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
# train 2 epoch
for epoch in range(2):
    loss = train(dataloader_train)
    val_accuracy, val_p, val_r, val_f1, _, _ = evaluate(dataloader_val)
    test_accuracy, test_p, test_r, test_f1, _, _ = evaluate(dataloader_test)
    print(f'Epoch {epoch}, Train Loss: {loss}')
    print(f'Validation Accuracy: {val_accuracy:.8f}, Validation P/R/F1: {val_p:.3f}/{val_r:.3f}/{val_f1:.3f}')
    print(f'Test Accuracy: {test_accuracy:.8f}, Test P/R/F1: {test_p:.3f}/{test_r:.3f}/{test_f1:.3f}')

### Test training GAT
1. Concatenate the question embedding with each node embedding.

In [6]:
torch.manual_seed(2024)
random.seed(2024)
np.random.seed(2024)

In [7]:
# model
class QuestionAwareGAT(torch.nn.Module):
    def __init__(self, node_dim, question_dim, hidden_dim, output_dim=1):
        super(QuestionAwareGAT, self).__init__()
        # 4 layers, 100 hidden_dim
        self.conv1 = GATConv(node_dim + question_dim, hidden_dim, heads=1, concat=True)
        self.conv2 = GATConv(hidden_dim, hidden_dim, heads=1, concat=True)
        self.conv3 = GATConv(hidden_dim, hidden_dim, heads=1, concat=True)
        self.conv4 = GATConv(hidden_dim, output_dim, heads=1, concat=True)

    def forward(self, batched_subgraphs):
        # concatenate question embeddings with node features for each subgraph along feature dimension
        question_emb_expanded = []
        for subgraph in batched_subgraphs.to_data_list():
            subgraph.x = torch.cat((subgraph.x, subgraph.qn.unsqueeze(0).expand(subgraph.x.size(0), -1)), dim=1)
            question_emb_expanded.append(subgraph.x)

        batched_subgraphs.x = torch.cat(question_emb_expanded, dim=0)
        x, edge_index = batched_subgraphs.x, batched_subgraphs.edge_index

        # GAT
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x_inter = self.conv3(x, edge_index)
        x = F.relu(x_inter)
        x = self.conv4(x, edge_index)
        # Output logits directly for BCEWithLogitsLoss
        return x.squeeze(-1), x_inter

# Binary classification (answer candidate or not)
model = QuestionAwareGAT(node_dim=64, question_dim=384, hidden_dim=100).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [8]:
# train 2 epoch
for epoch in range(2):
    loss = train(dataloader_train)
    val_accuracy, val_p, val_r, val_f1, _, _ = evaluate(dataloader_val)
    test_accuracy, test_p, test_r, test_f1, _, _ = evaluate(dataloader_test)
    print(f'Epoch {epoch}, Train Loss: {loss}')
    print(f'Validation Accuracy: {val_accuracy:.8f}, Validation P/R/F1: {val_p:.3f}/{val_r:.3f}/{val_f1:.3f}')
    print(f'Test Accuracy: {test_accuracy:.8f}, Test P/R/F1: {test_p:.3f}/{test_r:.3f}/{test_f1:.3f}')

Batch loss is 1.6537309885025024
Batch loss is 1.5430313348770142
Batch loss is 1.222909688949585
Batch loss is 1.6415810585021973
Batch loss is 1.2838495969772339
Batch loss is 1.3589661121368408
Batch loss is 1.6335830688476562
Batch loss is 1.3502583503723145
Batch loss is 1.2371432781219482
Batch loss is 1.453822135925293
Batch loss is 1.321313500404358
Batch loss is 1.4179054498672485
Batch loss is 1.1516602039337158
Batch loss is 1.0206462144851685
Batch loss is 1.3867830038070679
Batch loss is 1.1479699611663818


NameError: name 'dataloader_val' is not defined

In [12]:
test_accuracy, test_p, test_r, test_f1, _, _ = evaluate(dataloader_test)
print(f'Epoch {epoch}, Train Loss: {loss}')
print(f'Test Accuracy: {test_accuracy:.8f}, Test P/R/F1: {test_p:.3f}/{test_r:.3f}/{test_f1:.3f}')

Epoch 0, Train Loss: 1.4612611160625386
Test Accuracy: 0.38234432, Test P/R/F1: 0.003/0.596/0.005
