Skip to content

Commit

Permalink
add tests for distribution utility and fix loose ends
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Apr 13, 2023
1 parent 1a70ac7 commit 6249078
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_executable(test-z3
datalog_parser.cpp
ddnf.cpp
diff_logic.cpp
distribution.cpp
dl_context.cpp
dl_product_relation.cpp
dl_query.cpp
Expand Down
45 changes: 45 additions & 0 deletions src/test/distribution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*++
Copyright (c) 2023 Microsoft Corporation
Module Name:
distribution.cpp
Abstract:
Test distribution
Author:
Nikolaj Bjorner (nbjorner) 2023-04-13
--*/
#include "util/distribution.h"
#include <iostream>

static void tst1() {
distribution dist(1);
dist.push(1, 3);
dist.push(2, 1);
dist.push(3, 1);
dist.push(4, 1);

unsigned counts[4] = { 0, 0, 0, 0 };
for (unsigned i = 0; i < 1000; ++i)
counts[dist.choose()-1]++;
for (unsigned i = 1; i <= 4; ++i)
std::cout << "count " << i << ": " << counts[i-1] << "\n";

for (unsigned i = 0; i < 5; ++i) {
std::cout << "enum ";
for (auto j : dist)
std::cout << j << " ";
std::cout << "\n";
}

}

void tst_distribution() {
tst1();
}
1 change: 1 addition & 0 deletions src/test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,5 @@ int main(int argc, char ** argv) {
//TST_ARGV(hs);
TST(finder);
TST(totalizer);
TST(distribution);
}
16 changes: 10 additions & 6 deletions src/util/distribution.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*++
Copyright (c) 2017 Microsoft Corporation
Copyright (c) 2023 Microsoft Corporation
Module Name:
Expand All @@ -18,6 +18,8 @@ Module Name:
Distribution class works by pushing identifiers with associated scores.
After they have been pushed, you can access a random element using choose
or you can enumerate the elements in random order, sorted by the score probability.
Only one iterator can be active at a time because the iterator reshuffles the registered elements.
The requirement is not checked or enforced.
--*/
#pragma once
Expand All @@ -32,10 +34,12 @@ class distribution {

unsigned choose(unsigned sum) {
unsigned s = m_random(sum);
unsigned idx = 0;
for (auto const& [j, score] : m_elems) {
if (s < score)
return j;
return idx;
s -= score;
++idx;
}
UNREACHABLE();
return 0;
Expand Down Expand Up @@ -76,9 +80,8 @@ class distribution {
unsigned m_sum = 0;
unsigned m_index = 0;
void next_index() {
if (0 == m_sz)
return;
m_index = d.choose(m_sum);
if (0 != m_sz)
m_index = d.choose(m_sum);
}
public:
iterator(distribution& d, bool start): d(d), m_sz(start?d.m_elems.size():0), m_sum(d.m_sum) {
Expand All @@ -88,8 +91,9 @@ class distribution {
iterator operator++() {
m_sum -= d.m_elems[m_index].second;
--m_sz;
std::swap(d.m_elems[m_index], d.m_elems[d.m_elems.size() - 1]);
std::swap(d.m_elems[m_index], d.m_elems[m_sz]);
next_index();
return *this;
}
bool operator==(iterator const& other) const { return m_sz == other.m_sz; }
bool operator!=(iterator const& other) const { return m_sz != other.m_sz; }
Expand Down

0 comments on commit 6249078

Please sign in to comment.