-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
214 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
""" | ||
This module is responsible for adding auxiliary helper data to the result proto | ||
""" | ||
|
||
from agatha.topic_query import ( | ||
topic_query_result_pb2 as res_pb, | ||
) | ||
from gensim.models import LdaModel | ||
from gensim.corpora import Dictionary | ||
import numpy as np | ||
from typing import Iterable, Tuple, Dict, Optional | ||
from itertools import combinations | ||
from agatha.util.sqlite3_lookup import Sqlite3Graph, Sqlite3Bow | ||
from collections import defaultdict | ||
from agatha.util.entity_types import is_sentence_type | ||
|
||
|
||
def estimate_plaintext_from_graph_key( | ||
graph_key:str, | ||
graph_db:Sqlite3Graph, | ||
bow_db:Sqlite3Bow, | ||
max_checks:Optional[int]=100, | ||
)->str: | ||
""" | ||
Given a graph key, get the most likely plaintext word associated with it. | ||
For instance, given "l:noun:cancer" or "m:d009369" we should get something | ||
like "cancer" | ||
If `max_checks` is set to `None`, we will check every sentence associated | ||
with the graph_key, however, we often only need to check a sample. | ||
""" | ||
|
||
word2count = defaultdict(int) | ||
|
||
# Select a subset of sentences to lookup | ||
sentences_to_check = [graph_key] + graph_db[graph_key] | ||
sentences_to_check = list(filter(is_sentence_type, sentences_to_check)) | ||
if max_checks is not None: | ||
sentences_to_check = sentences_to_check[:max_checks] | ||
|
||
for neighbor in sentences_to_check: | ||
if neighbor in bow_db: | ||
for word in bow_db[neighbor]: | ||
word2count[word] += 1 | ||
max_count = None | ||
res = None | ||
for word, count in word2count.items(): | ||
if max_count is None or count > max_count: | ||
max_count = count | ||
res = word | ||
return res | ||
|
||
|
||
def _weighted_jacc(a:np.array, b:np.array)->float: | ||
return np.sum(np.minimum(a, b)) / np.sum(np.maximum(a, b)) | ||
|
||
|
||
def _all_pairs_jaccard_comparisions( | ||
graph_idx2vec:Dict[int, np.array] | ||
)->Iterable[Tuple[int, int, float]]: | ||
""" | ||
Performs all-pairs comparisons within the set of vectors. If there is a | ||
nonzero similarity between vectors i, and j, then we will generate: (i, j, | ||
sim(i,j)), and (j, i, sim(i,j)). | ||
note: Similarity is symmetrical. | ||
""" | ||
|
||
for i, j in combinations(graph_idx2vec.keys(), 2): | ||
sim = _weighted_jacc(graph_idx2vec[i], graph_idx2vec[j]) | ||
if sim > 0: | ||
yield (i, j, sim) | ||
yield (j, i, sim) | ||
|
||
|
||
def add_topical_network( | ||
result:res_pb.TopicQueryResult, | ||
topic_model:LdaModel, | ||
dictionary:Dictionary, | ||
graph_db:Sqlite3Graph, | ||
bow_db:Sqlite3Bow, | ||
)->None: | ||
""" | ||
Adds the topical_network field to the result proto. | ||
Creates this network by the weighted jacquard of topics. | ||
The source and target words are going to be assigned indices -1 and -2. | ||
""" | ||
# Size n_topics X voccab_size | ||
term_topic_mat = topic_model.get_topics() | ||
num_topics, vocab_size = term_topic_mat.shape | ||
|
||
source_word = estimate_plaintext_from_graph_key( | ||
graph_key=result.source, | ||
graph_db=graph_db, | ||
bow_db=bow_db, | ||
) | ||
source_word_idx = dictionary.token2id[source_word] | ||
source_graph_idx = -1 | ||
source_vec = np.zeros(vocab_size) | ||
source_vec[source_word_idx] = 1 | ||
|
||
target_word = estimate_plaintext_from_graph_key( | ||
graph_key=result.target, | ||
graph_db=graph_db, | ||
bow_db=bow_db, | ||
) | ||
target_word_idx = dictionary.token2id[target_word] | ||
target_graph_idx = -2 | ||
target_vec = np.zeros(vocab_size) | ||
target_vec[target_word_idx] = 1 | ||
|
||
graph_idx2vec = { | ||
topic_idx: term_topic_mat[topic_idx, :] | ||
for topic_idx in range(num_topics) | ||
} | ||
graph_idx2vec[source_graph_idx] = source_vec | ||
graph_idx2vec[target_graph_idx] = target_vec | ||
|
||
# Set all node names | ||
for idx in range(num_topics): | ||
result.topical_network.nodes[idx].name = f"Topic: {idx}" | ||
result.topical_network.nodes[source_graph_idx].name = \ | ||
f"Source: '{result.source}' -- '{source_word}'" | ||
result.topical_network.nodes[target_graph_idx].name = \ | ||
f"Source: '{result.target}' -- '{target_word}'" | ||
|
||
# Set all edges: | ||
for i, j, sim in _all_pairs_jaccard_comparisions(graph_idx2vec): | ||
result.topical_network.nodes[i].neighbors[j] = sim | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,44 @@ | ||
syntax = "proto2"; | ||
package agatha; | ||
|
||
|
||
// A result of the topic_query process | ||
message TopicQueryResult { | ||
// the first word searched | ||
optional string source = 1; | ||
// the second word searched | ||
optional string target = 2; | ||
// The path of graph connections that connections source and target | ||
repeated string path = 3; | ||
// The set of topics produced by this query | ||
repeated Topic topics = 4; | ||
// The set of sentences that were used to create this query | ||
repeated Document documents = 5; | ||
// Topical network information | ||
optional TopicalNetwork topical_network = 6; | ||
} | ||
|
||
// Helper Messages | ||
|
||
// A topic is a probability distribution over a vocabulary | ||
message Topic { | ||
message TermWeight { | ||
optional string term = 1; | ||
optional float weight = 2; | ||
} | ||
repeated TermWeight term_weights = 1; | ||
optional int32 index = 2; | ||
map<string, float> term2weight = 1; | ||
} | ||
|
||
// A document is a single sentence in our subcorpus | ||
message Document { | ||
message TopicWeight { | ||
optional int32 topic = 1; | ||
optional float weight = 2; | ||
} | ||
optional string key = 1; | ||
repeated TopicWeight topic_weights = 2; | ||
repeated string terms = 3; | ||
// the identifier of this document | ||
optional string doc_id = 1; | ||
// The topic mixture of this document. Note that topic2weight[i] | ||
// corresponds to TopicQueryResult.topics[i] | ||
map<int32, float> topic2weight = 2; | ||
} | ||
|
||
message TopicQueryResult { | ||
repeated string path = 1; | ||
repeated Document documents = 2; | ||
repeated Topic topics = 3; | ||
optional string source = 4; | ||
optional string target = 5; | ||
// Stores the topic network information | ||
message TopicalNetwork { | ||
message TopicalNode { | ||
map<int32, float> neighbors = 1; | ||
optional string name = 2; | ||
} | ||
map<int32, TopicalNode> nodes = 1; | ||
} |