Skip to content
This repository
tree: 059bd826b5
Fetching contributors…

Cannot retrieve contributors at this time

file 52 lines (41 sloc) 1.699 kb
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
/*
Copyright (c) by respective owners including Yahoo!, Microsoft, and
individual contributors. All rights reserved. Released under a BSD
license as described in the file LICENSE.
*/
#ifndef CB_H
#define CB_H

#define CB_TYPE_DR 0
#define CB_TYPE_DM 1
#define CB_TYPE_IPS 2

#include "global_data.h"
#include "parser.h"

//Contextual Bandit module to deal with incomplete cost-sensitive data
//Currently implemented as a reduction to cost-sensitive learning, using the methods discussed in the paper 'Doubly Robust Policy Evaluation and Learning'.

//CB is currently made to work with CSOAA or WAP as base cs learner
//TODO: extend to handle CSOAA_LDF and WAP_LDF

namespace CB {

  struct cb_class {
    float x; // the cost of this class
    uint32_t weight_index; // the index of this class
    float prob_action; //new for bandit setting, specifies the probability the data collection policy chose this class for importance weighting
    bool operator==(cb_class j){return weight_index == j.weight_index;}
  };

  struct label {
    v_array<cb_class> costs;
  };

  void parse_flags(vw& all, std::vector<std::string>&, po::variables_map& vm, po::variables_map& vm_file);

  void output_example(vw& all, example* ec);
  size_t read_cached_label(shared_data* sd, void* v, io_buf& cache);
  void cache_label(void* v, io_buf& cache);
  void default_label(void* v);
  void parse_label(shared_data* sd, void* v, v_array<substring>& words);
  void delete_label(void* v);
  float weight(void* v);
  float initial(void* v);
  const label_parser cb_label_parser = {default_label, parse_label,
cache_label, read_cached_label,
delete_label, weight, initial,
sizeof(label)};

}

#endif
Something went wrong with that request. Please try again.