-
Notifications
You must be signed in to change notification settings - Fork 0
/
feature_importance.R
105 lines (84 loc) · 3.1 KB
/
feature_importance.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
library(tidyverse)
library(ggtext)
library(purrr)
library(readr)
library(dplyr)
l2_files<- list.files(path="results/seq_gl_100s/runs",
pattern="*.Rds",
full.names=TRUE)
#
test <- l2_files[1]
model <- readRDS(test) %>%
pluck("trained_model")
model <- readRDS(test)
readRDS(test)
coef(model$finalModel, model$bestTune$lambda) %>%
as.matrix %>%
as_tibble(rownames = "feature") %>%
rename(weight = s1)
get_weights <- function(file_name){
model <- readRDS(file_name)
coef(model$finalModel, model$bestTune$lambda) %>%
as.matrix %>%
as_tibble(rownames="feature") %>%
rename(weight = `s1`) %>%
mutate(seed = str_replace(file_name,
"results/seq_gl_100s_(\\d*).Rds",
"\\1"))
}
l2_weights <- map_dfr(l2_files, get_weights)
summary(l2_weights$weight)
l2_weights %>%
filter(feature != "(Intercept)") %>%
group_by(feature) %>%
summarize(median = median(weight),
l_quartile = quantile(weight, prob=0.25),
n = n(),
u_quartile = quantile(weight, prob=0.75))
#
# %>%
# mutate(#feature = str_replace(feature, "(.*)", "*\\1*"),
# #feature = str_replace(feature, "(.*)_unclassified\\*", "Unclassified \\1*"),
# #feature = str_replace(feature, "_(.*)\\*", "* \\1"),
# feature = fct_reorder(feature, median)) %>%
# filter(abs(median) > 0.0000008) %>%
# ggplot(aes(x=median, y=feature, xmin=l_quartile, xmax=u_quartile)) +
# geom_vline(xintercept=0, color="gray") +
# geom_point() +
# geom_linerange() +
# labs(x="Weights", y=NULL,
# title = "Feature Weights") +
# theme_classic() +
# theme(axis.text.y = element_markdown())
ggsave("figures/data_gl_lambda/feature_weights.jpg")
l2_files
get_feature_importance <- function(file_name){
feature_importance <- read_csv(file_name) %>%
#pluck("feature_importance") %>%
as_tibble() %>%
select(feat, perf_metric, perf_metric_diff)
}
l2_files<- list.files(path="results/seq_gl_100s/runs",
pattern="*feature-importance.csv",
full.names=TRUE)
l2_feature_importance <- map_dfr(l2_files, get_feature_importance)
summary(l2_feature_importance)
l2_feature_importance %>%
rename(feature = feat) %>%
group_by(feature) %>%
summarize(median = median(perf_metric_diff),
l_quartile = quantile(perf_metric_diff, prob=0.25),
u_quartile = quantile(perf_metric_diff, prob=0.75)) %>%
mutate(feature = str_replace(feature, "(.*)", "*\\1*"),
feature = str_replace(feature, "(.*)_unclassified\\*", "Unclassified \\1*"),
feature = str_replace(feature, "_(.*)\\*", "* \\1"),
feature = fct_reorder(feature, median)) %>%
filter(median > 0.000075) %>%
ggplot(aes(x=median, y=feature, xmin=l_quartile, xmax=u_quartile)) +
geom_point() +
geom_linerange() +
labs(x="Change in AUC when removed", y=NULL,
title = "Performance Metric Difference") +
theme_classic() +
theme(axis.text.y = element_markdown())
ggsave("figures/data_gl_lambda/l2_feature_importance.tiff", width=5, height=5)