Load environments

In [58]:
import sys
sys.path.append('/home/svu/e0315913/.local/lib/python3.8/site-packages')
sys.path.append('/home/svu/e0315913/.local/bin')
sys.path.append("/hpctmp/e0315913/CS5284_Project/GNN-cluster")

import os
os.chdir('/hpctmp/e0315913/CS5284_Project/GNN-cluster')

Import libraries

In [77]:
import random, torch
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

In [None]:
from src.utils.config import load_config, validate_config
from src.my_datasets.kgqa_dataset import KGQADataset
from src.my_datasets.data_utils import collate_fn
from src.models.rgcn_model import RGCNModel

# for this purpose only
from src.RAG.kgqa_extractor import extract_subgraph_qemb, load_all_metadata, load_subgraph_data

Set config and device

In [None]:
CONFIG_PATH = "config/demo_config.yaml"

config = load_config(CONFIG_PATH)
required_keys = [
    'model','train', 'node_embed', 'idxes',
    'train_qa_data', 'test_qa_data', 'num_hops',
]
validate_config(config, required_keys)

In [None]:
torch.manual_seed(2024)
random.seed(2024)
np.random.seed(2024)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

Test 2

In [None]:
train_dataset = KGQADataset(
    path_to_node_embed=config['node_embed'],
    path_to_idxes=config['idxes'],
    path_to_qa=config['train_qa_data'],
    path_to_kb=config['raw_kb'],
    from_paths_activate=False,
    entity_sbert=config['entity_sbert'],
    k=config['num_hops']
)
num_relations = train_dataset.num_relations # extract the num_relation from the entire graph
sub_train_dataset = Subset(train_dataset, list(range(config['train']['start_idx'],
                                                     config['train']['end_idx'])))

train_loader = DataLoader(
    sub_train_dataset,
    batch_size=config['train']['batch_size'],
    collate_fn=collate_fn,
    shuffle=True
)

In [None]:
model_test2 = RGCNModel(
            node_dim=config['model']['in_channels'],
            question_dim=train_dataset.q_embeddings.size(-1),
            hidden_dim=config['model']['hidden_channels'],
            num_relations=num_relations,
            output_dim=config['model']['out_channels'],
            num_rgcn=config['model']['num_layers'],
            reduced_qn_dim=config['model']['reduced_qn_dim'],
            reduced_node_dim=config['model']['reduced_node_dim'],
            output_embedding=config['model']['output_embedding'],
            use_residuals=config['model']['use_residuals']
        )

checkpoint = torch.load(config['model_path_test2'])
model_test2.load_state_dict(checkpoint['model_state_dict'])
model_test2 = model_test2.to(device)

In [None]:
equal_subgraph_weighting = config['train']['equal_subgraph_weighting']
threshold_value = config['threshold_value']
hits_at_k = config['train']['hits_at_k']

save_all_path_test2 = config['save_all_path_test2']
save_emb_path_test2 = config['save_emb_path_test2']

In [None]:
extract_subgraph_qemb(train_loader, model_test2, device, threshold_value, save_all_path_test2, save_emb_path_test2)

Test 8

In [None]:
train_dataset = KGQADataset(
    path_to_node_embed=config['node_embed'],
    path_to_idxes=config['idxes'],
    path_to_qa=config['train_qa_data'],
    path_to_kb=config['raw_kb'],
    from_paths_activate=config['from_paths_activate_test8'],
    entity_sbert=config['entity_sbert'],
    k=config['num_hops']
)
num_relations = train_dataset.num_relations # extract the num_relation from the entire graph
sub_train_dataset = Subset(train_dataset, list(range(config['train']['start_idx'],
                                                     config['train']['end_idx'])))

train_loader = DataLoader(
    sub_train_dataset,
    batch_size=config['train']['batch_size'],
    collate_fn=collate_fn,
    shuffle=True
)

In [None]:
model_test8 = RGCNModel(
            node_dim=config['model']['in_channels'],
            question_dim=train_dataset.q_embeddings.size(-1),
            hidden_dim=config['model']['hidden_channels'],
            num_relations=num_relations,
            output_dim=config['model']['out_channels'],
            num_rgcn=config['model']['num_layers'],
            reduced_qn_dim=config['model']['reduced_qn_dim'],
            reduced_node_dim=config['model']['reduced_node_dim'],
            output_embedding=config['model']['output_embedding'],
            use_residuals=config['model']['use_residuals']
        )

checkpoint = torch.load(config['model_path_test8'])
model_test8.load_state_dict(checkpoint['model_state_dict'])
model_test8 = model_test8.to(device)

In [None]:
equal_subgraph_weighting = config['train']['equal_subgraph_weighting']
threshold_value = config['threshold_value']
hits_at_k = config['train']['hits_at_k']

save_all_path_test8 = config['save_all_path_test8']
save_emb_path_test8 = config['save_emb_path_test8']


In [None]:
extract_subgraph_qemb(train_loader, model_test8, device, threshold_value, save_all_path_test8, save_emb_path_test8)

Loading Saved Embeddings

In [None]:
# Test 02
batched_subgraphs, original_graph_embeddings, question_embeddings, candidates_masks, similarity_scores, node_maps, labels = load_all_metadata(save_all_path_test2)
#OR
batched_subgraphs, original_graph_embeddings, question_embeddings = load_subgraph_data(save_emb_path_test2)