-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_test_split.R
167 lines (147 loc) · 6.16 KB
/
train_test_split.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
library(tidyverse)
train_test_split = function(df, y_cols, id_cols, feats_lst, test_size = .3,
alpha = .5, target_alpha = .9, validate = TRUE) {
# Splits df into train/test sets and input/target (X/y) sets.
# (Must have id_col, but can be "dummy" since it's discarded for index.)
# Parameters:
# df: (data.frame) Full data set, including target variable(s).
# y_cols: (c(character)) Target column(s).
# id_cols: (c(character)) Id column(s) to drop, because df maintains index.
# test_size: (numeric) Proportion of rows to use for test set.
# (Does not validate.)
# alpha: (numeric) Probability of incorrectly rejecting the null hypothesis.
# H0 = feature n of train and of test do not represent different sets.
# (i.e. representative split)
# H1 = feature n of train and of test represent different supersets.
# target_alpha: (numeric) Alpha to use if feature is target feature (i.e.
# if feature is in y_cols).
# validate: (bool) Should set split be validated?
# Return:
# split_lst: (list(data.frame)) (train_X, train_y, test_X, test_y)
# train_X (data.frame) Input features in training subset.
# train_y (data.frame) Target variable in training subset.
# test_X (data.frame) Input features in testing subset.
# test_y (data.frame) Target variable in testing subset.
split_lst = list(
'train_X' = data.frame(),
'train_y' = data.frame(),
'test_X' = data.frame(),
'test_y' = data.frame()
)
full_set_len = nrow(df)
test_set_len = as.integer(test_size * full_set_len)
###
### TO DO: Add a parameter and logic to choose whether to track this. ###
###
# To track average p-values of features:
feats_p_av_lst = vector(mode = 'list', length = length(feats_lst))
names(feats_p_av_lst) = feats_lst
# Split and validate until valid.
valid_split = FALSE
while (!valid_split) {
# Split randomly.
test_idx = sample(x = full_set_len, size = test_set_len)
split_lst$train_X = select(df[-test_idx, ], -all_of(y_cols))
split_lst$train_y = select(df[-test_idx, ], all_of(y_cols))
split_lst$train_y[id_cols] = split_lst$train_X[id_cols]
split_lst$test_X = select(df[test_idx, ], -all_of(y_cols))
split_lst$test_y = select(df[test_idx, ], all_of(y_cols))
split_lst$test_y[id_cols] = split_lst$test_X[id_cols]
# Validate the split.
if (validate) {
# Randomize test order to "cost-average" compute.
feats_lst = sample(feats_lst)
# Test X and y separately to avoid the join compute and data copies.
X_validation_results = validate_split(
train = split_lst$train_X,
test = split_lst$test_X,
feats_lst = feats_lst,
y_cols = y_cols,
feats_p_val_lst = feats_p_av_lst,
alpha = alpha,
target_alpha = target_alpha
)
feats_p_av_lst = X_validation_results$p_vals
if (X_validation_results$valid){
y_validation_results = validate_split(
train = split_lst$train_y,
test = split_lst$test_y,
feats_lst = feats_lst,
y_cols = y_cols,
feats_p_val_lst = feats_p_av_lst,
alpha = alpha,
target_alpha = target_alpha
)
feats_p_av_lst = y_validation_results$p_vals
if (y_validation_results$valid) {
valid_split = TRUE
} # else { print("Invalid y split. Resampling.") }
} # else { print("Invalid X split. Resampling.") }
} else {valid_split = TRUE}
}
if (validate) {
for(feat in names(feats_p_av_lst)) {
feats_p_av_lst[[feat]] = mean(feats_p_av_lst[[feat]])
}
print('Average p-values:')
print(feats_p_av_lst)
}
return(split_lst)
}
validate_split = function(train, test, feats_lst, y_cols, feats_p_val_lst,
alpha = .5, target_alpha = .9) {
# Conducts Wilcoxon ranks sum test column by column to test if train and test
# represent a similar superset. (i.e., is the split stratified on every
# feature?) Both train and test should have the same features. There should
# be at least one numeric (i.e. continuous) feature, as the test will only
# be performed on these columns -- this does limit the test.
# Parameters:
# train: (data.frame) A subset of original set to compare to the other
# subset, test.
# test: (data.frame) A subset of original set to compare to the other
# subset, train.
# feats_lst: (list(character)) List of features to test.
# y_cols: (c(character)) Vector of target features.
# feats_p_val_lst: (list(character:list(double)) Dictionary of p-values to
# to track which features are hardest to stratify.
# alpha: (numeric) Probability of incorrectly rejecting the null hypothesis.
# H0 = feature n of train and test does not represent different sets.
# (i.e. representative split)
# H1 = feature n of train and test represents a different superset.
# target_alpha: (numeric) Alpha to use if feature is target feature (i.e.
# if feature is in y_cols).
# Return:
# list(valid: (bool), p_vals: (list(character:list(double)))
# valid: (bool) Are the sets representative of the same superset?
# p_vals: (list(character:list(double)) feats_p_val_lst updated
valid = TRUE
for (feat in feats_lst) {
if (valid & feat %in% colnames(train) & feat %in% colnames(test)) {
this_alpha = alpha
if (feat %in% y_cols) {
this_alpha = target_alpha
}
results = wilcox.test(
x = as.double(train[[feat]]),
y = as.double(test[[feat]])
)
if (!(results$p.value > this_alpha)) {
# print("Reject null hypothesis that split is not unrepresentative:")
valid = FALSE
}
# print(feat)
# print(results$p.value)
feats_p_val_lst[[feat]] = c(feats_p_val_lst[[feat]], results$p.value)
}
}
return(list('valid' = valid, 'p_vals' = feats_p_val_lst))
}
write_sets = function(set_lst, prefix, file_path, row.names = FALSE) {
for (set_name in names(set_lst)) {
write.csv(
set_lst[[set_name]],
paste(file_path, prefix, set_name, '.csv', sep = ''),
row.names = row.names
)
}
}