Load environments

In [58]:
import sys
sys.path.append('/home/svu/e0315913/.local/lib/python3.8/site-packages')
sys.path.append('/home/svu/e0315913/.local/bin')
sys.path.append("/hpctmp/e0315913/CS5284_Project/GNN-cluster")

import os
os.chdir('/hpctmp/e0315913/CS5284_Project/GNN-cluster')

Import libraries

In [77]:
import random, torch
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

In [85]:
from src.utils.config import load_config, validate_config
from src.utils.evaluation import evaluate
from src.models.alpha import FullOutput, Metrics, threshold_based_candidates, calculate_avg_metrics
from src.my_datasets.kgqa_dataset import KGQADataset
from src.my_datasets.data_utils import collate_fn
from src.models.rgcn_model import RGCNModel

Set config and device

In [61]:
CONFIG_PATH = "config/demo_config.yaml"

config = load_config(CONFIG_PATH)
required_keys = [
    'model','train', 'node_embed', 'idxes',
    'train_qa_data', 'test_qa_data', 'num_hops',
    'model_path'
]
validate_config(config, required_keys)

In [62]:
torch.manual_seed(2024)
random.seed(2024)
np.random.seed(2024)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

Load Data

In [63]:
train_dataset = KGQADataset(
    path_to_node_embed=config['node_embed'],
    path_to_idxes=config['idxes'],
    path_to_qa=config['train_qa_data'],
    path_to_kb=config['raw_kb'],
    from_paths_activate=config['from_paths_activate'],
    entity_sbert=config['entity_sbert'],
    k=config['num_hops']
)
num_relations = train_dataset.num_relations # extract the num_relation from the entire graph
sub_train_dataset = Subset(train_dataset, list(range(config['train']['start_idx'],
                                                     config['train']['end_idx'])))

train_loader = DataLoader(
    sub_train_dataset,
    batch_size=config['train']['batch_size'],
    collate_fn=collate_fn,
    shuffle=True
)

In [70]:
config['train']['batch_size']

5

In [71]:
sub_train_dataset

<torch.utils.data.dataset.Subset at 0x2b92a1f63cd0>

In [64]:
model = RGCNModel(
            node_dim=config['model']['in_channels'],
            question_dim=train_dataset.q_embeddings.size(-1),
            hidden_dim=config['model']['hidden_channels'],
            num_relations=num_relations,
            output_dim=config['model']['out_channels'],
            num_rgcn=config['model']['num_layers'],
            reduced_qn_dim=config['model']['reduced_qn_dim'],
            reduced_node_dim=config['model']['reduced_node_dim'],
            output_embedding=config['model']['output_embedding'],
            use_residuals=config['model']['use_residuals']
        )

checkpoint = torch.load(config['model_path'])
model.load_state_dict(checkpoint['model_state_dict'])

model = model.to(device)

  checkpoint = torch.load(config['model_path'])


In [65]:
equal_subgraph_weighting = config['train']['equal_subgraph_weighting']
threshold_value = config['threshold_value']
hits_at_k = config['train']['hits_at_k']

save_all_path = config['save_all_path']
save_emb_path = config['save_emb_path']

In [None]:
def extract_subgraph_qemb(dataloader, model, device, equal_subgraph_weighting, threshold_value, save_all_path, save_emb_path):
    model.eval()
    
    all_batched_subgraphs = []
    all_question_embeddings = []
    all_candidates_masks = []
    all_similarity_scores = []
    all_node_maps = []
    all_labels = []

    with torch.no_grad():
        for batched_subgraphs, question_embeddings, stacked_labels, node_maps, labels in tqdm(dataloader, desc="Extracting subgraph", leave=True):
            # Move tensors to the specified device
            batched_subgraphs = batched_subgraphs.to(device)
            question_embeddings = question_embeddings.to(device)
            stacked_labels = stacked_labels.to(device)

            # Perform forward pass
            full_output = model(batched_subgraphs, question_embeddings)
            output = full_output.output if hasattr(full_output, 'output') else full_output
            threshold = full_output.threshold if hasattr(full_output, 'threshold') else threshold_value

            # Determine candidate nodes based on similarity threshold
            candidates_mask, similarity_scores = threshold_based_candidates(output, threshold=threshold)

            # Save batched data to lists (detaching to avoid memory leaks)
            all_batched_subgraphs.append(batched_subgraphs.x.detach().cpu())
            all_question_embeddings.append(question_embeddings.detach().cpu())
            all_candidates_masks.append(candidates_mask.detach().cpu())
            all_node_maps.extend(node_maps)  # Collect node maps
            all_labels.extend(labels)  # Collect labels
            if similarity_scores is not None:
                all_similarity_scores.append(similarity_scores.detach().cpu())
                
    # Concatenate all batched data along the 0-axis (vertically)
    all_batched_subgraphs = torch.cat(all_batched_subgraphs, dim=0)
    all_question_embeddings = torch.cat(all_question_embeddings, dim=0)
    all_candidates_masks = torch.cat(all_candidates_masks, dim=0)
    all_similarity_scores = torch.cat(all_similarity_scores, dim=0) if all_similarity_scores else None

    # Saving processed data to files
    save_subg_qemb_file(all_batched_subgraphs, all_question_embeddings, file_path=save_emb_path)
    save_all_to_file(all_batched_subgraphs, all_question_embeddings, all_candidates_masks, all_similarity_scores, all_node_maps, all_labels, file_path=save_all_path)


In [None]:
def save_all_to_file(batched_subgraphs, question_embeddings, candidates_mask, similarity_scores, node_map, labels, file_path):
    
    data = {
        "batched_subgraphs": batched_subgraphs,
        "question_embeddings" : question_embeddings,
        "candidates_mask": candidates_mask.tolist(),
        "similarity_scores": similarity_scores.tolist() if similarity_scores is not None else None,
        "node_maps" : node_map,
        "labels" : labels
    }
    
    torch.save(data, file_path)


In [103]:
def save_subg_qemb_file(batched_subgraphs, question_embeddings, file_path):
    
    data = {
        "batched_subgraphs": batched_subgraphs,
        "question_embeddings" : question_embeddings,
    }
    
    torch.save(data, file_path)

In [104]:
extract_subgraph_qemb(train_loader, model, device, equal_subgraph_weighting, threshold_value, save_all_path, save_emb_path)

Extracting subgraph: 100%|██████████| 4/4 [00:00<00:00,  7.74it/s]


To load the node embeddings

In [None]:
import torch

# from GNN-cluster/data/demo/candidate_metadata.pt
def load_saved_data(file_path):
    # Load the data from the saved file
    saved_data = torch.load(file_path)
    
    # Extract each component from the dictionary
    batched_subgraphs = saved_data["batched_subgraphs"]
    question_embeddings = saved_data["question_embeddings"]
    candidates_masks = saved_data["candidates_masks"]
    similarity_scores = saved_data.get("similarity_scores", None)  # Use .get() in case it's None
    node_maps = saved_data["node_maps"]
    labels = saved_data["labels"]
    
    return batched_subgraphs, question_embeddings, candidates_masks, similarity_scores, node_maps, labels


# from GNN-cluster/data/demo/subgraph_qembedding.pt
def load_saved_data(file_path):
    # Load the data from the saved file
    saved_data = torch.load(file_path)
    
    # Extract each component from the dictionary
    batched_subgraphs = saved_data["batched_subgraphs"]
    question_embeddings = saved_data["question_embeddings"]
    
    return batched_subgraphs, question_embeddings

  subgraph_qemb = torch.load(save_emb_path)
