/
tabular_model.R
128 lines (104 loc) · 2.66 KB
/
tabular_model.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#' @title Emb_sz_rule
#'
#' @description Rule of thumb to pick embedding size corresponding to `n_cat`
#'
#'
#' @param n_cat n_cat
#' @return None
#' @export
emb_sz_rule <- function(n_cat) {
tabular()$emb_sz_rule(
n_cat = n_cat
)
}
#' @title Get_emb_sz
#'
#' @description Get default embedding size from `TabularPreprocessor` `proc` or the ones in `sz_dict`
#'
#'
#' @param to to
#' @param sz_dict dictionary size
#' @return None
#' @export
get_emb_sz <- function(to, sz_dict = NULL) {
tabular()$get_emb_sz(
to = to,
sz_dict = sz_dict
)
}
#' @title Tabular_config
#'
#' @description Convenience function to easily create a config for `TabularModel`
#'
#'
#' @param ps ps
#' @param embed_p embed proportion
#' @param y_range y_range
#' @param use_bn use batch normalization
#' @param bn_final batch normalization final
#' @param bn_cont batch normalization
#' @param act_cls activation
#' @return None
#' @export
tabular_config <- function(ps = NULL, embed_p = 0.0, y_range = NULL,
use_bn = TRUE, bn_final = FALSE,
bn_cont = TRUE, act_cls = nn()$ReLU(inplace = TRUE)) {
args = list(
ps = ps,
embed_p = embed_p,
y_range = y_range,
use_bn = use_bn,
bn_final = bn_final,
bn_cont = bn_cont,
act_cls = act_cls
)
if(is.null(args$ps))
args$ps <- NULL
if(is.null(args$y_range))
args$y_range <- NULL
do.call(tabular()$tabular_config, args)
}
#' @title TabularModel
#'
#' @description Basic model for tabular data.
#'
#'
#' @param emb_szs embedding size
#' @param n_cont number of cont
#' @param out_sz output size
#' @param layers layers
#' @param ps ps
#' @param embed_p embed proportion
#' @param y_range y range
#' @param use_bn use batch normalization
#' @param bn_final batch normalization final
#' @param bn_cont batch normalization cont
#' @param act_cls activation
#' @return None
#' @export
TabularModel <- function(emb_szs, n_cont, out_sz, layers, ps = NULL,
embed_p = 0.0, y_range = NULL, use_bn = TRUE, bn_final = FALSE,
bn_cont = TRUE, act_cls = nn()$ReLU(inplace = TRUE)) {
if(missing(emb_szs) & missing(n_cont) & missing(out_sz) & layers) {
invisible(tabular()$TabularModel)
} else {
args <- list(
emb_szs = emb_szs,
n_cont = n_cont,
out_sz = out_sz,
layers = layers,
ps = ps,
embed_p = embed_p,
y_range = y_range,
use_bn = use_bn,
bn_final = bn_final,
bn_cont = bn_cont,
act_cls = act_cls
)
if(is.null(args$ps))
args$ps <- NULL
if(is.null(args$y_range))
args$y_range <- NULL
do.call(tabular()$TabularModel, args)
}
}