In [17]:
%load_ext autoreload
%autoreload 2

In [18]:

from expand_subgraph import ExpandSubgraph
from model import GNN_auto
import pickle as pkl
import numpy as np
import torch
from load_data import DataLoader
from reasoning_module import ReasoningModule

## Load graph artifacts:

In [19]:
with open("knowledge_graph/KG_data/FB15k-237-betae/id2ent.pkl", "rb") as f:
    id2ent = pkl.load(f)
with open("knowledge_graph/KG_data/FB15k-237-betae/id2rel.pkl", "rb") as f:
    id2rel = pkl.load(f)
with open("knowledge_graph/KG_data/FB15k-237-betae/ent2id.pkl", "rb") as f:
    ent2id = pkl.load(f)

with open("knowledge_graph/queries/train_all_id.pkl", "rb") as f:
    queries = pkl.load(f)

with open("knowledge_graph/KG_data/FB15k-237-betae/FB15k_mid2name.txt", "r", encoding='utf-8') as f:
    ent2name = {}
    for line in f:
        mid, name = line.strip().split("\t", 1)  # Use maxsplit=1 in case name has tabs
        ent2name[mid] = name

#### Load queries

In [71]:
with open("knowledge_graph/queries/CWQ_sim_queries.pkl", "rb") as f:
    queries = pkl.load(f)
id_query = 32
queries[id_query]

{'query_type': ('e', ('r',)),
 'raw_query': (4542, (24,)),
 'named_query': ('And_Starring_Pancho_Villa_as_Himself',
  ('+/award/award_winning_work/awards_won./award/award_honor/award_winner',)),
 'transformed_query': ['Who won awards for their work starring Pancho Villa as Himself?',
  "Which award winner is associated with the work 'And Starring Pancho Villa as Himself'?",
  "Can you tell me who received awards for 'And Starring Pancho Villa as Himself'?"],
 'answers_id': [3297],
 'answers': ['Larry_Gelbart'],
 'natural_query': 'Who won an award for their work on the film "And Starring Pancho Villa as Himself"?'}

## Load sampler

In [None]:
class Config:
    data_path = 'knowledge_graph/KG_data/FB15k-237-betae'
    seed = 1234
    k = 9 # beams
    depth = 2 # max depth of subgraph
    cands_lim = 1024
    gpu = 0
    fact_ratio = 1.0
    val_num = -1 # how many triples are used as the validate set
    add_manual_edges = False
    remove_1hop_edges = True
    not_shuffle_train = False
    device = "cuda:0"

In [148]:
args = Config()
loader = DataLoader(args, mode='train')

# loader.shuffle_train()
train_graph = loader.train_graph
train_graph_homo = list(set([(h,t) for (h,r,t) in train_graph]))

args.n_ent = loader.n_ent
args.n_rel = loader.n_rel

==> removing 1-hop links...
==> done


In [149]:
len(train_graph)

544230

### Init sampler

In [150]:
GoG_args = {
    'drop_ratio': 0.4,
}

sampler = ExpandSubgraph(
    args.n_ent, args.n_rel, train_graph_homo, train_graph, args,
    GoG_simulation=True,
    GoG_args=GoG_args
)

In [155]:
sampler.assign_query(queries[id_query])
subgraph_data = sampler.sampleSubgraphBFS()

Simulating GoG by removing edges...


In [156]:
print(len(sampler.orignal_edge_index), len(sampler.edge_index))

544230 544227


In [157]:
def count_path_to_ans(id_query, subgraph_data):
    for ans in queries[id_query]['answers_id']:
        cnt = 0
        for edges in subgraph_data[2]:
            if ans in edges[[0,2]]:
                cnt += 1
        print(ans, cnt)

In [158]:
count_path_to_ans(id_query, subgraph_data)

3297 3


In [146]:
np.unique(subgraph_data[2][:,[0,2]].flatten()).shape

(7213,)

In [159]:
subgraph_edges = subgraph_data[2]

Check coverage of the subgraph

In [None]:
coverage = 0
ans = torch.tensor(queries[id_query]['answers_id'])
subgraph_nodes = subgraph_data[0]
## How many ans entity appear in subgraph_nodes
mask = torch.isin(ans, subgraph_nodes)

# Count how many are True
print("Coverage = ", mask.sum().item() / ans.size(0))

Coverage =  1.0


### Init model

In [None]:
args.n_rel=237

In [None]:
modelPath = "weights/topk_0.1_layer_8_ValMRR_0.437.pt"
checkpoint = torch.load(modelPath, map_location=torch.device("cuda:0"))
params = {'lr': 0.0003, 'hidden_dim': 64, 'attn_dim': 4, 'n_layer': 8, 'act': 'relu', 'initializer': 'binary', 'concatHidden': True, 'shortcut': False, 'readout': 'linear', 'decay_rate': 0.9429713470775948, 'lamb': 0.000946516892415447, 'dropout': 0.19456805575101324}
class Model_Params:
    n_ent = args.n_ent
    n_rel = args.n_rel
    lr = 0.0003
    hidden_dim = 64
    attn_dim = 4
    n_layer = 8
    act = 'relu'
    initializer = 'binary'
    concatHidden = True
    shortcut = False
    readout = 'linear'
    decay_rate = 0.9429713470775948
    lamb = 0.000946516892415447
    dropout = 0.19456805575101324

model_params = Model_Params()

gnn_model = GNN_auto(model_params, loader)
gnn_model.load_state_dict(checkpoint['model_state_dict'])
gnn_model = gnn_model.to(torch.device("cuda:0"))

In [None]:
reasoning_module = ReasoningModule(queries[id_query], 
                                   sampler,
                                   subgraph_data,
                                   gnn_model,
                                   "gpt-3.5")

### Inference on your queries

In [None]:
def count_path_to_ans(id_query, subgraph_data):
    sum_cnt = 0
    for ans in queries[id_query]['answers_id']:
        cnt = 0
        for edges in subgraph_data[2]:
            if ans in edges[[0,2]]:
                cnt += 1
        sum_cnt += cnt
        print(ent2name[id2ent[ans]], "|| count: " + str(cnt))
        return sum_cnt

In [None]:
predictions = []
ground_truths = []
cands = []

In [None]:
for query in queries[::]:
    ground_truths.append(query['answers'])
    print(query['natural_query'])
    subgraph_data = sampler.sampleSubgraph(query)
    
    # reasoning_module.assign_query(query)
    # reasoning_module.assign_subgraph(subgraph_data)
    # reasoning_module.reasoning()
    # results = reasoning_module.result
    # predictions.append([results])
    # cands.append(reasoning_module.cands)

Which awards have the production companies of films featuring Katherine Heigl been nominated for?
Simulating GoG by removing edges...
decision:  Yes
___________________________
Use grounded relations.
_____________________
Processing relation: +/award/award_nominee/award_nominations./award/award_nomination/nominated_for
@@@@@@@@@@@@@@  
 Top candidates from GNN:
Killers | score:  11.432977676491602
Knocked_Up | score:  11.432977676491602
Grey's_Anatomy | score:  11.432977676491602
Life_as_We_Know_It | score:  11.432977676491602
The_Ugly_Truth | score:  11.432977676491602
The_Haunting | score:  1.4329771996544434
valid entity pruning
############
 Top candidates after LLM filtering:
The_Haunting | score:  0.7164885999272217
Killers | score:  1e-10
Knocked_Up | score:  1e-10
_____________________
Processing relation: +/award/award_nominee/award_nominations./award/award_nomination/award
@@@@@@@@@@@@@@  
 Top candidates from GNN:
Razzie_Award_for_Worst_Actress | score:  10.982972145180566


KeyboardInterrupt: 

In [None]:
a

In [None]:
dict_t = {"predctions": predictions, "cands": cands, "ground_truths": ground_truths}

with open("attempt_4_1_100_queries.pkl", "wb") as f:
    pkl.dump(dict_t, f)


In [None]:
from evaluate_result import evaluate_predictions

evaluate_predictions(ground_truths, predictions)

{'accuracy': 0.0, 'f1': 0.0, 'hit@1': 0.0, 'hit@3': 0.0}

In [None]:
reasoning_module.entities

[5320]

In [None]:
count_path_to_ans(1,  subgraph_data)

Razzie_Award_for_Worst_Picture || count: 0
Academy_Award_for_Best_Picture || count: 0
Tony_Award_for_Best_Musical || count: 0
Razzie_Award_for_Worst_Prequel,_Remake,_Rip-off_or_Sequel || count: 0
Satellite_Award_for_Best_Animated_or_Mixed_Media_Feature || count: 0


In [None]:
queries[1]

{'query_type': ('e', ('r', 'r', 'r')),
 'raw_query': (5320, (12, 172, 38)),
 'named_query': ('Katherine_Heigl',
  ('+/film/actor/film./film/performance/film',
   '+/film/film/production_companies',
   '+/award/award_nominee/award_nominations./award/award_nomination/award')),
 'transformed_query': ['Which awards have been associated with films that Katherine Heigl acted in?',
  'What are the film awards received or nominated for by movies featuring Katherine Heigl?',
  'Can you list the awards for which the films starring Katherine Heigl have been nominated?'],
 'answers_id': [929, 1121, 2859, 1808, 3646],
 'answers': ['Razzie_Award_for_Worst_Picture',
  'Academy_Award_for_Best_Picture',
  'Tony_Award_for_Best_Musical',
  'Razzie_Award_for_Worst_Prequel,_Remake,_Rip-off_or_Sequel',
  'Satellite_Award_for_Best_Animated_or_Mixed_Media_Feature'],
 'natural_query': 'Which awards have the production companies of films featuring Katherine Heigl been nominated for?'}