-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathlinear_model_transfer.Rmd
182 lines (131 loc) · 6.3 KB
/
linear_model_transfer.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
---
title: "Linear Model Transfer"
author: "Nina Zumel"
date: "2024-07-08"
output: github_document
---
## Quantile Regression
In this notebook, we'll show an example of fitting quantile regressions with the `quantreg` package,
then transferring the model to another linear framework: in this example `lm`. A similar procedure can be used
to transfer an R `quantreg` model to a framework in another language, for example `scikit-learn` in Python, or `linfa` in Rust.
```{r}
library(quantreg)
library(ggplot2)
```
We'll use the Mammals dataset from the `quantreg` package for our example, and explore the relationship between mammal weight and mammal speed.
```{r}
data(Mammals)
summary(Mammals)
# let's compare log weight to log speed
Mammals = within(Mammals,
{
logwt = log(weight)
logspeed = log(speed)
})
# hold out a few rows for didactic reasons
mammals = Mammals[1:100, ]
newmammals = Mammals[101:107, ]
```
We'll use `rq()` to fit models for the 10th and 90th percentile of log speed as a function of log weight, and whether or not the animal is a hopper (like a kangaroo) or "special": animals with "lifestyles in which speed does not figure as an important factor". Examples of "special" animals include the porcupine, the skunk, and the raccoon.
Note that, as discussed [here](http://www.econ.uiuc.edu/~roger/research/rq/FAQ), the calculated percentile predictions from the model aren't necessarily unique, but they are correct. Hence you may get a "non-uniqueness" warning when running `rq`.
```{r}
# 90th percentile
model09 = rq(logspeed ~ logwt + hoppers + specials, data=mammals, tau=0.9)
summary(model09)
coef(model09)
# 10th percentile
model01 = rq(logspeed ~ logwt + hoppers + specials, data=mammals, tau=0.1)
summary(model01)
coef(model01)
```
For comparison, we'll fit a linear model for expected log speed as a function of log weight, etc. This is mostly just to hammer in that the models we fit above with `rq` aren't the same.
```{r}
lmmodel = lm(logspeed ~ logwt + hoppers + specials, data=mammals)
summary(lmmodel)
coef(lmmodel)
```
Now let's compare actual log speeds to the linear and percentile predictions. We'll plot outcomes and predictions as a function of weight, with the data partitioned by the values of `specials` and `hoppers`.
```{r}
mammals$predmean = predict(lmmodel, newdata=mammals)
mammals$pred01 = predict(model01, newdata=mammals)
mammals$pred09 = predict(model09, newdata=mammals)
ggplot(mammals, aes(x=logwt, y=logspeed)) +
geom_point(color='darkgray') +
geom_line(aes(y=predmean), linetype="dashed") +
geom_line(aes(y=pred01), color='darkblue') +
geom_line(aes(y=pred09), color='darkblue') +
facet_grid(hoppers ~ specials, labeller=label_both) +
ggtitle("Mammal speed as a function of weight, with 10th and 90th percentiles")
```
## Transfer the quantile model to the `lm` framework
Now let's transfer the coefficients from the `qr` models to `lm` models. As discussed in our main article, we can
do this by evaluating the quantile models on a full rank set of rows. Since all the variables here are numeric or logical (0/1), this only requires 4 independent rows: one for each variable, plus the intercept.
Note that in the representation below, the intercept column is implicit, so you can think of `evalframe` having an invisible intercept column `c(1, 1, 1, 1)`, making `evalframe` a 4 by 4 full rank matrix.
First, transfer the 90th percentile model.
```{r}
meanlogwt = mean(mammals$logwt) # not strictly necessary, but nice to keep the value in range
# you can build this frame any old way, this just happens to be readable
evalframe = wrapr::build_frame(
'logwt', 'hoppers', 'specials' |
0, 0, 0 | # for intercept
meanlogwt, 0, 0 |
0, 1, 0 |
0, 0, 1
)
# hoppers and specials need to be logicals
evalframe$hoppers = as.logical(evalframe$hoppers)
evalframe$specials = as.logical(evalframe$specials)
# now add the predictions from the 90th percentile model
evalframe$logspeed = predict(model09, newdata=evalframe)
print(evalframe)
```
In the code above, I used `wrapr::build_frame` to specify the dataframe in a row-wise fashion. This is of course completely optional; I just didn't want to transpose the frame in my head to write it down in the default column-wise manner.
Now, let's fit an `lm` model to the evaluation frame
```{r}
model09lm = lm(logspeed ~ logwt + hoppers + specials, data=evalframe)
# compare the coefficients
data.frame(
ported_model09 = coef(model09lm),
original_model09 = coef(model09)
)
```
The coefficients of the new `lm` model are the same as the original 90th percentile model!
It's easy to prove that the predicted values are exactly the same, even on new data.
```{r}
Mammals$pred09 = predict(model09, newdata=Mammals)
Mammals$pred09x = predict(model09lm, newdata=Mammals)
# mark the training and holdout data
Mammals$datatype = c(rep('training', 100), rep('holdout', 7))
palette = c(training='darkgray', holdout='#d95f02')
ggplot(Mammals, aes(x=pred09, y=pred09x, color=datatype)) +
geom_point() +
geom_abline(color='gray') +
scale_color_manual(values=palette) +
ggtitle("90th percentile: lm representation vs original model")
```
We'll finish off by also re-fitting the 10th percentile model, and replicating the mean and percentile plot we created above.
```{r}
# get the 10th percentile predictions
evalframe$logspeed = predict(model01, newdata=evalframe)
model01lm = lm(logspeed ~ logwt + hoppers + specials, data=evalframe)
# compare the coefficients
data.frame(
ported_model01 = coef(model01lm),
original_model01 = coef(model01)
)
```
```{r}
# recreate the plot (including the holdout data)
Mammals$predmean = predict(lmmodel, newdata=Mammals)
Mammals$pred01 = predict(model01lm, newdata=Mammals)
Mammals$pred09 = predict(model09lm, newdata=Mammals)
ggplot(Mammals, aes(x=logwt, y=logspeed)) +
geom_point(aes(color=datatype)) +
geom_line(aes(y=predmean), linetype="dashed") +
geom_line(aes(y=pred01), color='darkblue') +
geom_line(aes(y=pred09), color='darkblue') +
scale_color_manual(values=palette) +
facet_grid(hoppers ~ specials, labeller=label_both) +
ggtitle("Mammal speed as a function of weight, with 10th and 90th percentiles",
subtitle="Using models ported to lm from qr")
```