-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathlinear_model_transfer_categorical.Rmd
146 lines (103 loc) · 5.65 KB
/
linear_model_transfer_categorical.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
---
title: "Linear Model Transfer with Categorical Variables"
author: "Nina Zumel"
date: "2024-07-08"
output: github_document
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
In our main article, we showed an example of transferring the weights from one linear model to another one, without retraining on the original data. In order to do that, we fit a new model to a full rank set of rows, using the predictions of the original model on our new evaluation set as the outcome.
When all the variables are continuous, then the number of rows we need is the number of variables plus one, for the intercept. This is the number of coefficients in the model. If $n_{var}$ is the number of variables, then an easy evaluation matrix to use is the $n_{var}$-dimensional identity matrix, with a row of all zeros appended to "represent" the intercept. This representation assumes an implicit all ones column for the intercept; if you are transferring the model to a framework that needs the intercept column explicitly, you'd need to add a column of all ones explicitly.
When some of the variables are categorical, the process is slightly more involved.
```{r}
library(Matrix)
```
## One categorical variable
For our first example, we'll use the `iris` dataset and fit sepal length as a function of sepal width, petal length, and species.
There are three species represented: *setosa*, *versicolor*, and *virginica*.
First let's fit a model.
```{r}
iris = iris
summary(iris)
# fit on two continuous vars for now
model = lm(Sepal.Length ~ Sepal.Width + Petal.Length + Species, data=iris)
coef(model)
```
Now we want to transfer the weights from this model to another one, without retraining on the original data. Since our model has
5 coefficients, we know we need 5 rows. But let's count them out explicitly.
* 1 row for the intercept
* 2 rows for the two continuous variables
* `nlevels(Species)` - 1 = 3 - 1 = 2 rows for the categorical variable `Species`
1 + 2 + 2 = 5.
In R (and in most other linear modeling frameworks), by default, one level from each categorical variable is folded into the intercept term as a "reference level": sort of the categorical equivalent of the origin. So rather than a reference row of all zeros to represent the intercept, we need a row of all zeros (for the continuous variables), *plus* the reference levels.
In R, you either specify the categorical levels directly, or build the model matrix directly. In python, you'd need to use the appropriate one-hot encoded representation, which (in R parlance) is building the model matrix directly.
The reference level of a categorical variable is generally the orthographically first one. This is true in R; you can use the `relevel()` command to change it. In any event, make sure to use the same reference levels as the original model.
In this case, *setosa* is the reference level for `Species`.
```{r}
evalframe = wrapr::build_frame(
'Sepal.Width', 'Petal.Length', 'Species' |
0, 0, 'setosa' | # intercept plus reference level
1, 0, 'setosa' |
0, 1, 'setosa' |
0, 0, 'versicolor' |
0, 0, 'virginica'
)
# attach predictions from the original model
evalframe$Sepal.Length = predict(model, newdata=evalframe)
# look at the model matrix and confirm it's full rank.
m = model.matrix(Sepal.Length ~ Sepal.Width + Petal.Length + Species, evalframe)
m
# assert that the model is full rank
stopifnot(rankMatrix(m)[1] == nrow(m))
```
We've confirmed that the model matrix for the evaluation frame is full rank. Now we fit another `lm` to the evaluation frame.
```{r}
portmod = lm(Sepal.Length ~ Sepal.Width + Petal.Length + Species, data=evalframe)
data.frame(
ported_model = coef(portmod),
original_model = coef(model)
)
```
The coefficients are the same as the original model fit to the iris data.
## More than one categorical variable
In this example, we'll predict the ascorbic acid (vitamin C) content of cabbages as a function of: their weight (continuous variable), their cultivar (categorical with two values: `c39` and `c52`) and their planting date (categorical with three values: `d16`, `d20`, `d21`).
```{r}
library(MASS) # to get the data
cabbages = cabbages
summary(cabbages)
model = lm(VitC ~ HeadWt + Date + Cult, data=cabbages)
coef(model)
```
To port the model, let's walk through the number of rows we need:
* 1 for the intercept
* 1 for the continuous variable `HeadWt`
* `nlevels(Date)` - 1 = 3 - 1 = 2 for the categorical variable `Date` (the reference level is `d16`)
* `nlevels(Cult)` - 1 = 2 - 1 = 1 for the categorical variable `Cult` (the reference level is `c39`)
1 + 1 + 2 + 1 = 5.
```{r}
evalframe = wrapr::build_frame(
'Cult', 'Date', 'HeadWt' |
'c39', 'd16', 0 | # intercept plus reference levels
'c39', 'd16', 1 | # HeadWt is 1, everything else is "0"
'c52', 'd16', 0 | # Cultc52 is 1, everything else is "0"
'c39', 'd20', 0 | # Dated20 is 1, everything else is "0"
'c39', 'd21', 0 # Dated21 is 1, everything else is "0"
)
# attach predictions from the original model
evalframe$VitC = predict(model, newdata=evalframe)
# look at the model matrix and confirm it's full rank.
m = model.matrix(VitC ~ HeadWt + Date + Cult, evalframe)
m
# assert that the model is full rank
stopifnot(rankMatrix(m)[1] == nrow(m))
```
The model matrix is full rank, so let's fit another `lm` to the evaluation frame.
```{r}
portmod = lm(VitC ~ HeadWt + Date + Cult, data=evalframe)
data.frame(
ported_model = coef(portmod),
original_model = coef(model)
)
```
The coefficients of the ported model are the same as the original model fit to the cabbages data.