<a href="https://colab.research.google.com/github/anonymscientist/CQD-SHAP/blob/main/example_usage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Prerequisites

### Dataset

The datasets used in this project can be downloaded from Google Drive ([Link to Datasets](https://drive.google.com/file/d/1yoZFUAY7DLOj4fC78pIU32SUSAEWRmLw/view?usp=drive_link)) using the following command:

In [1]:
!gdown 1yoZFUAY7DLOj4fC78pIU32SUSAEWRmLw

Downloading...
From (original): https://drive.google.com/uc?id=1yoZFUAY7DLOj4fC78pIU32SUSAEWRmLw
From (redirected): https://drive.google.com/uc?id=1yoZFUAY7DLOj4fC78pIU32SUSAEWRmLw&confirm=t&uuid=6d8ed812-8cce-415d-9120-3053103aa328
To: /content/data.zip
100% 152M/152M [00:02<00:00, 67.8MB/s]


The extracted files will be saved in the `data/` directory.

In [2]:
!unzip --qq data.zip

### Pre-trained Models

The pre-trained models can be downloaded from Google Drive ([Link to Models](https://drive.google.com/file/d/1ot3CuVk4DorVu3JiHKzdumzGNaTREAU3/view?usp=drive_link)) using the following command:

In [3]:
!gdown 1ot3CuVk4DorVu3JiHKzdumzGNaTREAU3

Downloading...
From (original): https://drive.google.com/uc?id=1ot3CuVk4DorVu3JiHKzdumzGNaTREAU3
From (redirected): https://drive.google.com/uc?id=1ot3CuVk4DorVu3JiHKzdumzGNaTREAU3&confirm=t&uuid=a7056398-f991-428a-9d25-cfc9fec90c9e
To: /content/models.zip
100% 1.16G/1.16G [00:11<00:00, 102MB/s]


The torch model files will be saved in the `models/` directory.

In [4]:
!unzip --qq models.zip

### CQD-SHAP Implementation

Now, we can clone the CQD-SHAP repository:


In [5]:
!git clone https://github.com/anonymscientist/CQD-SHAP

Cloning into 'CQD-SHAP'...
remote: Enumerating objects: 219, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 219 (delta 23), reused 30 (delta 12), pack-reused 171 (from 3)[K
Receiving objects: 100% (219/219), 119.93 MiB | 23.73 MiB/s, done.
Resolving deltas: 100% (47/47), done.


We move all the files in `CQD-SHAP/` to the current directory:

In [6]:
!mv CQD-SHAP/* .

In [7]:
from symbolic_torch import SymbolicReasoning
from xcqa_torch import XCQA
from utils import get_num_atoms, get_query_file_paths, setup_dataset_and_graphs, load_query_datasets, load_all_queries, check_missing_link, compute_rank
from tqdm import tqdm
from shapley import shapley_value
import random
from query import human_readable
random.seed(42)

## Usage

The following example demonstrates how to use the CQD-SHAP implementation for a sample query in the FB15k-237 dataset.

If you want to use the NELL995 dataset, please change the following variables in the code:
- `data_dir` = `"data/NELL"`
- `model_path` = `"models/NELL-model-rank-1000-epoch-100-1602499096.pt"`

You can also set your desired value for the hyperparameter `k` in the code.

In [8]:
data_dir = "data/FB15k-237"
k = 10
t_norm, t_conorm = "prod", "prod"
model_path = "models/FB15k-237-model-rank-1000-epoch-100-1602508358.pt"

Let's load the dataset, training and validation graphs.

In [9]:
dataset, graph_train, graph_valid = setup_dataset_and_graphs(data_dir)

Loaded 14505 nodes from data/FB15k-237/ind2ent.pkl.
Loaded 474 relations from data/FB15k-237/ind2rel.pkl.
Loaded 14951 node titles from data/FB15k-237/extra/entity2text.txt.
Loaded 544230 edges from data/FB15k-237/train.txt, skipped 0 edges due to missing nodes or relations.
Loaded 579300 edges from data/FB15k-237/valid.txt, skipped 0 edges due to missing nodes or relations.


To build the test graph, we need to load the test triples from the dataset and append them to the validation graph.

In [10]:
from graph import Dataset, Graph

graph_test = Graph(dataset)
for edge in graph_valid.get_edges():
    graph_test.add_edge(edge.get_head().get_name(), edge.get_name(), edge.get_tail().get_name(), skip_missing=False, add_reverse=False)
graph_test.load_triples(f'{data_dir}/test.txt', skip_missing=False, add_reverse=True)

Loaded 620232 edges from data/FB15k-237/test.txt, skipped 0 edges due to missing nodes or relations.


Our main symbolic reasoner would be based on the validation graph. However, we also define another symbolic reasoner based on the test graph in case we want to check complete graph.

In [11]:
# when evaluating on test queries, we use the validation graph for symbolic reasoning
reasoner = SymbolicReasoning(graph_valid, logging=False)
reasoner_test = SymbolicReasoning(graph_test, logging=False)

We can load the queries using the following code (here we load all the `test` queries):

In [12]:
query_dataset, query_dataset_hard = load_all_queries(dataset, data_dir, "test")

Class `XCQA` helps us to execute a partial query and get the answers. We define an instance of this class by defining the reasoner which should be used in symbolic parts, the dataset, and the path to the neural model.

In [13]:
xcqa = XCQA(symbolic=reasoner, dataset=dataset, logging=False, model_path=model_path)

ComplEx(
  (embeddings): ModuleList(
    (0): Embedding(14505, 2000, sparse=True)
    (1): Embedding(474, 2000, sparse=True)
  )
)


Just as an example, we will pick one query from `2p` query type. You can change it to any other query type (e.g., `2p`, `2u`, `2i`, `3i`, `3p`, `up` for $2u1p$, `ip` for $2i1p$, and `pi` for $1p2i$) and any other index.

In [14]:
query_type = "2p"
num_atoms = get_num_atoms(query_type)
queries = query_dataset.get_queries(query_type)
queries_hard = query_dataset_hard.get_queries(query_type)

In [15]:
idx = 2893
query_complete = queries[idx]
query_hard = queries_hard[idx]
hard_answers = query_hard.get_answer()
all_answers = set(query_complete.get_answer())
easy_answers = [ans for ans in all_answers if ans not in hard_answers]
query_hard

Query(type=2p, query=((14349, (104, 40)),), answer=[2592, 9566, 3779, 10596, 1000, 9866, 7157, 4406, 3896, 6073, 3128, 5118])

In [16]:
human_readable(query_hard, dataset)

Query:
Dixieland	--/music/genre/parent_genre-->	V
V	--/music/genre/artists-->	?

Answer Set (?): 
['Miles Davis', 'Joss Stone', 'Chris Thile', 'Bill Evans', 'Corinne Bailey Rae', 'Toni Braxton', 'Ray Manzarek', 'Natalie Cole', 'Ringo Sheena', 'Bill Wyman', 'Amanda Lear', 'Tom Waits']


In [17]:
human_readable(query_complete, dataset)

Query:
Dixieland	--/music/genre/parent_genre-->	V
V	--/music/genre/artists-->	?

Answer Set (?): 
['Randy Jackson', 'Herb Alpert', 'John Williams', 'Lyle Lovett', 'Miles Davis', 'Thelonious Monk', 'Jerry Garcia', 'Kenny Rogers', 'Billie Holiday', 'Charlie Parker', 'Prince', 'Pat Metheny', 'Roy Haynes', 'Jamie Cullum', 'Freddie Hubbard', 'John Coltrane', 'Amanda Lear', 'Jaco Pastorius', 'James Brown', 'Beastie Boys', 'David Sylvian', 'Mike Patton', 'Steve Winwood', 'Steve Jordan', 'Linda Ronstadt', 'George Michael', 'Dave Brubeck', 'Jon Lord', 'Sonny Rollins', 'Oingo Boingo', 'Henry Mancini', 'Walter Becker', 'Joshua Redman', 'Ray Charles', 'Norah Jones', 'Diana Ross', 'Toni Braxton', 'Keith Jarrett', 'Sting', 'Earth, Wind & Fire', 'Big band', 'Marcus Miller', 'Bruce Hornsby', 'Johnny Mandel', 'Jill Scott', 'Christina Aguilera', 'Curtis Mayfield', 'Ennio Morricone', 'Adele', 'Mamoru Miyano', 'Chris Botti', 'Humberto Gatica', 'The Mothers of Invention', 'Lalo Schifrin', 'Chris Thile', 'P

In [18]:
anchor = query_hard.get_query()[0][0]
relation1 = query_hard.get_query()[0][1][0]
relation2 = query_hard.get_query()[0][1][1]
target1 = 2592 # Miles Davis
target2 = 10596 # Bill Evans

We first execute the query only by the neural model (`coalition=[1, 1]`).

In [19]:
grand_results = xcqa.query_execution(query_hard, k=k, coalition=num_atoms*[1], t_norm=t_norm, t_conorm=t_conorm)
grand_results['title'] = grand_results.index.map(dataset.get_title_by_id)
grand_results['is_easy_answer'] = grand_results.index.isin(easy_answers)
grand_results['is_hard_answer'] = grand_results.index.isin(query_hard.get_answer())
grand_results['variable_title'] = grand_results['variable_0'].map(dataset.get_title_by_id)
grand_results

Unnamed: 0,scores_0,scores_1,variable_0,final_score,title,is_easy_answer,is_hard_answer,variable_title
9086,5.882746,9.785547,14349,57.565891,Louis Armstrong,True,False,Dixieland
8522,5.882746,9.508564,14349,55.936470,Bing Crosby,True,False,Dixieland
2653,6.066570,8.999807,6709,54.597965,London Symphony Orchestra,False,False,Popular music
8053,6.066570,8.891864,6709,53.943115,Harry Warren,False,False,Popular music
8071,6.066570,8.600492,6709,52.175488,Barry Manilow,False,False,Popular music
...,...,...,...,...,...,...,...,...
14216,4.570889,-1.153239,529,-5.271328,African popular music,False,False,Pop music
14462,4.587668,-1.153908,5629,-5.293747,Wehrmacht,False,False,Vocal jazz
11624,4.570889,-1.178043,529,-5.384703,Announcer,False,False,Pop music
14504,4.570889,-1.182587,529,-5.405474,Modern architecture,False,False,Pop music


We can compute the rank of an answer using the `compute_rank` function.

In [20]:
rankings = {}
for answer in hard_answers:
    rank = compute_rank(grand_results, all_answers, answer)
    rankings[answer] = rank
rankings = dict(sorted(rankings.items(), key=lambda item: item[1]))
for answer, rank in rankings.items():
    answer_title = dataset.get_title_by_id(answer)
    variable_title = dataset.get_title_by_id(grand_results.loc[answer]['variable_0'])
    print(f"Rank {rank}: {answer} ({answer_title} | {variable_title})")

Rank 16: 6073 (Bill Wyman | Swing music)
Rank 25: 2592 (Miles Davis | Jazz)
Rank 33: 3896 (Ringo Sheena | Big band)
Rank 34: 4406 (Natalie Cole | Traditional pop music)
Rank 44: 10596 (Bill Evans | Jazz)
Rank 92: 3779 (Chris Thile | Jazz)
Rank 148: 9566 (Joss Stone | Pop music)
Rank 163: 3128 (Amanda Lear | Pop music)
Rank 176: 9866 (Toni Braxton | Pop music)
Rank 202: 1000 (Corinne Bailey Rae | Jazz)
Rank 308: 7157 (Ray Manzarek | Jazz)
Rank 367: 5118 (Tom Waits | Jazz)


To compute shapley values for each atom and a target answer, we can use `shapley_value` function as follows:

In [21]:
target = target1
print(f"Computing shapley values for entity {target} ({dataset.get_title_by_id(target)})")
for atom in range(num_atoms):
    filtered_exclude = all_answers - {target}
    sv = shapley_value(xcqa, query_hard, atom, filtered_exclude, target, "rank", k, t_norm, t_conorm)
    print(f"Shapley value of atom {atom}: {sv}")

Computing shapley values for entity 2592 (Miles Davis)
Shapley value of atom 0: -14.0
Shapley value of atom 1: 7595.0


In [22]:
target = target2
print(f"Computing shapley values for entity {target} ({dataset.get_title_by_id(target)})")
for atom in range(num_atoms):
    filtered_exclude = all_answers - {target}
    sv = shapley_value(xcqa, query_hard, atom, filtered_exclude, target, "rank", k, t_norm, t_conorm)
    print(f"Shapley value of atom {atom}: {sv}")

Computing shapley values for entity 10596 (Bill Evans)
Shapley value of atom 0: -95.5
Shapley value of atom 1: 973.5
