Skip to content
This repository
Browse code

Merge branch 'master' of github.com:JohnLangford/vowpal_wabbit

  • Loading branch information...
commit 38734e4ff665baaad3bf24527af4983f54626281 2 parents 7f00159 + 91d9285
John authored
18 Makefile.am
@@ -7,17 +7,17 @@ noinst_HEADERS = vowpalwabbit/accumulate.h vowpalwabbit/oaa.h \
7 7 vowpalwabbit/parse_example.h vowpalwabbit/cache.h \
8 8 vowpalwabbit/parse_primitives.h vowpalwabbit/comp_io.h \
9 9 vowpalwabbit/parse_regressor.h vowpalwabbit/constant.h \
10   - vowpalwabbit/parser.h vowpalwabbit/csoaa.h \
  10 + vowpalwabbit/parser.h vowpalwabbit/csoaa.h vowpalwabbit/beam.h \
11 11 vowpalwabbit/searn.h vowpalwabbit/ect.h \
12 12 vowpalwabbit/searn_sequencetask.h vowpalwabbit/example.h \
13   - vowpalwabbit/sender.h vowpalwabbit/gd.h \
14   - vowpalwabbit/sequence.h vowpalwabbit/gd_mf.h \
15   - vowpalwabbit/simple_label.h vowpalwabbit/global_data.h \
16   - vowpalwabbit/sparse_dense.h vowpalwabbit/hash.h \
17   - vowpalwabbit/unique_sort.h vowpalwabbit/io.h \
18   - vowpalwabbit/v_array.h vowpalwabbit/lda_core.h \
19   - vowpalwabbit/v_hashmap.h vowpalwabbit/loss_functions.h \
20   - vowpalwabbit/network.h vowpalwabbit/wap.h vowpalwabbit/noop.h
  13 + vowpalwabbit/sender.h vowpalwabbit/gd.h vowpalwabbit/sequence.h \
  14 + vowpalwabbit/gd_mf.h vowpalwabbit/simple_label.h \
  15 + vowpalwabbit/global_data.h vowpalwabbit/sparse_dense.h \
  16 + vowpalwabbit/hash.h vowpalwabbit/unique_sort.h \
  17 + vowpalwabbit/io.h vowpalwabbit/v_array.h \
  18 + vowpalwabbit/lda_core.h vowpalwabbit/v_hashmap.h \
  19 + vowpalwabbit/loss_functions.h vowpalwabbit/network.h \
  20 + vowpalwabbit/wap.h vowpalwabbit/noop.h
21 21
22 22 ACLOCAL_AMFLAGS = -I acinclude.d
23 23
10 README.hal3
... ... @@ -0,0 +1,10 @@
  1 +To get John's changes:
  2 + git pull git://github.com/JohnLangford/vowpal_wabbit.git master
  3 +
  4 +To enable use of valgrind, do:
  5 + ./configure --enable-profiling
  6 +
  7 +To do fancy learning, do
  8 + --exact_adaptive_norm --power_t 1
  9 +or
  10 + --exact_adaptive_norm --power_t 0.5
2  vowpalwabbit/Makefile.am
@@ -5,7 +5,7 @@ include_HEADERS = allreduce.h
5 5
6 6 bin_PROGRAMS = vw active_interactor
7 7
8   -libvw_la_SOURCES = hash.cc global_data.cc io.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc simple_label.cc oaa.cc ect.cc csoaa.cc v_hashmap.cc wap.cc searn.cc searn_sequencetask.cc sequence.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc loss_functions.cc sender.cc
  8 +libvw_la_SOURCES = hash.cc global_data.cc io.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc simple_label.cc oaa.cc ect.cc csoaa.cc v_hashmap.cc wap.cc beam.cc searn.cc searn_sequencetask.cc sequence.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc vw.cc loss_functions.cc sender.cc
9 9
10 10 vw_SOURCES = vw.cc
11 11
232 vowpalwabbit/beam.cc
... ... @@ -0,0 +1,232 @@
  1 +#include <iostream>
  2 +#include <float.h>
  3 +#include <stdio.h>
  4 +#include <math.h>
  5 +#include "beam.h"
  6 +#include "v_hashmap.h"
  7 +#include "v_array.h"
  8 +
  9 +#define MULTIPLIER 5
  10 +
  11 +using namespace std;
  12 +
  13 +namespace Beam
  14 +{
  15 + int compare_elem(const void *va, const void *vb) {
  16 + // first sort on hash, then on loss
  17 + elem* a = (elem*)va;
  18 + elem* b = (elem*)vb;
  19 + if (a->hash < b->hash) { return -1; }
  20 + if (a->hash > b->hash) { return 1; }
  21 + return b->loss - a->loss; // if b is greater, it should go second
  22 + }
  23 +
  24 + beam::beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size) {
  25 + equivalent = eq;
  26 + hash = hs;
  27 + empty_bucket = new v_array<elem>();
  28 + last_retrieved = NULL;
  29 + max_size = max_beam_size;
  30 + losses = (float*)calloc(max_size, sizeof(float));
  31 + dat = new v_hashmap<size_t,bucket>(8, empty_bucket, NULL);
  32 + }
  33 +
  34 + beam::~beam() {
  35 + // TODO: really free the elements
  36 + delete dat;
  37 + free(empty_bucket->begin);
  38 + delete empty_bucket;
  39 + }
  40 +
  41 + size_t hash_bucket(size_t id) { return 1043221*(893901 + id); }
  42 +
  43 + void beam::put(size_t id, state s, size_t hs, size_t act, float loss) {
  44 + elem e = { s, hs, loss, id, last_retrieved, act, true };
  45 + // check to see if we have this bucket yet
  46 + bucket b = dat->get(id, hash_bucket(id));
  47 + if (b->index() > 0) { // this one exists: just add to it
  48 + push(*b, e);
  49 + //dat->put_after_get(id, hash_bucket(id), b);
  50 + if (b->index() >= max_size * MULTIPLIER)
  51 + prune(id);
  52 + } else {
  53 + bucket bnew = new v_array<elem>();
  54 + push(*bnew, e);
  55 + dat->put_after_get(id, hash_bucket(id), bnew);
  56 + }
  57 + }
  58 +
  59 + void beam::put_final(state s, size_t act, float loss) {
  60 + elem e = { s, 0, loss, 0, last_retrieved, act, true };
  61 + push(*final_states, e);
  62 + }
  63 +
  64 + void beam::iterate(size_t id, void (*f)(beam*,size_t,state,float,void*), void*args) {
  65 + bucket b = dat->get(id, hash_bucket(id));
  66 + if (b->index() == 0) return;
  67 +
  68 + cout << "before prune" << endl;
  69 + prune(id);
  70 + cout << "after prune" << endl;
  71 +
  72 + for (elem*e=b->begin; e!=b->end; e++) {
  73 + cout << "element" << endl;
  74 + if (e->alive) {
  75 + last_retrieved = e;
  76 + f(this, id, e->s, e->loss, args);
  77 + }
  78 + }
  79 + }
  80 +
  81 + #define SWAP(a,b) temp=(a);(a)=(b);(b)=temp;
  82 + float quickselect(float *arr, size_t n, size_t k) {
  83 + size_t i,ir,j,l,mid;
  84 + float a,temp;
  85 +
  86 + l=0;
  87 + ir=n-1;
  88 + for(;;) {
  89 + if (ir <= l+1) {
  90 + if (ir == l+1 && arr[ir] < arr[l]) {
  91 + SWAP(arr[l],arr[ir]);
  92 + }
  93 + return arr[k];
  94 + }
  95 + else {
  96 + mid=(l+ir) >> 1;
  97 + SWAP(arr[mid],arr[l+1]);
  98 + if (arr[l] > arr[ir]) {
  99 + SWAP(arr[l],arr[ir]);
  100 + }
  101 + if (arr[l+1] > arr[ir]) {
  102 + SWAP(arr[l+1],arr[ir]);
  103 + }
  104 + if (arr[l] > arr[l+1]) {
  105 + SWAP(arr[l],arr[l+1]);
  106 + }
  107 + i=l+1;
  108 + j=ir;
  109 + a=arr[l+1];
  110 + for (;;) {
  111 + do i++; while (arr[i] < a);
  112 + do j--; while (arr[j] > a);
  113 + if (j < i) break;
  114 + SWAP(arr[i],arr[j]);
  115 + }
  116 + arr[l+1]=arr[j];
  117 + arr[j]=a;
  118 + if (j >= k) ir=j-1;
  119 + if (j <= k) l=i;
  120 + }
  121 + }
  122 + }
  123 +
  124 +
  125 + void beam::prune(size_t id) {
  126 + bucket b = dat->get(id, hash_bucket(id));
  127 + if (b->index() == 0) return;
  128 +
  129 + size_t num_alive = 0;
  130 + if (equivalent == NULL) {
  131 + for (size_t i=1; i<b->index(); i++) {
  132 + (*b)[i].alive = true;
  133 + }
  134 + num_alive = b->index();
  135 + } else {
  136 + // first, sort on hash, backing off to loss
  137 + qsort(b->begin, b->index(), sizeof(elem), compare_elem);
  138 +
  139 + // now, check actual equivalence
  140 + size_t last_pos = 0;
  141 + size_t last_hash = (*b)[0].hash;
  142 + for (size_t i=1; i<b->index(); i++) {
  143 + (*b)[i].alive = true;
  144 + if ((*b)[i].hash != last_hash) {
  145 + last_pos = i;
  146 + last_hash = (*b)[i].hash;
  147 + } else {
  148 + for (size_t j=last_pos; j<i; j++) {
  149 + if ((*b)[j].alive && equivalent((*b)[j].s, (*b)[i].s)) {
  150 + (*b)[i].alive = false;
  151 + break;
  152 + }
  153 + }
  154 + }
  155 +
  156 + if ((*b)[i].alive) {
  157 + losses[num_alive] = (*b)[i].loss;
  158 + num_alive++;
  159 + }
  160 + }
  161 + }
  162 +
  163 + if (num_alive <= max_size) return;
  164 +
  165 + // sort the remaining items on loss
  166 + float cutoff = quickselect(losses, num_alive, max_size);
  167 + bucket bnew = new v_array<elem>();
  168 + for (elem*e=b->begin; e!=b->end; e++) {
  169 + if (e->loss > cutoff) continue;
  170 + push(*bnew, *e);
  171 + num_alive--;
  172 + if (num_alive < 0) break;
  173 + }
  174 + dat->put_after_get(id, hash_bucket(id), bnew);
  175 + }
  176 +
  177 + size_t beam::get_next_bucket(size_t start) {
  178 + size_t next_bucket = 0;
  179 + for (v_hashmap<size_t,bucket>::hash_elem* e=dat->dat.begin; e!=dat->dat.end_array; e++) {
  180 + if (e->occupied) {
  181 + size_t bucket_id = e->key;
  182 + if ((bucket_id > start) && (bucket_id < next_bucket))
  183 + next_bucket = bucket_id;
  184 + }
  185 + }
  186 + return next_bucket;
  187 + }
  188 +
  189 + void beam::get_best_output(std::vector<size_t>* action_seq) {
  190 + action_seq->clear();
  191 + if (final_states->index() == 0) {
  192 + // TODO: error
  193 + return;
  194 + } else {
  195 + elem *bestElem = NULL;
  196 + for (size_t i=0; i<final_states->index(); i++) {
  197 + if ((bestElem == NULL) || ((*final_states)[i].loss < bestElem->loss))
  198 + bestElem = &(*final_states)[i];
  199 + }
  200 + // chase backpointers
  201 + while (bestElem != NULL) {
  202 + std::vector<size_t>::iterator be = action_seq->begin();
  203 + action_seq->insert( be, bestElem->last_action );
  204 + bestElem = bestElem->backpointer;
  205 + }
  206 + }
  207 + }
  208 +
  209 +
  210 + struct test_beam_state {
  211 + size_t id;
  212 + };
  213 + bool state_eq(state a,state b) { return ((test_beam_state*)a)->id == ((test_beam_state*)b)->id; }
  214 + size_t state_hash(state a) { return 381049*(3820+((test_beam_state*)a)->id); }
  215 + void expand_state(beam*b, size_t old_id, state old_state, float old_loss, void*args) {
  216 + test_beam_state* new_state = (test_beam_state*)calloc(1, sizeof(test_beam_state));
  217 + new_state->id = old_id + ((test_beam_state*)old_state)->id * 2;
  218 + float new_loss = old_loss + 0.5;
  219 + cout << "expand_state " << old_loss << " -> " << new_state->id << " , " << new_loss << endl;
  220 + b->put(old_id+1, new_state, 0, new_loss);
  221 + }
  222 + void test_beam() {
  223 + beam*b = new beam(&state_eq, &state_hash, 5);
  224 + for (size_t i=0; i<25; i++) {
  225 + test_beam_state* s = (test_beam_state*)calloc(1, sizeof(test_beam_state));
  226 + s->id = i / 3;
  227 + b->put(0, s, 0, 0. - (float)i);
  228 + cout << "added " << s->id << endl;
  229 + }
  230 + b->iterate(0, expand_state, NULL);
  231 + }
  232 +}
51 vowpalwabbit/beam.h
... ... @@ -0,0 +1,51 @@
  1 +#ifndef BEAM_H
  2 +#define BEAM_H
  3 +
  4 +#include <vector>
  5 +#include <stdio.h>
  6 +#include "v_hashmap.h"
  7 +#include "v_array.h"
  8 +
  9 +typedef void* state;
  10 +
  11 +namespace Beam
  12 +{
  13 + struct elem {
  14 + state s;
  15 + size_t hash;
  16 + float loss;
  17 + size_t bucket_id;
  18 + elem* backpointer;
  19 + size_t last_action;
  20 + bool alive;
  21 + };
  22 +
  23 + typedef v_array<elem>* bucket;
  24 +
  25 + class beam {
  26 + private:
  27 + bool (*equivalent)(state, state);
  28 + size_t (*hash)(state);
  29 +
  30 + v_hashmap<size_t, bucket>* dat;
  31 + bucket final_states;
  32 +
  33 + bucket empty_bucket;
  34 + elem* last_retrieved;
  35 + size_t max_size;
  36 + float* losses;
  37 +
  38 + public:
  39 + beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size);
  40 + ~beam();
  41 + void put(size_t id, state s, size_t hs, size_t act, float loss);
  42 + void put(size_t id, state s, size_t act, float loss) { put(id, s, hash(s), act, loss); }
  43 + void put_final(state s, size_t act, float loss);
  44 + void iterate(size_t id, void (*f)(beam*,size_t,state,float,void*), void*);
  45 + void prune(size_t id);
  46 + size_t get_next_bucket(size_t start);
  47 + void get_best_output(std::vector<size_t>*);
  48 + };
  49 +}
  50 +
  51 +#endif
4 vowpalwabbit/parse_example.cc
@@ -95,14 +95,14 @@ class TC_parser {
95 95 }
96 96
97 97 inline void maybeFeature(){
98   - if(*reading_head == ' ' || *reading_head == '|'|| reading_head == endLine ){
  98 + if(*reading_head == ' ' || *reading_head == '|'|| reading_head == endLine ){
99 99 // maybeFeature --> ø
100 100 }else if(*reading_head != ':'){
101 101 // maybeFeature --> 'String' FeatureValue
102 102 substring feature_name ;
103 103 feature_name.begin = reading_head;
104 104 v_array<char> feature_v;
105   - while( !(*reading_head == ' ' || *reading_head == ':' ||*reading_head == '|' ||reading_head == endLine )){
  105 + while( !(*reading_head == ' ' || *reading_head == ':' ||*reading_head == '|' ||reading_head == endLine )){
106 106 if(audit){
107 107 push(feature_v,*reading_head);
108 108 }
71 vowpalwabbit/parse_primitives.h
@@ -108,42 +108,49 @@ inline void print_substring(substring s)
108 108 // in charge of error detection.
109 109 inline float parseFloat(char * p, char **end)
110 110 {
111   - if (!*p || *p == '?')
112   - return 0;
113   - int s = 1;
114   - while (*p == ' ') p++;
115   -
  111 + char* start = p;
  112 +
  113 + if (!*p)
  114 + return 0;
  115 + int s = 1;
  116 + while (*p == ' ') p++;
  117 +
  118 + if (*p == '-') {
  119 + s = -1; p++;
  120 + }
  121 +
  122 + double acc = 0;
  123 + while (*p >= '0' && *p <= '9')
  124 + acc = acc * 10 + *p++ - '0';
  125 +
  126 + int num_dec = 0;
  127 + if (*p == '.') {
  128 + p++;
  129 + while (*p >= '0' && *p <= '9') {
  130 + acc = acc *10 + (*p++ - '0') ;
  131 + num_dec++;
  132 + }
  133 + }
  134 + int exp_acc = 0;
  135 + if(*p == 'e' || *p == 'E'){
  136 + p++;
  137 + int exp_s = 1;
116 138 if (*p == '-') {
117   - s = -1; p++;
  139 + exp_s = -1; p++;
118 140 }
119   -
120   - double acc = 0;
121 141 while (*p >= '0' && *p <= '9')
122   - acc = acc * 10 + *p++ - '0';
123   -
124   - int num_dec = 0;
125   - if (*p == '.') {
126   - p++;
127   - while (*p >= '0' && *p <= '9') {
128   - acc = acc *10 + (*p++ - '0') ;
129   - num_dec++;
130   - }
131   - }
132   - int exp_acc = 0;
133   - if(*p == 'e' || *p == 'E'){
134   - p++;
135   - int exp_s = 1;
136   - if (*p == '-') {
137   - exp_s = -1; p++;
138   - }
139   - while (*p >= '0' && *p <= '9')
140   - exp_acc = exp_acc * 10 + *p++ - '0';
141   - exp_acc *= exp_s;
142   -
  142 + exp_acc = exp_acc * 10 + *p++ - '0';
  143 + exp_acc *= exp_s;
  144 +
  145 + }
  146 + if (*p == ' ')//easy case succeeded.
  147 + {
  148 + acc *= pow(10,exp_acc-num_dec);
  149 + *end = p;
  150 + return s * acc;
143 151 }
144   - acc *= pow(10,exp_acc-num_dec);
145   - *end = p;
146   - return s * acc;
  152 + else
  153 + return strtof(start,end);
147 154 }
148 155
149 156 inline float float_of_substring(substring s)
236 vowpalwabbit/searn.cc
@@ -10,6 +10,7 @@
10 10 #include "oaa.h"
11 11 #include "csoaa.h"
12 12 #include "v_hashmap.h"
  13 +#include "beam.h"
13 14
14 15 // task-specific includes
15 16 #include "searn_sequencetask.h"
@@ -264,187 +265,6 @@ namespace SearnUtil
264 265
265 266 }
266 267
267   -namespace Beam
268   -{
269   - int compare_elem(const void *va, const void *vb) {
270   - // first sort on hash, then on loss
271   - elem* a = (elem*)va;
272   - elem* b = (elem*)vb;
273   - if (a->hash < b->hash) { return -1; }
274   - if (a->hash > b->hash) { return 1; }
275   - return b->loss - a->loss; // if b is greater, it should go second
276   - }
277   -
278   - beam::beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size) {
279   - equivalent = eq;
280   - hash = hs;
281   - empty_bucket = v_array<elem>();
282   - last_retrieved = NULL;
283   - max_size = max_beam_size;
284   - losses = (float*)calloc(max_size, sizeof(float));
285   - dat = new v_hashmap<size_t,bucket>(8, empty_bucket, NULL);
286   - }
287   -
288   - beam::~beam() {
289   - // TODO: really free the elements
290   - delete dat;
291   - free(empty_bucket.begin);
292   - }
293   -
294   - size_t hash_bucket(size_t id) { return 1043221*(893901 + id); }
295   -
296   - void beam::put(size_t id, state s, size_t hs, float loss) {
297   - elem e = { s, hs, loss, id, last_retrieved };
298   - // check to see if we have this bucket yet
299   - bucket b = dat->get(id, hash_bucket(id));
300   - if (b.index() > 0) { // this one exists: just add to it
301   - push(b, e);
302   - dat->put_after_get(id, hash_bucket(id), b);
303   - } else {
304   - bucket bnew = v_array<elem>();
305   - push(bnew, e);
306   - dat->put_after_get(id, hash_bucket(id), bnew);
307   - }
308   - }
309   -
310   - void beam::iterate(size_t id, void (*f)(beam*,size_t,state,float)) {
311   - bucket b = dat->get(id, hash_bucket(id));
312   - if (b.index() == 0) return;
313   -
314   - cout << "before prune" << endl;
315   - prune(id);
316   - cout << "after prune" << endl;
317   -
318   - for (elem*e=b.begin; e!=b.end; e++) {
319   - cout << "element" << endl;
320   - if (e->alive) {
321   - last_retrieved = e;
322   - f(this, id, e->s, e->loss);
323   - }
324   - }
325   - }
326   -
327   - #define SWAP(a,b) temp=(a);(a)=(b);(b)=temp;
328   - float quickselect(float *arr, size_t n, size_t k) {
329   - size_t i,ir,j,l,mid;
330   - float a,temp;
331   -
332   - l=0;
333   - ir=n-1;
334   - for(;;) {
335   - if (ir <= l+1) {
336   - if (ir == l+1 && arr[ir] < arr[l]) {
337   - SWAP(arr[l],arr[ir]);
338   - }
339   - return arr[k];
340   - }
341   - else {
342   - mid=(l+ir) >> 1;
343   - SWAP(arr[mid],arr[l+1]);
344   - if (arr[l] > arr[ir]) {
345   - SWAP(arr[l],arr[ir]);
346   - }
347   - if (arr[l+1] > arr[ir]) {
348   - SWAP(arr[l+1],arr[ir]);
349   - }
350   - if (arr[l] > arr[l+1]) {
351   - SWAP(arr[l],arr[l+1]);
352   - }
353   - i=l+1;
354   - j=ir;
355   - a=arr[l+1];
356   - for (;;) {
357   - do i++; while (arr[i] < a);
358   - do j--; while (arr[j] > a);
359   - if (j < i) break;
360   - SWAP(arr[i],arr[j]);
361   - }
362   - arr[l+1]=arr[j];
363   - arr[j]=a;
364   - if (j >= k) ir=j-1;
365   - if (j <= k) l=i;
366   - }
367   - }
368   - }
369   -
370   -
371   - void beam::prune(size_t id) {
372   - bucket b = dat->get(id, hash_bucket(id));
373   - if (b.index() == 0) return;
374   -
375   - size_t num_alive = 0;
376   - if (equivalent == NULL) {
377   - for (size_t i=1; i<b.index(); i++) {
378   - b[i].alive = true;
379   - }
380   - num_alive = b.index();
381   - } else {
382   - // first, sort on hash, backing off to loss
383   - qsort(b.begin, b.index(), sizeof(elem), compare_elem);
384   -
385   - // now, check actual equivalence
386   - size_t last_pos = 0;
387   - size_t last_hash = b[0].hash;
388   - for (size_t i=1; i<b.index(); i++) {
389   - b[i].alive = true;
390   - if (b[i].hash != last_hash) {
391   - last_pos = i;
392   - last_hash = b[i].hash;
393   - } else {
394   - for (size_t j=last_pos; j<i; j++) {
395   - if (b[j].alive && equivalent(b[j].s, b[i].s)) {
396   - b[i].alive = false;
397   - break;
398   - }
399   - }
400   - }
401   -
402   - if (b[i].alive) {
403   - losses[num_alive] = b[i].loss;
404   - num_alive++;
405   - }
406   - }
407   - }
408   -
409   - if (num_alive <= max_size) return;
410   -
411   - // sort the remaining items on loss
412   - float cutoff = quickselect(losses, num_alive, max_size);
413   - bucket bnew = v_array<elem>();
414   - for (elem*e=b.begin; e!=b.end; e++) {
415   - if (e->loss > cutoff) continue;
416   - push(bnew, *e);
417   - num_alive--;
418   - if (num_alive < 0) break;
419   - }
420   - dat->put_after_get(id, hash_bucket(id), bnew);
421   - }
422   -
423   -
424   - struct test_beam_state {
425   - size_t id;
426   - };
427   - bool state_eq(state a,state b) { return ((test_beam_state*)a)->id == ((test_beam_state*)b)->id; }
428   - size_t state_hash(state a) { return 381049*(3820+((test_beam_state*)a)->id); }
429   - void expand_state(beam*b, size_t old_id, state old_state, float old_loss) {
430   - test_beam_state* new_state = (test_beam_state*)calloc(1, sizeof(test_beam_state));
431   - new_state->id = old_id + ((test_beam_state*)old_state)->id * 2;
432   - float new_loss = old_loss + 0.5;
433   - cout << "expand_state " << old_loss << " -> " << new_state->id << " , " << new_loss << endl;
434   - b->put(old_id+1, new_state, new_loss);
435   - }
436   - void test_beam() {
437   - beam*b = new beam(&state_eq, &state_hash, 5);
438   - for (size_t i=0; i<25; i++) {
439   - test_beam_state* s = (test_beam_state*)calloc(1, sizeof(test_beam_state));
440   - s->id = i / 3;
441   - b->put(0, s, 0. - (float)i);
442   - cout << "added " << s->id << endl;
443   - }
444   - b->iterate(0, expand_state);
445   - }
446   -}
447   -
448 268 namespace Searn
449 269 {
450 270 // task stuff
@@ -892,8 +712,7 @@ namespace Searn
892 712 all.finish = finish;
893 713 }
894 714
895   -
896   - size_t searn_predict(vw&all, state s0, size_t step, bool allow_oracle, bool allow_current)
  715 + size_t searn_predict(vw&all, state s0, size_t step, bool allow_oracle, bool allow_current, v_array< pair<size_t,float> >* partial_predictions) // TODO: partial_predictions
897 716 {
898 717 int policy = SearnUtil::random_policy(has_hash ? task.hash(s0) : step, beta, allow_current, current_policy, allow_oracle);
899 718 if (PRINT_DEBUG_INFO) { cerr << "predicing with policy " << policy << " (allow_oracle=" << allow_oracle << ", allow_current=" << allow_current << "), current_policy=" << current_policy << endl; }
@@ -962,7 +781,7 @@ namespace Searn
962 781 return best_action;
963 782 }
964 783 }
965   -
  784 +
966 785 void parallel_rollout(vw&all, state s0)
967 786 {
968 787 // first, make K copies of s0 and step them
@@ -997,7 +816,7 @@ namespace Searn
997 816
998 817 if (action == 0) { // this means we didn't find it or we're not recombining
999 818 if( !rollout_oracle )
1000   - action = searn_predict(all, rollout[k-1].st, step, true, allow_current_policy);
  819 + action = searn_predict(all, rollout[k-1].st, step, true, allow_current_policy, NULL);
1001 820 else
1002 821 action = task.oracle(rollout[k-1].st);
1003 822
@@ -1116,7 +935,7 @@ namespace Searn
1116 935 {
1117 936 int step = 1;
1118 937 while (!task.final(s0)) {
1119   - size_t action = searn_predict(all, s0, step, allow_oracle, allow_current);
  938 + size_t action = searn_predict(all, s0, step, allow_oracle, allow_current, NULL);
1120 939 if (track_actions)
1121 940 action_sequence->push_back(action);
1122 941
@@ -1124,6 +943,49 @@ namespace Searn
1124 943 step++;
1125 944 }
1126 945 }
  946 +
  947 + struct beam_info_struct {
  948 + vw&all;
  949 + bool allow_oracle;
  950 + bool allow_current;
  951 + };
  952 +
  953 + void run_prediction_beam_iter(Beam::beam*b, size_t bucket_id, state s0, float cur_loss, void*args)
  954 + {
  955 + beam_info_struct* bi = (beam_info_struct*)args;
  956 +
  957 + if (task.final(s0)) return;
  958 +
  959 + v_array< pair<size_t,float> > partial_predictions;
  960 + searn_predict(bi->all, s0, bucket_id, bi->allow_oracle, bi->allow_current, &partial_predictions);
  961 + for (size_t i=0; i<partial_predictions.index(); i++) {
  962 + state s1 = task.copy(s0);
  963 + float new_loss = cur_loss + partial_predictions[i].second;
  964 + size_t action = partial_predictions[i].first;
  965 + task.step( s1, action );
  966 + b->put( task.bucket(s1), s1, action, new_loss );
  967 + }
  968 + }
  969 +
  970 + void run_prediction_beam(vw&all, size_t max_beam_size, state s0, bool allow_oracle, bool allow_current, bool track_actions, std::vector<action>* action_sequence)
  971 + {
  972 + Beam::beam *b = new Beam::beam(task.equivalent, task.hash, max_beam_size);
  973 +
  974 + beam_info_struct bi = { all, allow_oracle, allow_current };
  975 +
  976 + b->put(task.bucket(s0), s0, 0, 0.);
  977 + size_t current_bucket = 0;
  978 + while (true) {
  979 + current_bucket = b->get_next_bucket(current_bucket);
  980 + if (current_bucket == 0) break;
  981 + b->iterate(current_bucket, run_prediction_beam_iter, &bi);
  982 + }
  983 +
  984 + if (track_actions && (action_sequence != NULL))
  985 + b->get_best_output(action_sequence);
  986 +
  987 + delete b;
  988 + }
1127 989
1128 990 // void hm_free_state_copies(state s, action a) {
1129 991 // task.finish(s);
@@ -1193,7 +1055,7 @@ namespace Searn
1193 1055
1194 1056 // first, make a prediction (we don't want to bias ourselves if
1195 1057 // we're using the current policy to predict)
1196   - size_t action = searn_predict(all, s0, step, true, allow_current_policy);
  1058 + size_t action = searn_predict(all, s0, step, true, allow_current_policy, NULL);
1197 1059
1198 1060 // generate training example for the current state
1199 1061 generate_state_example(all, s0);
44 vowpalwabbit/searn.h
@@ -37,41 +37,7 @@ namespace SearnUtil
37 37
38 38 void add_history_to_example(vw&, history_info*, example*, history);
39 39 void remove_history_from_example(vw&, history_info *, example*);
40   -}
41   -
42   -namespace Beam
43   -{
44   - struct elem {
45   - state s;
46   - size_t hash;
47   - float loss;
48   - size_t bucket_id;
49   - elem* backpointer;
50   - bool alive;
51   - };
52   -
53   - typedef v_array<elem> bucket;
54   -
55   - class beam {
56   - public:
57   - bool (*equivalent)(state, state);
58   - size_t (*hash)(state);
59   -
60   - bucket empty_bucket;
61   - v_hashmap<size_t, bucket>* dat;
62   - elem* last_retrieved;
63   - size_t max_size;
64   - float* losses;
65   -
66   - beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size);
67   - ~beam();
68   - void put(size_t id, state s, size_t hs, float loss);
69   - void put(size_t id, state s, float loss) { put(id, s, hash(s), loss); }
70   - void iterate(size_t id, void (*f)(beam*,size_t,state,float));
71   - void prune(size_t id);
72   - };
73   -}
74   -
  40 +}
75 41
76 42
77 43 namespace Searn
@@ -162,6 +128,14 @@ namespace Searn
162 128 // MUST provide the allowed() function, see below.
163 129 void (*cs_ldf_example)(vw&, state, action, example*&, bool);
164 130
  131 +
  132 + /************************************************************************
  133 + ***************** FUNCTIONS REQUIRED FOR BEAM SEARCH *******************
  134 + ***********************************************************************/
  135 +
  136 + // TODO: document
  137 + size_t (*bucket)(state);
  138 +
165 139 /************************************************************************
166 140 ********************* (MOSTLY) OPTIONAL FUNCTIONS **********************
167 141 ***********************************************************************/
11 vowpalwabbit/sequence.cc
@@ -311,9 +311,14 @@ namespace Sequence {
311 311 if (ec_seq.begin != NULL)
312 312 free(ec_seq.begin);
313 313
314   - if (DEBUG_FORCE_BEAM_ONE || sequence_beam > 1)
315   - for (size_t i=0; i<sequence_k * sequence_beam; i++)
316   - free(hcache[i].total_predictions);
  314 + loss_vector.erase();
  315 + free(loss_vector.begin);
  316 +
  317 + transition_prediction_costs.erase();
  318 + free(transition_prediction_costs.begin);
  319 +
  320 + for (size_t i=0; i<sequence_k * sequence_beam; i++)
  321 + free(hcache[i].total_predictions);
317 322
318 323 free(pred_seq); pred_seq = NULL;
319 324 free(policy_seq); policy_seq = NULL;
21 vowpalwabbit/v_hashmap.h
@@ -9,7 +9,8 @@
9 9
10 10 template<class K, class V> class v_hashmap{
11 11 public:
12   - struct elem {
  12 +
  13 + struct hash_elem {
13 14 bool occupied;
14 15 K key;
15 16 V val;
@@ -19,7 +20,7 @@ template<class K, class V> class v_hashmap{
19 20 bool (*equivalent)(K,K);
20 21 // size_t (*hash)(K);
21 22 V default_value;
22   - v_array<elem> dat;
  23 + v_array<hash_elem> dat;
23 24 size_t last_position;
24 25 size_t num_occupants;
25 26
@@ -29,7 +30,7 @@ template<class K, class V> class v_hashmap{
29 30 }
30 31
31 32 v_hashmap(size_t min_size, V def, bool (*eq)(K,K)) {
32   - dat = v_array<elem>();
  33 + dat = v_array<hash_elem>();
33 34 if (min_size < 1023) min_size = 1023;
34 35 reserve(dat, min_size); // reserve sets to 0 ==> occupied=false
35 36
@@ -49,15 +50,15 @@ template<class K, class V> class v_hashmap{
49 50 }
50 51
51 52 void clear() {
52   - memset(dat.begin, 0, base_size()*sizeof(elem));
  53 + memset(dat.begin, 0, base_size()*sizeof(hash_elem));
53 54 last_position = 0;
54 55 num_occupants = 0;
55 56 }
56 57
57 58 void iter(void (*func)(K,V)) {
58 59 //for (size_t lp=0; lp<base_size(); lp++) {
59   - for (elem* e=dat.begin; e!=dat.end_array; e++) {
60   - //elem* e = dat.begin+lp;
  60 + for (hash_elem* e=dat.begin; e!=dat.end_array; e++) {
  61 + //hash_elem* e = dat.begin+lp;
61 62 if (e->occupied) {
62 63 //printf(" [lp=%d\tocc=%d\thash=%zu]\n", lp, e->occupied, e->hash);
63 64 func(e->key, e->val);
@@ -77,18 +78,18 @@ template<class K, class V> class v_hashmap{
77 78 void double_size() {
78 79 // printf("doubling size!\n");
79 80 // remember the old occupants
80   - v_array<elem>tmp = v_array<elem>();
  81 + v_array<hash_elem>tmp = v_array<hash_elem>();
81 82 reserve(tmp, num_occupants+10);
82   - for (elem* e=dat.begin; e!=dat.end_array; e++)
  83 + for (hash_elem* e=dat.begin; e!=dat.end_array; e++)
83 84 if (e->occupied)
84 85 push(tmp, *e);
85 86
86 87 // double the size and clear
87 88 reserve(dat, base_size()*2);
88   - memset(dat.begin, 0, base_size()*sizeof(elem));
  89 + memset(dat.begin, 0, base_size()*sizeof(hash_elem));
89 90
90 91 // re-insert occupants
91   - for (elem* e=tmp.begin; e!=tmp.end; e++) {
  92 + for (hash_elem* e=tmp.begin; e!=tmp.end; e++) {
92 93 get(e->key, e->hash);
93 94 // std::cerr << "reinserting " << e->key << " at " << last_position << std::endl;
94 95 put_after_get_nogrow(e->key, e->hash, e->val);
3  vowpalwabbit/vw.cc
@@ -18,8 +18,7 @@ embodied in the content of this file are licensed under the BSD
18 18 #include "parse_args.h"
19 19 #include "accumulate.h"
20 20 #include "vw.h"
21   -
22   -#include "searn.cc"
  21 +#include "searn.h"
23 22
24 23 using namespace std;
25 24

0 comments on commit 38734e4

Please sign in to comment.
Something went wrong with that request. Please try again.