Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new experimental feature: QueryGraphs #56

Merged
merged 3 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Expand Up @@ -5,6 +5,12 @@ project(
LANGUAGES CXX C
)

option(EXPERIMENTAL_QUERY_GRAPHS "Enable experimental feature: query graphs" OFF)

if(EXPERIMENTAL_QUERY_GRAPHS)
add_definitions(-DEXPERIMENTAL_QUERY_GRAPHS)
endif()

list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
include(ClangFormat)
include(CheckIPOSupported)
Expand Down
2 changes: 2 additions & 0 deletions libursa/CMakeLists.txt
Expand Up @@ -38,6 +38,8 @@ add_library(
QueryGraph.h
QueryParser.cpp
QueryParser.h
QueryResult.cpp
QueryResult.h
RawFile.cpp
RawFile.h
Responses.cpp
Expand Down
1 change: 1 addition & 0 deletions libursa/Command.h
@@ -1,5 +1,6 @@
#pragma once

#include <set>
#include <variant>

#include "Query.h"
Expand Down
29 changes: 29 additions & 0 deletions libursa/Core.cpp
Expand Up @@ -28,3 +28,32 @@ std::vector<uint8_t> QToken::possible_values() const {
}
return options;
}

std::string get_index_type_name(IndexType type) {
switch (type) {
case IndexType::GRAM3:
return "gram3";
case IndexType::TEXT4:
return "text4";
case IndexType::HASH4:
return "hash4";
case IndexType::WIDE8:
return "wide8";
}

throw std::runtime_error("unhandled index type");
}

std::optional<IndexType> index_type_from_string(const std::string &type) {
if (type == "gram3") {
return IndexType::GRAM3;
} else if (type == "text4") {
return IndexType::TEXT4;
} else if (type == "hash4") {
return IndexType::HASH4;
} else if (type == "wide8") {
return IndexType::WIDE8;
} else {
return std::nullopt;
}
}
5 changes: 5 additions & 0 deletions libursa/Core.h
Expand Up @@ -2,6 +2,7 @@

#include <cstddef>
#include <cstdint>
#include <optional>
#include <vector>

using FileId = uint32_t;
Expand Down Expand Up @@ -39,6 +40,10 @@ constexpr bool is_valid_index_type(uint32_t type) {
return false;
}

std::string get_index_type_name(IndexType type);

std::optional<IndexType> index_type_from_string(const std::string &type);

enum class BuilderType { FLAT = 1, BITMAP = 2 };

enum class QTokenType {
Expand Down
11 changes: 11 additions & 0 deletions libursa/FeatureFlags.h
@@ -0,0 +1,11 @@
#pragma once

namespace feature {

#ifdef EXPERIMENTAL_QUERY_GRAPHS
constexpr bool query_graphs = true;
#else
constexpr bool query_graphs = false;
#endif

} // namespace feature
10 changes: 5 additions & 5 deletions libursa/OnDiskDataset.cpp
Expand Up @@ -87,9 +87,9 @@ std::vector<FileId> internal_pick_common(
}

// fix on that particular value selected in previous step and count
// number of repetitions among heads note that it's
// implementation-defined that std::vector<FileId> is always sorted and
// we use this fact here
// number of repetitions among heads.
// Note that it's implementation-defined that std::vector<FileId>
// is always sorted and we use this fact here.
int repeat_count = 0;
for (int i = min_index; i < static_cast<int>(heads.size()); i++) {
if (*heads[i].first == min_id) {
Expand Down Expand Up @@ -138,13 +138,13 @@ QueryResult OnDiskDataset::pick_common(
return QueryResult::everything();
}
} else {
sources_storage.push_back(result);
sources_storage.push_back(std::move(result));
}
}

// Special case optimization for cutoff==1 and a single source.
if (cutoff == 1 && sources_storage.size() == 1) {
return sources_storage[0];
return std::move(sources_storage[0]);
}

std::vector<const std::vector<FileId> *> sources;
Expand Down
1 change: 1 addition & 0 deletions libursa/OnDiskDataset.h
Expand Up @@ -8,6 +8,7 @@
#include "OnDiskFileIndex.h"
#include "OnDiskIndex.h"
#include "Query.h"
#include "QueryResult.h"
#include "ResultWriter.h"
#include "Task.h"

Expand Down
22 changes: 19 additions & 3 deletions libursa/OnDiskIndex.cpp
Expand Up @@ -5,8 +5,10 @@
#include <fstream>
#include <iostream>

#include "FeatureFlags.h"
#include "Query.h"
#include "Utils.h"
#include "spdlog/spdlog.h"

#pragma pack(1)
struct OnDiskIndexHeader {
Expand Down Expand Up @@ -148,11 +150,25 @@ QueryResult OnDiskIndex::query_str(const QString &str) const {
case IndexType::WIDE8:
input_len = 8;
break;
}
if (input_len == 0) {
throw std::runtime_error("unhandled index type");
default:
throw std::runtime_error("unhandled index type");
}

if (::feature::query_graphs && input_len <= 4) {
spdlog::info("Experimental graph query for {}",
get_index_type_name(index_type()));
QueryGraph graph{QueryGraph::from_qstring(str)};
for (int i = 0; i < input_len - 1; i++) {
spdlog::info("Computing dual graph ({} nodes)", graph.size());
graph = graph.dual();
}
spdlog::info("Final graph has {} nodes", graph.size());
QueryFunc oracle = [this](uint32_t raw_gram) {
uint32_t gram = convert_gram(index_type(), raw_gram);
return query_primitive(gram);
};
return graph.run(oracle);
}
return expand_wildcards(str, input_len, generator);
}

Expand Down
2 changes: 2 additions & 0 deletions libursa/OnDiskIndex.h
Expand Up @@ -9,8 +9,10 @@
#include "Query.h"
#include "RawFile.h"
#include "Task.h"
#include "Utils.h"

struct IndexMergeHelper;
class QueryResult;

class OnDiskIndex {
uint64_t index_size;
Expand Down
29 changes: 2 additions & 27 deletions libursa/Query.cpp
@@ -1,32 +1,5 @@
#include "Query.h"

void QueryResult::do_or(const QueryResult &&other) {
if (this->is_everything() || other.is_everything()) {
has_everything = true;
return;
}

std::vector<FileId> new_results;
std::set_union(other.results.begin(), other.results.end(), results.begin(),
results.end(), std::back_inserter(new_results));
std::swap(new_results, results);
}

void QueryResult::do_and(const QueryResult &&other) {
if (other.is_everything()) {
return;
}
if (this->is_everything()) {
*this = QueryResult(other);
return;
}

auto new_end =
std::set_intersection(other.results.begin(), other.results.end(),
results.begin(), results.end(), results.begin());
results.erase(new_end, results.end());
}

const std::vector<Query> &Query::as_queries() const {
if (type != QueryType::AND && type != QueryType::OR &&
type != QueryType::MIN_OF) {
Expand Down Expand Up @@ -75,6 +48,8 @@ std::ostream &operator<<(std::ostream &os, const Query &query) {
os << ")";
} else if (type == QueryType::PRIMITIVE) {
os << "'" << query.as_string_repr() << "'";
} else {
throw std::runtime_error("Unknown query type.");
}
return os;
}
Expand Down
29 changes: 1 addition & 28 deletions libursa/Query.h
Expand Up @@ -6,37 +6,10 @@
#include <string>
#include <vector>

#include "Core.h"
#include "Utils.h"
#include "QueryGraph.h"

enum QueryType { PRIMITIVE = 1, AND = 2, OR = 3, MIN_OF = 4 };

class QueryResult {
private:
std::vector<FileId> results;
bool has_everything;

QueryResult() : has_everything(true) {}

public:
QueryResult(std::vector<FileId> results)
: results(results), has_everything(false) {}

static QueryResult empty() { return QueryResult(std::vector<FileId>()); }

static QueryResult everything() { return QueryResult(); }

void do_or(const QueryResult &&other);

void do_and(const QueryResult &&other);

// If true, means that QueryResults represents special
// "uninitialized" value, "set of all FileIds in DataSet".
bool is_everything() const { return has_everything; }

const std::vector<FileId> &vector() const { return results; }
};

class Query {
public:
explicit Query(const QString &qstr);
Expand Down
90 changes: 90 additions & 0 deletions libursa/QueryGraph.cpp
@@ -1,6 +1,9 @@
#include "QueryGraph.h"

#include <map>
#include <set>

#include "spdlog/spdlog.h"

QueryGraph QueryGraph::dual() const {
QueryGraph result;
Expand All @@ -13,6 +16,11 @@ QueryGraph QueryGraph::dual() const {
newnodes.emplace(Edge{source, target}, result.make_node(gram));
}
}
for (NodeId source : sources_) {
for (NodeId target : get(source).edges()) {
result.sources_.push_back(newnodes.at(Edge{source, target}));
}
}
for (const auto &[edge, node] : newnodes) {
auto &[from, to] = edge;
for (const auto &target : get(to).edges()) {
Expand All @@ -23,6 +31,84 @@ QueryGraph QueryGraph::dual() const {
return result;
}

class NodeState {
std::vector<NodeId> ready_predecessors_;
uint32_t total_predecessors_;
QueryResult state_;

public:
NodeState() : state_(QueryResult::everything()), total_predecessors_(0) {}

const QueryResult &state() const { return state_; }

const std::vector<NodeId> &ready_predecessors() const {
return ready_predecessors_;
}

void set(QueryResult &&state) { state_ = std::move(state); }

void add_predecessor() { total_predecessors_++; }

void add_ready_predecessor(NodeId ready) {
ready_predecessors_.push_back(ready);
}

bool ready() const {
return ready_predecessors_.size() >= total_predecessors_;
}
};

QueryResult masked_or(std::vector<const QueryResult *> *to_or,
QueryResult &&mask) {
if (to_or->empty()) {
return std::move(mask);
}
QueryResult result{QueryResult::empty()};
for (auto query : *to_or) {
// TODO(msm): we should do everything in parallel here.
QueryResult alternative{query->vector()};
alternative.do_and(mask);
result.do_or(std::move(alternative));
}
return result;
}

QueryResult QueryGraph::run(const QueryFunc &oracle) const {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not a strange fate that we should suffer so much fear and doubt for so small a thing?

if (sources_.empty()) {
return QueryResult::everything();
}
std::vector<NodeId> ready = sources_;
QueryResult result{QueryResult::empty()};
std::vector<NodeState> states(nodes_.size());
for (size_t ndx = 0; ndx < nodes_.size(); ndx++) {
for (NodeId target : get(NodeId(ndx)).edges()) {
states.at(target.get()).add_predecessor();
}
}
while (ready.size()) {
NodeId nextid = ready.back();
ready.pop_back();
NodeState &next = states[nextid.get()];
std::vector<const QueryResult *> pred_states;
pred_states.reserve(next.ready_predecessors().size());
for (const auto &pred : next.ready_predecessors()) {
pred_states.push_back(&states[pred.get()].state());
}
QueryResult next_state{oracle(get(nextid).gram())};
next.set(std::move(masked_or(&pred_states, std::move(next_state))));
for (const auto &succ : get(nextid).edges()) {
states[succ.get()].add_ready_predecessor(nextid);
if (states[succ.get()].ready()) {
ready.push_back(succ);
}
}
if (get(nextid).edges().size() == 0) {
result.do_or(next.state());
}
}
return result;
}

QueryGraph QueryGraph::from_qstring(const QString &qstr) {
QueryGraph result;

Expand All @@ -36,6 +122,10 @@ QueryGraph QueryGraph::from_qstring(const QString &qstr) {
}
new_sinks.push_back(node);
}
if (result.sources_.empty()) {
spdlog::info("Setting up sources with {} nodes", new_sinks.size());
result.sources_ = new_sinks;
}
sinks = std::move(new_sinks);
}

Expand Down