Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge branch 'master' of github.com:JohnLangford/vowpal_wabbit

  • Loading branch information...
commit 38734e4ff665baaad3bf24527af4983f54626281 2 parents 7f00159 + 91d9285
John authored
18 Makefile.am
View
@@ -7,17 +7,17 @@ noinst_HEADERS = vowpalwabbit/accumulate.h vowpalwabbit/oaa.h \
vowpalwabbit/parse_example.h vowpalwabbit/cache.h \
vowpalwabbit/parse_primitives.h vowpalwabbit/comp_io.h \
vowpalwabbit/parse_regressor.h vowpalwabbit/constant.h \
- vowpalwabbit/parser.h vowpalwabbit/csoaa.h \
+ vowpalwabbit/parser.h vowpalwabbit/csoaa.h vowpalwabbit/beam.h \
vowpalwabbit/searn.h vowpalwabbit/ect.h \
vowpalwabbit/searn_sequencetask.h vowpalwabbit/example.h \
- vowpalwabbit/sender.h vowpalwabbit/gd.h \
- vowpalwabbit/sequence.h vowpalwabbit/gd_mf.h \
- vowpalwabbit/simple_label.h vowpalwabbit/global_data.h \
- vowpalwabbit/sparse_dense.h vowpalwabbit/hash.h \
- vowpalwabbit/unique_sort.h vowpalwabbit/io.h \
- vowpalwabbit/v_array.h vowpalwabbit/lda_core.h \
- vowpalwabbit/v_hashmap.h vowpalwabbit/loss_functions.h \
- vowpalwabbit/network.h vowpalwabbit/wap.h vowpalwabbit/noop.h
+ vowpalwabbit/sender.h vowpalwabbit/gd.h vowpalwabbit/sequence.h \
+ vowpalwabbit/gd_mf.h vowpalwabbit/simple_label.h \
+ vowpalwabbit/global_data.h vowpalwabbit/sparse_dense.h \
+ vowpalwabbit/hash.h vowpalwabbit/unique_sort.h \
+ vowpalwabbit/io.h vowpalwabbit/v_array.h \
+ vowpalwabbit/lda_core.h vowpalwabbit/v_hashmap.h \
+ vowpalwabbit/loss_functions.h vowpalwabbit/network.h \
+ vowpalwabbit/wap.h vowpalwabbit/noop.h
ACLOCAL_AMFLAGS = -I acinclude.d
10 README.hal3
View
@@ -0,0 +1,10 @@
+To get John's changes:
+ git pull git://github.com/JohnLangford/vowpal_wabbit.git master
+
+To enable use of valgrind, do:
+ ./configure --enable-profiling
+
+To do fancy learning, do
+ --exact_adaptive_norm --power_t 1
+or
+ --exact_adaptive_norm --power_t 0.5
2  vowpalwabbit/Makefile.am
View
@@ -5,7 +5,7 @@ include_HEADERS = allreduce.h
bin_PROGRAMS = vw active_interactor
-libvw_la_SOURCES = hash.cc global_data.cc io.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc simple_label.cc oaa.cc ect.cc csoaa.cc v_hashmap.cc wap.cc searn.cc searn_sequencetask.cc sequence.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc loss_functions.cc sender.cc
+libvw_la_SOURCES = hash.cc global_data.cc io.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc simple_label.cc oaa.cc ect.cc csoaa.cc v_hashmap.cc wap.cc beam.cc searn.cc searn_sequencetask.cc sequence.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc vw.cc loss_functions.cc sender.cc
vw_SOURCES = vw.cc
232 vowpalwabbit/beam.cc
View
@@ -0,0 +1,232 @@
+#include <iostream>
+#include <float.h>
+#include <stdio.h>
+#include <math.h>
+#include "beam.h"
+#include "v_hashmap.h"
+#include "v_array.h"
+
+#define MULTIPLIER 5
+
+using namespace std;
+
+namespace Beam
+{
+ int compare_elem(const void *va, const void *vb) {
+ // first sort on hash, then on loss
+ elem* a = (elem*)va;
+ elem* b = (elem*)vb;
+ if (a->hash < b->hash) { return -1; }
+ if (a->hash > b->hash) { return 1; }
+ return b->loss - a->loss; // if b is greater, it should go second
+ }
+
+ beam::beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size) {
+ equivalent = eq;
+ hash = hs;
+ empty_bucket = new v_array<elem>();
+ last_retrieved = NULL;
+ max_size = max_beam_size;
+ losses = (float*)calloc(max_size, sizeof(float));
+ dat = new v_hashmap<size_t,bucket>(8, empty_bucket, NULL);
+ }
+
+ beam::~beam() {
+ // TODO: really free the elements
+ delete dat;
+ free(empty_bucket->begin);
+ delete empty_bucket;
+ }
+
+ size_t hash_bucket(size_t id) { return 1043221*(893901 + id); }
+
+ void beam::put(size_t id, state s, size_t hs, size_t act, float loss) {
+ elem e = { s, hs, loss, id, last_retrieved, act, true };
+ // check to see if we have this bucket yet
+ bucket b = dat->get(id, hash_bucket(id));
+ if (b->index() > 0) { // this one exists: just add to it
+ push(*b, e);
+ //dat->put_after_get(id, hash_bucket(id), b);
+ if (b->index() >= max_size * MULTIPLIER)
+ prune(id);
+ } else {
+ bucket bnew = new v_array<elem>();
+ push(*bnew, e);
+ dat->put_after_get(id, hash_bucket(id), bnew);
+ }
+ }
+
+ void beam::put_final(state s, size_t act, float loss) {
+ elem e = { s, 0, loss, 0, last_retrieved, act, true };
+ push(*final_states, e);
+ }
+
+ void beam::iterate(size_t id, void (*f)(beam*,size_t,state,float,void*), void*args) {
+ bucket b = dat->get(id, hash_bucket(id));
+ if (b->index() == 0) return;
+
+ cout << "before prune" << endl;
+ prune(id);
+ cout << "after prune" << endl;
+
+ for (elem*e=b->begin; e!=b->end; e++) {
+ cout << "element" << endl;
+ if (e->alive) {
+ last_retrieved = e;
+ f(this, id, e->s, e->loss, args);
+ }
+ }
+ }
+
+ #define SWAP(a,b) temp=(a);(a)=(b);(b)=temp;
+ float quickselect(float *arr, size_t n, size_t k) {
+ size_t i,ir,j,l,mid;
+ float a,temp;
+
+ l=0;
+ ir=n-1;
+ for(;;) {
+ if (ir <= l+1) {
+ if (ir == l+1 && arr[ir] < arr[l]) {
+ SWAP(arr[l],arr[ir]);
+ }
+ return arr[k];
+ }
+ else {
+ mid=(l+ir) >> 1;
+ SWAP(arr[mid],arr[l+1]);
+ if (arr[l] > arr[ir]) {
+ SWAP(arr[l],arr[ir]);
+ }
+ if (arr[l+1] > arr[ir]) {
+ SWAP(arr[l+1],arr[ir]);
+ }
+ if (arr[l] > arr[l+1]) {
+ SWAP(arr[l],arr[l+1]);
+ }
+ i=l+1;
+ j=ir;
+ a=arr[l+1];
+ for (;;) {
+ do i++; while (arr[i] < a);
+ do j--; while (arr[j] > a);
+ if (j < i) break;
+ SWAP(arr[i],arr[j]);
+ }
+ arr[l+1]=arr[j];
+ arr[j]=a;
+ if (j >= k) ir=j-1;
+ if (j <= k) l=i;
+ }
+ }
+ }
+
+
+ void beam::prune(size_t id) {
+ bucket b = dat->get(id, hash_bucket(id));
+ if (b->index() == 0) return;
+
+ size_t num_alive = 0;
+ if (equivalent == NULL) {
+ for (size_t i=1; i<b->index(); i++) {
+ (*b)[i].alive = true;
+ }
+ num_alive = b->index();
+ } else {
+ // first, sort on hash, backing off to loss
+ qsort(b->begin, b->index(), sizeof(elem), compare_elem);
+
+ // now, check actual equivalence
+ size_t last_pos = 0;
+ size_t last_hash = (*b)[0].hash;
+ for (size_t i=1; i<b->index(); i++) {
+ (*b)[i].alive = true;
+ if ((*b)[i].hash != last_hash) {
+ last_pos = i;
+ last_hash = (*b)[i].hash;
+ } else {
+ for (size_t j=last_pos; j<i; j++) {
+ if ((*b)[j].alive && equivalent((*b)[j].s, (*b)[i].s)) {
+ (*b)[i].alive = false;
+ break;
+ }
+ }
+ }
+
+ if ((*b)[i].alive) {
+ losses[num_alive] = (*b)[i].loss;
+ num_alive++;
+ }
+ }
+ }
+
+ if (num_alive <= max_size) return;
+
+ // sort the remaining items on loss
+ float cutoff = quickselect(losses, num_alive, max_size);
+ bucket bnew = new v_array<elem>();
+ for (elem*e=b->begin; e!=b->end; e++) {
+ if (e->loss > cutoff) continue;
+ push(*bnew, *e);
+ num_alive--;
+ if (num_alive < 0) break;
+ }
+ dat->put_after_get(id, hash_bucket(id), bnew);
+ }
+
+ size_t beam::get_next_bucket(size_t start) {
+ size_t next_bucket = 0;
+ for (v_hashmap<size_t,bucket>::hash_elem* e=dat->dat.begin; e!=dat->dat.end_array; e++) {
+ if (e->occupied) {
+ size_t bucket_id = e->key;
+ if ((bucket_id > start) && (bucket_id < next_bucket))
+ next_bucket = bucket_id;
+ }
+ }
+ return next_bucket;
+ }
+
+ void beam::get_best_output(std::vector<size_t>* action_seq) {
+ action_seq->clear();
+ if (final_states->index() == 0) {
+ // TODO: error
+ return;
+ } else {
+ elem *bestElem = NULL;
+ for (size_t i=0; i<final_states->index(); i++) {
+ if ((bestElem == NULL) || ((*final_states)[i].loss < bestElem->loss))
+ bestElem = &(*final_states)[i];
+ }
+ // chase backpointers
+ while (bestElem != NULL) {
+ std::vector<size_t>::iterator be = action_seq->begin();
+ action_seq->insert( be, bestElem->last_action );
+ bestElem = bestElem->backpointer;
+ }
+ }
+ }
+
+
+ struct test_beam_state {
+ size_t id;
+ };
+ bool state_eq(state a,state b) { return ((test_beam_state*)a)->id == ((test_beam_state*)b)->id; }
+ size_t state_hash(state a) { return 381049*(3820+((test_beam_state*)a)->id); }
+ void expand_state(beam*b, size_t old_id, state old_state, float old_loss, void*args) {
+ test_beam_state* new_state = (test_beam_state*)calloc(1, sizeof(test_beam_state));
+ new_state->id = old_id + ((test_beam_state*)old_state)->id * 2;
+ float new_loss = old_loss + 0.5;
+ cout << "expand_state " << old_loss << " -> " << new_state->id << " , " << new_loss << endl;
+ b->put(old_id+1, new_state, 0, new_loss);
+ }
+ void test_beam() {
+ beam*b = new beam(&state_eq, &state_hash, 5);
+ for (size_t i=0; i<25; i++) {
+ test_beam_state* s = (test_beam_state*)calloc(1, sizeof(test_beam_state));
+ s->id = i / 3;
+ b->put(0, s, 0, 0. - (float)i);
+ cout << "added " << s->id << endl;
+ }
+ b->iterate(0, expand_state, NULL);
+ }
+}
51 vowpalwabbit/beam.h
View
@@ -0,0 +1,51 @@
+#ifndef BEAM_H
+#define BEAM_H
+
+#include <vector>
+#include <stdio.h>
+#include "v_hashmap.h"
+#include "v_array.h"
+
+typedef void* state;
+
+namespace Beam
+{
+ struct elem {
+ state s;
+ size_t hash;
+ float loss;
+ size_t bucket_id;
+ elem* backpointer;
+ size_t last_action;
+ bool alive;
+ };
+
+ typedef v_array<elem>* bucket;
+
+ class beam {
+ private:
+ bool (*equivalent)(state, state);
+ size_t (*hash)(state);
+
+ v_hashmap<size_t, bucket>* dat;
+ bucket final_states;
+
+ bucket empty_bucket;
+ elem* last_retrieved;
+ size_t max_size;
+ float* losses;
+
+ public:
+ beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size);
+ ~beam();
+ void put(size_t id, state s, size_t hs, size_t act, float loss);
+ void put(size_t id, state s, size_t act, float loss) { put(id, s, hash(s), act, loss); }
+ void put_final(state s, size_t act, float loss);
+ void iterate(size_t id, void (*f)(beam*,size_t,state,float,void*), void*);
+ void prune(size_t id);
+ size_t get_next_bucket(size_t start);
+ void get_best_output(std::vector<size_t>*);
+ };
+}
+
+#endif
4 vowpalwabbit/parse_example.cc
View
@@ -95,14 +95,14 @@ class TC_parser {
}
inline void maybeFeature(){
- if(*reading_head == ' ' || *reading_head == '|'|| reading_head == endLine ){
+ if(*reading_head == ' ' || *reading_head == '|'|| reading_head == endLine ){
// maybeFeature --> ø
}else if(*reading_head != ':'){
// maybeFeature --> 'String' FeatureValue
substring feature_name ;
feature_name.begin = reading_head;
v_array<char> feature_v;
- while( !(*reading_head == ' ' || *reading_head == ':' ||*reading_head == '|' ||reading_head == endLine )){
+ while( !(*reading_head == ' ' || *reading_head == ':' ||*reading_head == '|' ||reading_head == endLine )){
if(audit){
push(feature_v,*reading_head);
}
71 vowpalwabbit/parse_primitives.h
View
@@ -108,42 +108,49 @@ inline void print_substring(substring s)
// in charge of error detection.
inline float parseFloat(char * p, char **end)
{
- if (!*p || *p == '?')
- return 0;
- int s = 1;
- while (*p == ' ') p++;
-
+ char* start = p;
+
+ if (!*p)
+ return 0;
+ int s = 1;
+ while (*p == ' ') p++;
+
+ if (*p == '-') {
+ s = -1; p++;
+ }
+
+ double acc = 0;
+ while (*p >= '0' && *p <= '9')
+ acc = acc * 10 + *p++ - '0';
+
+ int num_dec = 0;
+ if (*p == '.') {
+ p++;
+ while (*p >= '0' && *p <= '9') {
+ acc = acc *10 + (*p++ - '0') ;
+ num_dec++;
+ }
+ }
+ int exp_acc = 0;
+ if(*p == 'e' || *p == 'E'){
+ p++;
+ int exp_s = 1;
if (*p == '-') {
- s = -1; p++;
+ exp_s = -1; p++;
}
-
- double acc = 0;
while (*p >= '0' && *p <= '9')
- acc = acc * 10 + *p++ - '0';
-
- int num_dec = 0;
- if (*p == '.') {
- p++;
- while (*p >= '0' && *p <= '9') {
- acc = acc *10 + (*p++ - '0') ;
- num_dec++;
- }
- }
- int exp_acc = 0;
- if(*p == 'e' || *p == 'E'){
- p++;
- int exp_s = 1;
- if (*p == '-') {
- exp_s = -1; p++;
- }
- while (*p >= '0' && *p <= '9')
- exp_acc = exp_acc * 10 + *p++ - '0';
- exp_acc *= exp_s;
-
+ exp_acc = exp_acc * 10 + *p++ - '0';
+ exp_acc *= exp_s;
+
+ }
+ if (*p == ' ')//easy case succeeded.
+ {
+ acc *= pow(10,exp_acc-num_dec);
+ *end = p;
+ return s * acc;
}
- acc *= pow(10,exp_acc-num_dec);
- *end = p;
- return s * acc;
+ else
+ return strtof(start,end);
}
inline float float_of_substring(substring s)
236 vowpalwabbit/searn.cc
View
@@ -10,6 +10,7 @@
#include "oaa.h"
#include "csoaa.h"
#include "v_hashmap.h"
+#include "beam.h"
// task-specific includes
#include "searn_sequencetask.h"
@@ -264,187 +265,6 @@ namespace SearnUtil
}
-namespace Beam
-{
- int compare_elem(const void *va, const void *vb) {
- // first sort on hash, then on loss
- elem* a = (elem*)va;
- elem* b = (elem*)vb;
- if (a->hash < b->hash) { return -1; }
- if (a->hash > b->hash) { return 1; }
- return b->loss - a->loss; // if b is greater, it should go second
- }
-
- beam::beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size) {
- equivalent = eq;
- hash = hs;
- empty_bucket = v_array<elem>();
- last_retrieved = NULL;
- max_size = max_beam_size;
- losses = (float*)calloc(max_size, sizeof(float));
- dat = new v_hashmap<size_t,bucket>(8, empty_bucket, NULL);
- }
-
- beam::~beam() {
- // TODO: really free the elements
- delete dat;
- free(empty_bucket.begin);
- }
-
- size_t hash_bucket(size_t id) { return 1043221*(893901 + id); }
-
- void beam::put(size_t id, state s, size_t hs, float loss) {
- elem e = { s, hs, loss, id, last_retrieved };
- // check to see if we have this bucket yet
- bucket b = dat->get(id, hash_bucket(id));
- if (b.index() > 0) { // this one exists: just add to it
- push(b, e);
- dat->put_after_get(id, hash_bucket(id), b);
- } else {
- bucket bnew = v_array<elem>();
- push(bnew, e);
- dat->put_after_get(id, hash_bucket(id), bnew);
- }
- }
-
- void beam::iterate(size_t id, void (*f)(beam*,size_t,state,float)) {
- bucket b = dat->get(id, hash_bucket(id));
- if (b.index() == 0) return;
-
- cout << "before prune" << endl;
- prune(id);
- cout << "after prune" << endl;
-
- for (elem*e=b.begin; e!=b.end; e++) {
- cout << "element" << endl;
- if (e->alive) {
- last_retrieved = e;
- f(this, id, e->s, e->loss);
- }
- }
- }
-
- #define SWAP(a,b) temp=(a);(a)=(b);(b)=temp;
- float quickselect(float *arr, size_t n, size_t k) {
- size_t i,ir,j,l,mid;
- float a,temp;
-
- l=0;
- ir=n-1;
- for(;;) {
- if (ir <= l+1) {
- if (ir == l+1 && arr[ir] < arr[l]) {
- SWAP(arr[l],arr[ir]);
- }
- return arr[k];
- }
- else {
- mid=(l+ir) >> 1;
- SWAP(arr[mid],arr[l+1]);
- if (arr[l] > arr[ir]) {
- SWAP(arr[l],arr[ir]);
- }
- if (arr[l+1] > arr[ir]) {
- SWAP(arr[l+1],arr[ir]);
- }
- if (arr[l] > arr[l+1]) {
- SWAP(arr[l],arr[l+1]);
- }
- i=l+1;
- j=ir;
- a=arr[l+1];
- for (;;) {
- do i++; while (arr[i] < a);
- do j--; while (arr[j] > a);
- if (j < i) break;
- SWAP(arr[i],arr[j]);
- }
- arr[l+1]=arr[j];
- arr[j]=a;
- if (j >= k) ir=j-1;
- if (j <= k) l=i;
- }
- }
- }
-
-
- void beam::prune(size_t id) {
- bucket b = dat->get(id, hash_bucket(id));
- if (b.index() == 0) return;
-
- size_t num_alive = 0;
- if (equivalent == NULL) {
- for (size_t i=1; i<b.index(); i++) {
- b[i].alive = true;
- }
- num_alive = b.index();
- } else {
- // first, sort on hash, backing off to loss
- qsort(b.begin, b.index(), sizeof(elem), compare_elem);
-
- // now, check actual equivalence
- size_t last_pos = 0;
- size_t last_hash = b[0].hash;
- for (size_t i=1; i<b.index(); i++) {
- b[i].alive = true;
- if (b[i].hash != last_hash) {
- last_pos = i;
- last_hash = b[i].hash;
- } else {
- for (size_t j=last_pos; j<i; j++) {
- if (b[j].alive && equivalent(b[j].s, b[i].s)) {
- b[i].alive = false;
- break;
- }
- }
- }
-
- if (b[i].alive) {
- losses[num_alive] = b[i].loss;
- num_alive++;
- }
- }
- }
-
- if (num_alive <= max_size) return;
-
- // sort the remaining items on loss
- float cutoff = quickselect(losses, num_alive, max_size);
- bucket bnew = v_array<elem>();
- for (elem*e=b.begin; e!=b.end; e++) {
- if (e->loss > cutoff) continue;
- push(bnew, *e);
- num_alive--;
- if (num_alive < 0) break;
- }
- dat->put_after_get(id, hash_bucket(id), bnew);
- }
-
-
- struct test_beam_state {
- size_t id;
- };
- bool state_eq(state a,state b) { return ((test_beam_state*)a)->id == ((test_beam_state*)b)->id; }
- size_t state_hash(state a) { return 381049*(3820+((test_beam_state*)a)->id); }
- void expand_state(beam*b, size_t old_id, state old_state, float old_loss) {
- test_beam_state* new_state = (test_beam_state*)calloc(1, sizeof(test_beam_state));
- new_state->id = old_id + ((test_beam_state*)old_state)->id * 2;
- float new_loss = old_loss + 0.5;
- cout << "expand_state " << old_loss << " -> " << new_state->id << " , " << new_loss << endl;
- b->put(old_id+1, new_state, new_loss);
- }
- void test_beam() {
- beam*b = new beam(&state_eq, &state_hash, 5);
- for (size_t i=0; i<25; i++) {
- test_beam_state* s = (test_beam_state*)calloc(1, sizeof(test_beam_state));
- s->id = i / 3;
- b->put(0, s, 0. - (float)i);
- cout << "added " << s->id << endl;
- }
- b->iterate(0, expand_state);
- }
-}
-
namespace Searn
{
// task stuff
@@ -892,8 +712,7 @@ namespace Searn
all.finish = finish;
}
-
- size_t searn_predict(vw&all, state s0, size_t step, bool allow_oracle, bool allow_current)
+ size_t searn_predict(vw&all, state s0, size_t step, bool allow_oracle, bool allow_current, v_array< pair<size_t,float> >* partial_predictions) // TODO: partial_predictions
{
int policy = SearnUtil::random_policy(has_hash ? task.hash(s0) : step, beta, allow_current, current_policy, allow_oracle);
if (PRINT_DEBUG_INFO) { cerr << "predicing with policy " << policy << " (allow_oracle=" << allow_oracle << ", allow_current=" << allow_current << "), current_policy=" << current_policy << endl; }
@@ -962,7 +781,7 @@ namespace Searn
return best_action;
}
}
-
+
void parallel_rollout(vw&all, state s0)
{
// first, make K copies of s0 and step them
@@ -997,7 +816,7 @@ namespace Searn
if (action == 0) { // this means we didn't find it or we're not recombining
if( !rollout_oracle )
- action = searn_predict(all, rollout[k-1].st, step, true, allow_current_policy);
+ action = searn_predict(all, rollout[k-1].st, step, true, allow_current_policy, NULL);
else
action = task.oracle(rollout[k-1].st);
@@ -1116,7 +935,7 @@ namespace Searn
{
int step = 1;
while (!task.final(s0)) {
- size_t action = searn_predict(all, s0, step, allow_oracle, allow_current);
+ size_t action = searn_predict(all, s0, step, allow_oracle, allow_current, NULL);
if (track_actions)
action_sequence->push_back(action);
@@ -1124,6 +943,49 @@ namespace Searn
step++;
}
}
+
+ struct beam_info_struct {
+ vw&all;
+ bool allow_oracle;
+ bool allow_current;
+ };
+
+ void run_prediction_beam_iter(Beam::beam*b, size_t bucket_id, state s0, float cur_loss, void*args)
+ {
+ beam_info_struct* bi = (beam_info_struct*)args;
+
+ if (task.final(s0)) return;
+
+ v_array< pair<size_t,float> > partial_predictions;
+ searn_predict(bi->all, s0, bucket_id, bi->allow_oracle, bi->allow_current, &partial_predictions);
+ for (size_t i=0; i<partial_predictions.index(); i++) {
+ state s1 = task.copy(s0);
+ float new_loss = cur_loss + partial_predictions[i].second;
+ size_t action = partial_predictions[i].first;
+ task.step( s1, action );
+ b->put( task.bucket(s1), s1, action, new_loss );
+ }
+ }
+
+ void run_prediction_beam(vw&all, size_t max_beam_size, state s0, bool allow_oracle, bool allow_current, bool track_actions, std::vector<action>* action_sequence)
+ {
+ Beam::beam *b = new Beam::beam(task.equivalent, task.hash, max_beam_size);
+
+ beam_info_struct bi = { all, allow_oracle, allow_current };
+
+ b->put(task.bucket(s0), s0, 0, 0.);
+ size_t current_bucket = 0;
+ while (true) {
+ current_bucket = b->get_next_bucket(current_bucket);
+ if (current_bucket == 0) break;
+ b->iterate(current_bucket, run_prediction_beam_iter, &bi);
+ }
+
+ if (track_actions && (action_sequence != NULL))
+ b->get_best_output(action_sequence);
+
+ delete b;
+ }
// void hm_free_state_copies(state s, action a) {
// task.finish(s);
@@ -1193,7 +1055,7 @@ namespace Searn
// first, make a prediction (we don't want to bias ourselves if
// we're using the current policy to predict)
- size_t action = searn_predict(all, s0, step, true, allow_current_policy);
+ size_t action = searn_predict(all, s0, step, true, allow_current_policy, NULL);
// generate training example for the current state
generate_state_example(all, s0);
44 vowpalwabbit/searn.h
View
@@ -37,41 +37,7 @@ namespace SearnUtil
void add_history_to_example(vw&, history_info*, example*, history);
void remove_history_from_example(vw&, history_info *, example*);
-}
-
-namespace Beam
-{
- struct elem {
- state s;
- size_t hash;
- float loss;
- size_t bucket_id;
- elem* backpointer;
- bool alive;
- };
-
- typedef v_array<elem> bucket;
-
- class beam {
- public:
- bool (*equivalent)(state, state);
- size_t (*hash)(state);
-
- bucket empty_bucket;
- v_hashmap<size_t, bucket>* dat;
- elem* last_retrieved;
- size_t max_size;
- float* losses;
-
- beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size);
- ~beam();
- void put(size_t id, state s, size_t hs, float loss);
- void put(size_t id, state s, float loss) { put(id, s, hash(s), loss); }
- void iterate(size_t id, void (*f)(beam*,size_t,state,float));
- void prune(size_t id);
- };
-}
-
+}
namespace Searn
@@ -162,6 +128,14 @@ namespace Searn
// MUST provide the allowed() function, see below.
void (*cs_ldf_example)(vw&, state, action, example*&, bool);
+
+ /************************************************************************
+ ***************** FUNCTIONS REQUIRED FOR BEAM SEARCH *******************
+ ***********************************************************************/
+
+ // TODO: document
+ size_t (*bucket)(state);
+
/************************************************************************
********************* (MOSTLY) OPTIONAL FUNCTIONS **********************
***********************************************************************/
11 vowpalwabbit/sequence.cc
View
@@ -311,9 +311,14 @@ namespace Sequence {
if (ec_seq.begin != NULL)
free(ec_seq.begin);
- if (DEBUG_FORCE_BEAM_ONE || sequence_beam > 1)
- for (size_t i=0; i<sequence_k * sequence_beam; i++)
- free(hcache[i].total_predictions);
+ loss_vector.erase();
+ free(loss_vector.begin);
+
+ transition_prediction_costs.erase();
+ free(transition_prediction_costs.begin);
+
+ for (size_t i=0; i<sequence_k * sequence_beam; i++)
+ free(hcache[i].total_predictions);
free(pred_seq); pred_seq = NULL;
free(policy_seq); policy_seq = NULL;
21 vowpalwabbit/v_hashmap.h
View
@@ -9,7 +9,8 @@
template<class K, class V> class v_hashmap{
public:
- struct elem {
+
+ struct hash_elem {
bool occupied;
K key;
V val;
@@ -19,7 +20,7 @@ template<class K, class V> class v_hashmap{
bool (*equivalent)(K,K);
// size_t (*hash)(K);
V default_value;
- v_array<elem> dat;
+ v_array<hash_elem> dat;
size_t last_position;
size_t num_occupants;
@@ -29,7 +30,7 @@ template<class K, class V> class v_hashmap{
}
v_hashmap(size_t min_size, V def, bool (*eq)(K,K)) {
- dat = v_array<elem>();
+ dat = v_array<hash_elem>();
if (min_size < 1023) min_size = 1023;
reserve(dat, min_size); // reserve sets to 0 ==> occupied=false
@@ -49,15 +50,15 @@ template<class K, class V> class v_hashmap{
}
void clear() {
- memset(dat.begin, 0, base_size()*sizeof(elem));
+ memset(dat.begin, 0, base_size()*sizeof(hash_elem));
last_position = 0;
num_occupants = 0;
}
void iter(void (*func)(K,V)) {
//for (size_t lp=0; lp<base_size(); lp++) {
- for (elem* e=dat.begin; e!=dat.end_array; e++) {
- //elem* e = dat.begin+lp;
+ for (hash_elem* e=dat.begin; e!=dat.end_array; e++) {
+ //hash_elem* e = dat.begin+lp;
if (e->occupied) {
//printf(" [lp=%d\tocc=%d\thash=%zu]\n", lp, e->occupied, e->hash);
func(e->key, e->val);
@@ -77,18 +78,18 @@ template<class K, class V> class v_hashmap{
void double_size() {
// printf("doubling size!\n");
// remember the old occupants
- v_array<elem>tmp = v_array<elem>();
+ v_array<hash_elem>tmp = v_array<hash_elem>();
reserve(tmp, num_occupants+10);
- for (elem* e=dat.begin; e!=dat.end_array; e++)
+ for (hash_elem* e=dat.begin; e!=dat.end_array; e++)
if (e->occupied)
push(tmp, *e);
// double the size and clear
reserve(dat, base_size()*2);
- memset(dat.begin, 0, base_size()*sizeof(elem));
+ memset(dat.begin, 0, base_size()*sizeof(hash_elem));
// re-insert occupants
- for (elem* e=tmp.begin; e!=tmp.end; e++) {
+ for (hash_elem* e=tmp.begin; e!=tmp.end; e++) {
get(e->key, e->hash);
// std::cerr << "reinserting " << e->key << " at " << last_position << std::endl;
put_after_get_nogrow(e->key, e->hash, e->val);
3  vowpalwabbit/vw.cc
View
@@ -18,8 +18,7 @@ embodied in the content of this file are licensed under the BSD
#include "parse_args.h"
#include "accumulate.h"
#include "vw.h"
-
-#include "searn.cc"
+#include "searn.h"
using namespace std;
Please sign in to comment.
Something went wrong with that request. Please try again.