In [1]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph
import json
import re

# Create PyTorch Geometric Data object

In [2]:
def create_data_object(edge_index, relations, entity_to_idx):
    unique_relations = list(set(relations))
    relation_mapping = {relation: index for index, relation in enumerate(unique_relations)}

    edge_index = torch.tensor(edge_index).t().contiguous()
    # Make the graph undirected by adding reverse edges
    undirected_edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) # comment out if want directed
    edge_attr = torch.tensor([relation_mapping[rel] for rel in relations])
    # Since we now have more edges (two for each undirected edge), we concat them
    undirected_edge_attr = torch.cat([edge_attr, edge_attr], dim=0) # comment out if want directed

    return Data(edge_index=undirected_edge_index, edge_attr=undirected_edge_attr, num_nodes=len(entity_to_idx))

def load_data_json(filename='../Datasets/MetaQA_dataset/processed/idxes.json'):
    with open(filename, 'r') as f:
        data = json.load(f)
    return data['entity_to_idx'], data['edge_index'], data['relations']

In [3]:
loaded_entity_to_idx, loaded_edge_index, loaded_relations = load_data_json() 
data = create_data_object(loaded_edge_index, loaded_relations, loaded_entity_to_idx)
print(data)

Data(edge_index=[2, 267164], edge_attr=[267164], num_nodes=43234)


# Convert QA entities and answers to indices

In [4]:
def extract_qa_entities_and_answers(file_path):
    extracted_entities = []
    extracted_answers = []
    
    with open(file_path, 'r') as file:
        for line_number, line in enumerate(file, start=1):  
            # Split the line into question and answers
            parts = line.strip().split('\t')
            if len(parts) < 2:
                print(f"Line {line_number}: Not enough parts found.")
                continue
            
            question, answers = parts[0], parts[1]
            
            # Use regex to find entities in []
            matches = re.findall(r'\[(.*?)\]', question)
            if not matches:
                print(f"Line {line_number}: No entities found.")
            else:
                extracted_entities.extend(matches)
                
            # Extract answers by splitting the answers string on '|'
            answer_list = answers.split('|')
            extracted_answers.append(answer_list)

    return extracted_entities, extracted_answers

In [5]:
qa_file = '../Datasets/MetaQA_dataset/vanilla 3-hop/qa_train.txt'
qa_entities, qa_answers = extract_qa_entities_and_answers(qa_file)

In [6]:
# Convert extracted entities to indices
qa_entity_indices = [loaded_entity_to_idx[entity] for entity in qa_entities if entity in loaded_entity_to_idx]

# Convert extracted answers to indices
qa_answer_indices = [[loaded_entity_to_idx[ans] for ans in answer_list if ans in loaded_entity_to_idx] for answer_list in qa_answers]

# Convert to PyTorch tensors
qa_entity_tensor = torch.tensor(qa_entity_indices, dtype=torch.long)
qa_answer_tensor = [torch.tensor(answer_indices, dtype=torch.long) for answer_indices in qa_answer_indices]

# Generate subgraphs on the fly

In [7]:
# Function to get the k-hop subgraph
def get_khop_subgraph(data, entity_index, k):
    subset, sub_edge_index, mapping, edge_mask = k_hop_subgraph(
        node_idx=entity_index,
        k=k,
        edge_index=data.edge_index
    )
    return subset, sub_edge_index, mapping, edge_mask