# Linear models - scalability

In this notebook, we will make a quick note to show the `partial_fit` functionality of some estimator that could be used to train a model.

In [1]:
import pandas as pd

data = pd.read_csv("../datasets/adult-census-numeric-all.csv")
data.head()

Unnamed: 0,age,education-num,capital-gain,capital-loss,hours-per-week,class
0,25,7,0,0,40,<=50K
1,38,9,0,0,50,<=50K
2,28,12,0,0,40,>50K
3,44,10,7688,0,40,>50K
4,18,10,0,0,30,<=50K


In [2]:
target_name = "class"
X = data.drop(columns=target_name)
y = data[target_name]

In [3]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=0,
)

In [4]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

In [5]:
batch_size = 100
start = 0
while start < y_train.size:
    stop = start + batch_size
    scaler.partial_fit(X_train[start:stop])
    start = stop

In [6]:
scaler.mean_, scaler.var_

(array([  38.68545767,   10.07327127, 1063.20692856,   86.77983129,
          40.42248369]),
 array([1.88511311e+02, 6.61175889e+00, 5.43824675e+07, 1.61245256e+05,
        1.54386985e+02]))

In [7]:
scaler = StandardScaler().fit(X_train)
scaler.mean_, scaler.var_

(array([  38.68545767,   10.07327127, 1063.20692856,   86.77983129,
          40.42248369]),
 array([1.88511311e+02, 6.61175889e+00, 5.43824675e+07, 1.61245256e+05,
        1.54386985e+02]))

In [9]:
from sklearn.linear_model import SGDClassifier

model = SGDClassifier(loss="hinge", alpha=0.01, max_iter=200)

In [10]:
import numpy as np

batch_size = 4_000
start = 0
iteration = 1
while start < y_train.size:
    stop = start + batch_size
    X_scaled = scaler.transform(X_train[start:stop])
    if not start:
        params = {"classes": np.unique(y)}
    else:
        params = {}
    model.partial_fit(X_scaled, y_train[start:stop], **params)
    print(
        f"Iteration #{iteration}: Weights:\n"
        f"{model.coef_}"
    )
    iteration += 1
    start = stop

Iteration #1: Weights:
[[0.01482068 0.25556641 1.30735298 0.38501694 0.20974326]]
Iteration #2: Weights:
[[0.25661497 0.23623037 1.07041435 0.33384739 0.16809809]]
Iteration #3: Weights:
[[0.15510484 0.20522907 1.27752695 0.22135532 0.13047763]]
Iteration #4: Weights:
[[0.15752789 0.21567381 1.19598888 0.23075887 0.08618274]]
Iteration #5: Weights:
[[0.13307296 0.1875683  1.07980003 0.2715925  0.07231146]]
Iteration #6: Weights:
[[0.13556772 0.14004452 1.09434847 0.2794567  0.08625648]]
Iteration #7: Weights:
[[0.08004633 0.14936342 1.14292221 0.30575399 0.07220483]]
Iteration #8: Weights:
[[0.07659844 0.1616876  1.18038861 0.15086681 0.04759882]]
Iteration #9: Weights:
[[0.10789898 0.2235596  1.21571193 0.21749484 0.15053511]]
Iteration #10: Weights:
[[0.12893306 0.17821288 1.21891103 0.14149952 0.10998177]]
