In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

from expand_subgraph import ExpandSubgraph
from model import GNN_auto
import pickle as pkl
import numpy as np
import torch
from load_data import DataLoader
from reasoning_module import ReasoningModule

## Load graph artifacts:

In [3]:
with open("knowledge_graph/KG_data/FB15k-237-betae/id2ent.pkl", "rb") as f:
    id2ent = pkl.load(f)
with open("knowledge_graph/KG_data/FB15k-237-betae/id2rel.pkl", "rb") as f:
    id2rel = pkl.load(f)
with open("knowledge_graph/KG_data/FB15k-237-betae/ent2id.pkl", "rb") as f:
    ent2id = pkl.load(f)

with open("knowledge_graph/queries/train_all_id.pkl", "rb") as f:
    queries = pkl.load(f)

with open("knowledge_graph/KG_data/FB15k-237-betae/FB15k_mid2name.txt", "r", encoding='utf-8') as f:
    ent2name = {}
    for line in f:
        mid, name = line.strip().split("\t", 1)  # Use maxsplit=1 in case name has tabs
        ent2name[mid] = name

#### Load queries

In [4]:
with open("knowledge_graph/queries/train_all_id.pkl", "rb") as f:
    queries = pkl.load(f)
id_query = 1
queries[id_query]

{'query_type': ('e', ('r', 'r', 'r')),
 'raw_query': (8552, (122, 191, 10)),
 'named_query': ('Marlborough_College',
  ('+/education/educational_institution/colors',
   '-/sports/sports_team/colors',
   '+/soccer/football_team/current_roster./soccer/football_roster_position/position')),
 'transformed_query': ['What colors are associated with Marlborough College and what positions do players in its soccer team hold?',
  'What are the soccer team colors for Marlborough College along with the positions of the players in the team?',
  'Can you tell me the colors of Marlborough College and the various positions played by its soccer team members?'],
 'answers_id': [9, 11, 117, 399],
 'answers': ['Midfielder', 'Forward', 'Goalkeeper', 'Defender']}

## Load sampler

In [5]:
class Config:
    data_path = 'knowledge_graph/KG_data/FB15k-237-betae'
    seed = 1234
    k = 9 # beams
    depth = 8 # max depth of subgraph
    cands_lim = 128
    gpu = 0
    fact_ratio = 0.6
    val_num = -1 # how many triples are used as the validate set
    epoch = 200
    layer = 6
    batchsize = 16
    cpu = 1
    weight = ''
    add_manual_edges = False
    remove_1hop_edges = True
    only_eval = False
    not_shuffle_train = False
    device = "cuda:0"

In [6]:
args = Config()
loader = DataLoader(args, mode='train')

loader.shuffle_train()
train_graph = loader.train_graph
train_graph_homo = list(set([(h,t) for (h,r,t) in train_graph]))

args.n_ent = loader.n_ent
args.n_rel = loader.n_rel

==> removing 1-hop links...
==> done
==> removing 1-hop links...
==> done


### Init sampler

In [7]:
sampler = ExpandSubgraph(args.n_ent, args.n_rel, train_graph_homo, train_graph, args)


In [8]:
sampler.assign_query(queries[id_query])
subgraph_data = sampler.sampleSubgraph()

### Init model

In [9]:
args.n_rel=237

In [10]:
modelPath = "weights/topk_0.1_layer_8_ValMRR_0.437.pt"
checkpoint = torch.load(modelPath, map_location=torch.device("cuda:0"))
params = {'lr': 0.0003, 'hidden_dim': 64, 'attn_dim': 4, 'n_layer': 8, 'act': 'relu', 'initializer': 'binary', 'concatHidden': True, 'shortcut': False, 'readout': 'linear', 'decay_rate': 0.9429713470775948, 'lamb': 0.000946516892415447, 'dropout': 0.19456805575101324}
class Model_Params:
    n_ent = args.n_ent
    n_rel = args.n_rel
    lr = 0.0003
    hidden_dim = 64
    attn_dim = 4
    n_layer = 8
    act = 'relu'
    initializer = 'binary'
    concatHidden = True
    shortcut = False
    readout = 'linear'
    decay_rate = 0.9429713470775948
    lamb = 0.000946516892415447
    dropout = 0.19456805575101324

model_params = Model_Params()

gnn_model = GNN_auto(model_params, loader)
gnn_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [11]:
reasoning_module = ReasoningModule(queries[id_query], 
                                   subgraph_data,
                                   gnn_model,
                                   "gpt-4o-mini")

In [None]:
reasoning_module.reasoning()