Skip to content

Commit

Permalink
Adds topic network to topic_query
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 18, 2020
1 parent 5abfd85 commit d532bac
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 55 deletions.
76 changes: 43 additions & 33 deletions agatha/topic_query/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,51 @@
from agatha.topic_query import (
path_util,
bow_util,
topic_query_result_pb2 as qpb,
topic_query_config_pb2 as cpb,
topic_query_result_pb2 as res_pb,
topic_query_config_pb2 as conf_pb,
aux_result_data,
)
from agatha.util import entity_types, proto_util
from agatha.util.sqlite3_lookup import Sqlite3Graph, Sqlite3Bow


def assert_conf_has_field(config:cpb.TopicQueryConfig, field:str)->None:
def assert_conf_has_field(config:conf_pb.TopicQueryConfig, field:str)->None:
if not config.HasField(field):
raise ValueError(f"Must supply `{field}` term.")


if __name__ == "__main__":
config = cpb.TopicQueryConfig()
config = conf_pb.TopicQueryConfig()
proto_util.parse_args_to_config_proto(config)
print("Running agatha query with the following custom parameters:")
print(config)

# Query specified
assert_conf_has_field(config, "source")
assert_conf_has_field(config, "target")
#assert_conf_has_field(config, "result_path")
print("Storing result to", config.result_path)
assert_conf_has_field(config, "result_path")

# Double check the result path
result_path = Path(config.result_path)
if result_path.is_file():
assert config.override
else:
assert not result_path.exists()
assert result_path.parent.is_dir()

graph_index = Sqlite3Graph(config.graph_db)
bow_index = Sqlite3Bow(config.bow_db)

assert config.source in graph_index, "Failed to find source in graph_index."
assert config.target in graph_index, "Failed to find target in graph_index."
# Setup the database indices
graph_db = Sqlite3Graph(config.graph_db)
assert config.source in graph_db, "Failed to find source in graph_db."
assert config.target in graph_db, "Failed to find target in graph_db."

# Preload the graph
if config.preload_graph_db:
print("Loading the graph in memory")
graph_db.preload()

# Get Path
print("Finding shortest path")
path, cached_graph = path_util.get_shortest_path(
graph_index=graph_index,
graph_index=graph_db,
source=config.source,
target=config.target,
max_degree=config.max_degree,
Expand All @@ -55,14 +58,17 @@ def assert_conf_has_field(config:cpb.TopicQueryConfig, field:str)->None:
raise ValueError(f"Path is disconnected, {config.source}, {config.target}")
pprint(path)

##############
# Get text from selected sentences

print("Collecting Nearby Sentences")
sentence_ids = set()
for path_node in path:
print("\t-", path_node)
# Each node along the path is allowed to add some sentences
sentence_ids.update(
path_util.get_nearby_nodes(
graph_index=graph_index,
graph_index=graph_db,
source=path_node,
key_type=entity_types.SENTENCE_TYPE,
max_result_size=config.max_sentences_per_path_elem,
Expand All @@ -73,12 +79,12 @@ def assert_conf_has_field(config:cpb.TopicQueryConfig, field:str)->None:
sentence_ids = list(sentence_ids)

print("Downloading Sentence Text for all", len(sentence_ids), "sentences")
# List[List[str]]
bow_db = Sqlite3Bow(config.bow_db)
text_corpus = [
bow[s] for s in sentence_ids if s in bow
bow_db[s] for s in sentence_ids if s in bow_db
]

print("Identifying potential query-specific stopwords")
print("Pruning low-support words")
min_support = config.topic_model.min_support_count
term2doc_freq = bow_util.get_document_frequencies(text_corpus)
stopwords_under = {
Expand All @@ -95,20 +101,21 @@ def assert_conf_has_field(config:cpb.TopicQueryConfig, field:str)->None:
assert len(sentence_ids) == len(text_corpus)

print("Computing topics")
word_index = Dictionary(text_corpus)
int_corpus = [word_index.doc2bow(t) for t in text_corpus]
dictionary = Dictionary(text_corpus)
int_corpus = [dictionary.doc2bow(t) for t in text_corpus]
topic_model = LdaMulticore(
corpus=int_corpus,
id2word=word_index,
id2word=dictionary,
num_topics=config.topic_model.num_topics,
random_state=config.topic_model.random_seed,
iterations=config.topic_model.iterations,
)

#####################################################
# Store results

print("Interpreting")
result = qpb.TopicQueryResult()
result = res_pb.TopicQueryResult()
result.source = config.source
result.target = config.target

Expand All @@ -118,28 +125,31 @@ def assert_conf_has_field(config:cpb.TopicQueryConfig, field:str)->None:

# Add documents from topic model
print("\t- Topics per-document")
for key, bow, words in zip(sentence_ids, int_corpus, text_corpus):
for doc_id, bow, words in zip(sentence_ids, int_corpus, text_corpus):
doc = result.documents.add()
doc.key = key
for word in words:
doc.terms.append(word)
doc.doc_id = doc_id
for topic_idx, weight in topic_model[bow]:
topic_weight = doc.topic_weights.add()
topic_weight.topic = topic_idx
topic_weight.weight = weight
doc.topic2weight[topic_idx] = weight

# Add topics from topic model
print("\t- Words per-topic")
for topic_idx in range(topic_model.num_topics):
topic = result.topics.add()
topic.index = topic_idx
for word_index, weight in topic_model.get_topic_terms(
for word_idx, weight in topic_model.get_topic_terms(
topic_idx,
config.topic_model.truncate_size,
):
term_weight = topic.term_weights.add()
term_weight.term = topic_model.id2word[word_index]
term_weight.weight = weight
term = topic_model.id2word[word_idx]
topic.term2weight[term] = weight

print("\t- Adding Topical Network")
aux_result_data.add_topical_network(
result=result,
topic_model=topic_model,
dictionary=dictionary,
graph_db=graph_db,
bow_db=bow_db,
)

with open(result_path, "wb") as proto_file:
proto_file.write(result.SerializeToString())
Expand Down
131 changes: 131 additions & 0 deletions agatha/topic_query/aux_result_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
This module is responsible for adding auxiliary helper data to the result proto
"""

from agatha.topic_query import (
topic_query_result_pb2 as res_pb,
)
from gensim.models import LdaModel
from gensim.corpora import Dictionary
import numpy as np
from typing import Iterable, Tuple, Dict, Optional
from itertools import combinations
from agatha.util.sqlite3_lookup import Sqlite3Graph, Sqlite3Bow
from collections import defaultdict
from agatha.util.entity_types import is_sentence_type


def estimate_plaintext_from_graph_key(
graph_key:str,
graph_db:Sqlite3Graph,
bow_db:Sqlite3Bow,
max_checks:Optional[int]=100,
)->str:
"""
Given a graph key, get the most likely plaintext word associated with it.
For instance, given "l:noun:cancer" or "m:d009369" we should get something
like "cancer"
If `max_checks` is set to `None`, we will check every sentence associated
with the graph_key, however, we often only need to check a sample.
"""

word2count = defaultdict(int)

# Select a subset of sentences to lookup
sentences_to_check = [graph_key] + graph_db[graph_key]
sentences_to_check = list(filter(is_sentence_type, sentences_to_check))
if max_checks is not None:
sentences_to_check = sentences_to_check[:max_checks]

for neighbor in sentences_to_check:
if neighbor in bow_db:
for word in bow_db[neighbor]:
word2count[word] += 1
max_count = None
res = None
for word, count in word2count.items():
if max_count is None or count > max_count:
max_count = count
res = word
return res


def _weighted_jacc(a:np.array, b:np.array)->float:
return np.sum(np.minimum(a, b)) / np.sum(np.maximum(a, b))


def _all_pairs_jaccard_comparisions(
graph_idx2vec:Dict[int, np.array]
)->Iterable[Tuple[int, int, float]]:
"""
Performs all-pairs comparisons within the set of vectors. If there is a
nonzero similarity between vectors i, and j, then we will generate: (i, j,
sim(i,j)), and (j, i, sim(i,j)).
note: Similarity is symmetrical.
"""

for i, j in combinations(graph_idx2vec.keys(), 2):
sim = _weighted_jacc(graph_idx2vec[i], graph_idx2vec[j])
if sim > 0:
yield (i, j, sim)
yield (j, i, sim)


def add_topical_network(
result:res_pb.TopicQueryResult,
topic_model:LdaModel,
dictionary:Dictionary,
graph_db:Sqlite3Graph,
bow_db:Sqlite3Bow,
)->None:
"""
Adds the topical_network field to the result proto.
Creates this network by the weighted jacquard of topics.
The source and target words are going to be assigned indices -1 and -2.
"""
# Size n_topics X voccab_size
term_topic_mat = topic_model.get_topics()
num_topics, vocab_size = term_topic_mat.shape

source_word = estimate_plaintext_from_graph_key(
graph_key=result.source,
graph_db=graph_db,
bow_db=bow_db,
)
source_word_idx = dictionary.token2id[source_word]
source_graph_idx = -1
source_vec = np.zeros(vocab_size)
source_vec[source_word_idx] = 1

target_word = estimate_plaintext_from_graph_key(
graph_key=result.target,
graph_db=graph_db,
bow_db=bow_db,
)
target_word_idx = dictionary.token2id[target_word]
target_graph_idx = -2
target_vec = np.zeros(vocab_size)
target_vec[target_word_idx] = 1

graph_idx2vec = {
topic_idx: term_topic_mat[topic_idx, :]
for topic_idx in range(num_topics)
}
graph_idx2vec[source_graph_idx] = source_vec
graph_idx2vec[target_graph_idx] = target_vec

# Set all node names
for idx in range(num_topics):
result.topical_network.nodes[idx].name = f"Topic: {idx}"
result.topical_network.nodes[source_graph_idx].name = \
f"Source: '{result.source}' -- '{source_word}'"
result.topical_network.nodes[target_graph_idx].name = \
f"Source: '{result.target}' -- '{target_word}'"

# Set all edges:
for i, j, sim in _all_pairs_jaccard_comparisions(graph_idx2vec):
result.topical_network.nodes[i].neighbors[j] = sim


9 changes: 6 additions & 3 deletions agatha/topic_query/topic_query_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ package agatha;
// Config used to perform topic model queries
message TopicQueryConfig {

// The first term of this query
// The first term of this query. This is a graph key
optional string source = 1;

// The second term of this query
// The second term of this query. This is a graph key
optional string target = 2;

// The path to the graph sqlite database used to run this query
Expand Down Expand Up @@ -36,6 +36,9 @@ message TopicQueryConfig {
// is allows to have. High degree nodes will be downsampled to this rate. A
// higher value indicates a more accurate shortest path, at a longer runtime.
optional int32 max_degree = 10 [default=1000];

//If set, we will not preload the graph
optional bool preload_graph_db = 11 [default=false];
}


Expand All @@ -54,7 +57,7 @@ message LdaConfig {
optional int32 iterations = 3 [default=50];

// Remove any word that does not occur at least X times
optional int32 min_support_count = 4 [default=2];
optional int32 min_support_count = 4 [default=0];

// Take the top X words per-topic, only effects output
optional int32 truncate_size = 7 [default=250];
Expand Down
53 changes: 34 additions & 19 deletions agatha/topic_query/topic_query_result.proto
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
syntax = "proto2";
package agatha;


// A result of the topic_query process
message TopicQueryResult {
// the first word searched
optional string source = 1;
// the second word searched
optional string target = 2;
// The path of graph connections that connections source and target
repeated string path = 3;
// The set of topics produced by this query
repeated Topic topics = 4;
// The set of sentences that were used to create this query
repeated Document documents = 5;
// Topical network information
optional TopicalNetwork topical_network = 6;
}

// Helper Messages

// A topic is a probability distribution over a vocabulary
message Topic {
message TermWeight {
optional string term = 1;
optional float weight = 2;
}
repeated TermWeight term_weights = 1;
optional int32 index = 2;
map<string, float> term2weight = 1;
}

// A document is a single sentence in our subcorpus
message Document {
message TopicWeight {
optional int32 topic = 1;
optional float weight = 2;
}
optional string key = 1;
repeated TopicWeight topic_weights = 2;
repeated string terms = 3;
// the identifier of this document
optional string doc_id = 1;
// The topic mixture of this document. Note that topic2weight[i]
// corresponds to TopicQueryResult.topics[i]
map<int32, float> topic2weight = 2;
}

message TopicQueryResult {
repeated string path = 1;
repeated Document documents = 2;
repeated Topic topics = 3;
optional string source = 4;
optional string target = 5;
// Stores the topic network information
message TopicalNetwork {
message TopicalNode {
map<int32, float> neighbors = 1;
optional string name = 2;
}
map<int32, TopicalNode> nodes = 1;
}

0 comments on commit d532bac

Please sign in to comment.