In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

from expand_subgraph import ExpandSubgraph
from model import GNN_auto
import pickle as pkl
import numpy as np
import torch
from load_data import DataLoader
from reasoning_module import ReasoningModule

## Load graph artifacts:

In [3]:
with open("knowledge_graph/KG_data/FB15k-237-betae/id2ent.pkl", "rb") as f:
    id2ent = pkl.load(f)
with open("knowledge_graph/KG_data/FB15k-237-betae/id2rel.pkl", "rb") as f:
    id2rel = pkl.load(f)
with open("knowledge_graph/KG_data/FB15k-237-betae/ent2id.pkl", "rb") as f:
    ent2id = pkl.load(f)

with open("knowledge_graph/queries/train_all_id.pkl", "rb") as f:
    queries = pkl.load(f)

with open("knowledge_graph/KG_data/FB15k-237-betae/FB15k_mid2name.txt", "r", encoding='utf-8') as f:
    ent2name = {}
    for line in f:
        mid, name = line.strip().split("\t", 1)  # Use maxsplit=1 in case name has tabs
        ent2name[mid] = name

#### Load queries

In [4]:
with open("knowledge_graph/queries/train_all_id.pkl", "rb") as f:
    queries = pkl.load(f)
id_query = 1
queries[id_query]

{'query_type': ('e', ('r', 'r', 'r')),
 'raw_query': (8552, (122, 191, 10)),
 'named_query': ('Marlborough_College',
  ('+/education/educational_institution/colors',
   '-/sports/sports_team/colors',
   '+/soccer/football_team/current_roster./soccer/football_roster_position/position')),
 'transformed_query': ['What colors are associated with Marlborough College and what positions do players in its soccer team hold?',
  'What are the soccer team colors for Marlborough College along with the positions of the players in the team?',
  'Can you tell me the colors of Marlborough College and the various positions played by its soccer team members?'],
 'answers_id': [9, 11, 117, 399],
 'answers': ['Midfielder', 'Forward', 'Goalkeeper', 'Defender']}

## Load sampler

In [5]:
class Config:
    data_path = 'knowledge_graph/KG_data/FB15k-237-betae'
    seed = 1234
    k = 9 # beams
    depth = 8 # max depth of subgraph
    cands_lim = 64
    gpu = 0
    fact_ratio = 0.6
    val_num = -1 # how many triples are used as the validate set
    add_manual_edges = False
    remove_1hop_edges = True
    not_shuffle_train = False
    device = "cuda:0"

In [6]:
args = Config()
loader = DataLoader(args, mode='train')

loader.shuffle_train()
train_graph = loader.train_graph
train_graph_homo = list(set([(h,t) for (h,r,t) in train_graph]))

args.n_ent = loader.n_ent
args.n_rel = loader.n_rel

==> removing 1-hop links...
==> done
==> removing 1-hop links...
==> done


### Init sampler

In [7]:
sampler = ExpandSubgraph(args.n_ent, args.n_rel, train_graph_homo, train_graph, args)


In [8]:
sampler.assign_query(queries[id_query])
subgraph_data = sampler.sampleSubgraph()
subgraph_data[0].shape
subgraph_data[2][:10]

tensor([[ 8552,   122,  7990],
        [ 8552,    44,  7973],
        [ 8552,   248,  8553],
        [ 8552,    31,  1124],
        [ 8552,    31,   382],
        [ 7990,   191,  4714],
        [ 7990,   191,  4882],
        [ 7990,   191, 10667],
        [ 7990,   191,  2967],
        [ 7990,   191,  7798]])

Check coverage of the subgraph

In [9]:
coverage = 0
ans = torch.tensor(queries[id_query]['answers_id'])
subgraph_nodes = subgraph_data[0]
## How many ans entity appear in subgraph_nodes
mask = torch.isin(ans, subgraph_nodes)

# Count how many are True
print("Coverage = ", mask.sum().item() / ans.size(0))


Coverage =  1.0


### Init model

In [10]:
args.n_rel=237

In [11]:
modelPath = "weights/topk_0.1_layer_8_ValMRR_0.437.pt"
checkpoint = torch.load(modelPath, map_location=torch.device("cuda:0"))
params = {'lr': 0.0003, 'hidden_dim': 64, 'attn_dim': 4, 'n_layer': 8, 'act': 'relu', 'initializer': 'binary', 'concatHidden': True, 'shortcut': False, 'readout': 'linear', 'decay_rate': 0.9429713470775948, 'lamb': 0.000946516892415447, 'dropout': 0.19456805575101324}
class Model_Params:
    n_ent = args.n_ent
    n_rel = args.n_rel
    lr = 0.0003
    hidden_dim = 64
    attn_dim = 4
    n_layer = 8
    act = 'relu'
    initializer = 'binary'
    concatHidden = True
    shortcut = False
    readout = 'linear'
    decay_rate = 0.9429713470775948
    lamb = 0.000946516892415447
    dropout = 0.19456805575101324

model_params = Model_Params()

gnn_model = GNN_auto(model_params, loader)
gnn_model.load_state_dict(checkpoint['model_state_dict'])
gnn_model = gnn_model.to(torch.device("cuda:0"))

In [12]:
reasoning_module = ReasoningModule(queries[id_query], 
                                   sampler,
                                   subgraph_data,
                                   gnn_model,
                                   "gpt-4o-mini")

In [13]:
subgraph_data[2][:10]

tensor([[ 8552,   122,  7990],
        [ 8552,    44,  7973],
        [ 8552,   248,  8553],
        [ 8552,    31,  1124],
        [ 8552,    31,   382],
        [ 7990,   191,  4714],
        [ 7990,   191,  4882],
        [ 7990,   191, 10667],
        [ 7990,   191,  2967],
        [ 7990,   191,  7798]])

In [18]:
reasoning_module.reasoning()

+/education/educational_institution/colors 122
[8552  122 7990]
+/education/educational_institution/students_graduates./education/education/student 44
[8552   44 7973]
+/sports/sports_team/colors 190
+/soccer/football_team/current_roster./soccer/football_roster_position/position 10
{'entity_id': 8552, 'relation': '+/soccer/football_team/current_roster./soccer/football_roster_position/position', 'candidate_id': 11, 'score': 0.7687200308096069}
{'entity_id': 8552, 'relation': '+/soccer/football_team/current_roster./soccer/football_roster_position/position', 'candidate_id': 117, 'score': 0.6328135133039612}
{'entity_id': 8552, 'relation': '+/education/educational_institution/colors', 'candidate_id': 11, 'score': 0.5966756224982263}
{'entity_id': 8552, 'relation': '+/education/educational_institution/colors', 'candidate_id': 117, 'score': 0.5695405722014428}
{'entity_id': 8552, 'relation': '+/education/educational_institution/colors', 'candidate_id': 521, 'score': 0.48990665677709594}
{'en

In [None]:
reasoning_module.subgraph_edges[:10]

tensor([[  0, 122,   0],
        [  0, 286,   0],
        [  0, 248,   0],
        [  0, 248,   0],
        [  0,  31,   0],
        [  0,  31,   0],
        [  0, 191,   0],
        [  0, 191,   0],
        [  0, 191,   0],
        [  0, 191,   0]])

In [None]:
print("Subgraph data types:", type(subgraph_data))
for a in subgraph_data:
    print(type(a), a.dtype)

Subgraph data types: <class 'tuple'>
<class 'torch.Tensor'> torch.int64
<class 'torch.Tensor'> torch.int64
<class 'torch.Tensor'> torch.int64


In [None]:
reasoning_module.adj_list

{8552: [(286, 13781), (44, 7973), (248, 12350), (31, 13781), (31, 382)],
 382: [(188, 14211),
  (59, 4564),
  (59, 74),
  (59, 3337),
  (59, 4043),
  (59, 6286),
  (59, 7407),
  (59, 1162),
  (59, 2881)],
 4564: [(161, 10554),
  (161, 4302),
  (161, 10126),
  (161, 9649),
  (161, 9655),
  (161, 13631),
  (161, 13502),
  (161, 5146),
  (161, 9713)],
 10126: [(190, 4288)],
 4302: [(190, 521)],
 9649: [(190, 521), (190, 4288)],
 9713: [(190, 7990), (190, 5052), (190, 521)],
 13502: [(190, 926)],
 13631: [(190, 926)],
 926: [(191, 810),
  (191, 1882),
  (191, 9860),
  (191, 11224),
  (191, 10180),
  (191, 343),
  (191, 2675),
  (191, 9553),
  (191, 10106),
  (191, 11948),
  (191, 13450),
  (191, 7394),
  (191, 12977),
  (191, 5436),
  (191, 7328),
  (191, 11789)],
 4288: [(191, 5706),
  (191, 8397),
  (191, 1438),
  (191, 8080),
  (191, 8934),
  (191, 9998),
  (191, 7219),
  (191, 5798),
  (191, 13841),
  (191, 2172),
  (191, 11665)],
 5706: [(190, 521), (10, 399), (10, 117)],
 2675: [(190