Skip to content

Commit

Permalink
Merge 0d6cc4d into 538e9bb
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnLangford committed Jun 5, 2019
2 parents 538e9bb + 0d6cc4d commit d12a82b
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions vowpalwabbit/memory_tree.cc
Expand Up @@ -356,7 +356,7 @@ namespace memory_tree_ns

//train the node with id cn, using the statistics stored in the node to
//formulate a binary classificaiton example.
float train_node(memory_tree& b, single_learner& base, example& ec, const uint32_t cn)
float train_node(memory_tree& b, single_learner& base, example& ec, const uint64_t cn)
{
//predict, learn and predict
//note: here we first train the router and then predict.
Expand Down Expand Up @@ -406,7 +406,7 @@ namespace memory_tree_ns

//turn a leaf into an internal node, and create two children
//when the number of examples is too big
void split_leaf(memory_tree& b, single_learner& base, const uint32_t cn)
void split_leaf(memory_tree& b, single_learner& base, const uint64_t cn)
{
//create two children
b.nodes[cn].internal = 1; //swith to internal node.
Expand Down Expand Up @@ -526,7 +526,7 @@ namespace memory_tree_ns
return (uint32_t)(array_1.size() + array_2.size() - 2*overlap);
}

void collect_labels_from_leaf(memory_tree& b, const uint32_t cn, v_array<uint32_t>& leaf_labs){
void collect_labels_from_leaf(memory_tree& b, const uint64_t cn, v_array<uint32_t>& leaf_labs){
if (b.nodes[cn].internal != -1)
cout<<"something is wrong, it should be a leaf node"<<endl;

Expand All @@ -540,7 +540,7 @@ namespace memory_tree_ns
}
}

inline void train_one_against_some_at_leaf(memory_tree& b, single_learner& base, const uint32_t cn, example& ec){
inline void train_one_against_some_at_leaf(memory_tree& b, single_learner& base, const uint64_t cn, example& ec){
v_array<uint32_t> leaf_labs = v_init<uint32_t>();
collect_labels_from_leaf(b, cn, leaf_labs); //unique labels from the leaf.
MULTILABEL::labels multilabels = ec.l.multilabels;
Expand All @@ -557,7 +557,7 @@ namespace memory_tree_ns
}

inline uint32_t compute_hamming_loss_via_oas(memory_tree& b, single_learner& base,
const uint32_t cn, example& ec, v_array<uint32_t>& selected_labs)
const uint64_t cn, example& ec, v_array<uint32_t>& selected_labs)
{
selected_labs.delete_v();
v_array<uint32_t> leaf_labs = v_init<uint32_t>();
Expand All @@ -579,7 +579,7 @@ namespace memory_tree_ns


//pick up the "closest" example in the leaf using the score function.
int64_t pick_nearest(memory_tree& b, single_learner& base, const uint32_t cn, example& ec)
int64_t pick_nearest(memory_tree& b, single_learner& base, const uint64_t cn, example& ec)
{
if (b.nodes[cn].examples_index.size() > 0)
{
Expand Down Expand Up @@ -769,7 +769,7 @@ namespace memory_tree_ns
return;
}

void route_to_leaf(memory_tree& b, single_learner& base, const uint32_t & ec_array_index, uint64_t cn, v_array<uint32_t>& path, bool insertion){
void route_to_leaf(memory_tree& b, single_learner& base, const uint32_t & ec_array_index, uint64_t cn, v_array<uint64_t>& path, bool insertion){
example& ec = *b.examples[ec_array_index];

MULTICLASS::label_t mc;
Expand Down Expand Up @@ -817,13 +817,13 @@ namespace memory_tree_ns

//we roll in, then stop at a random step, do exploration. //no real insertion happens in the function.
void single_query_and_learn(memory_tree& b, single_learner& base, const uint32_t& ec_array_index, example& ec){
v_array<uint32_t> path_to_leaf = v_init<uint32_t>();
v_array<uint64_t> path_to_leaf = v_init<uint64_t>();
route_to_leaf(b, base, ec_array_index, 0, path_to_leaf, false); //no insertion happens here.

if (path_to_leaf.size() > 1){
//uint32_t random_pos = merand48(b.all->random_state)*(path_to_leaf.size()-1);
uint32_t random_pos = (uint32_t)(merand48(b.all->random_state)*(path_to_leaf.size())); //include leaf
uint32_t cn = path_to_leaf[random_pos];
uint64_t cn = path_to_leaf[random_pos];

if (b.nodes[cn].internal != -1){ //if it's an internal node:'
float objective = 0.f;
Expand Down Expand Up @@ -927,7 +927,7 @@ namespace memory_tree_ns
insert_example(b, base, ec_id); //unsupervised learning
else{
if (b.dream_at_update == false){
v_array<uint32_t> tmp_path = v_init<uint32_t>();
v_array<uint64_t> tmp_path = v_init<uint64_t>();
route_to_leaf(b, base, ec_id, 0, tmp_path, true);
tmp_path.delete_v();
}
Expand Down

0 comments on commit d12a82b

Please sign in to comment.