forked from kirnhans/15418-project
/
DecisionTree.h
65 lines (51 loc) · 1.27 KB
/
DecisionTree.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <stdio.h>
#include <stdlib.h>
#ifndef DECISIONTREE_H
#define DECISIONTREE_H
struct node {
int split_var_idx;
int split_val_idx;
int size;
int is_terminal;
// P(class = 0 | data) and P(class = 1 | data)
double p_0;
double p_1;
// Just use Gini impurity because that is what sklearn and randomForest use
// by default
double impurity;
double** attribute_value_list;
int** class_label_list;
int** rid_list;
node* left;
node* right;
};
class DecisionTree {
public:
DecisionTree(double* train_data, int* train_y, int n, int p);
DecisionTree(double* train_data,
int* train_y,
int n,
int p,
int mtry,
int nodesize,
int maxnodes);
~DecisionTree();
double eval(double* new_data);
int count_levels();
void train();
private:
double* train_data;
int* train_y;
int n;
int p;
int mtry;
int nodesize;
int maxnodes;
node *root;
double* device_data;
int* device_labels;
void grow(node* t);
void deleteTree(node* t);
int count_help(node* t);
};
#endif