In [1]:
import glob

import pandas as pd
import dask.dataframe as dd

In [2]:
def read_data(location, format, frac=0.02):
    files = glob.glob(f"{location}*.{format}")
    if format == "parquet":
        data = dd.read_parquet(location).sample(frac=0.02)
    # elif format == "h5":
    #     data = dd.concat([read_hdf5(file) for file in files]).sample(frac=0.02)
    data = data.rename(columns={"Distance to CIS": "Distance to CSL"})
    return data

In [3]:
# ddf = read_data("../data/aggregated_data/parquet/", "parquet", frac=1)
# ddf

In [4]:
pd.set_option('display.max_columns', None)
# ddf = dd.read_parquet("../data/aggregated_data/parquet/tickets_full_merge_2024.parquet")
ddf = dd.read_parquet([
    "../data/aggregated_data/parquet/tickets_full_merge_2023.parquet",
    "../data/aggregated_data/parquet/tickets_full_merge_2024.parquet",
])
ddf.dropna(subset=["tempmax", "tempmin", "temp",  "conditions", "humidity", "windspeed", "visibility"])
ddf.head(5)

Unnamed: 0,Summons Number,Plate ID,Registration State,Plate Type,Issue Date,Violation Code,Vehicle Body Type,Vehicle Make,Issuing Agency,Street Code1,Street Code2,Street Code3,Vehicle Expiration Date,Violation Location,Violation Precinct,Issuer Precinct,Issuer Code,Issuer Command,Issuer Squad,Violation Time,Violation County,Violation In Front Of Or Opposite,House Number,Street Name,Intersecting Street,Date First Observed,Law Section,Sub Division,Violation Legal Code,Days Parking In Effect,From Hours In Effect,To Hours In Effect,Vehicle Color,Vehicle Year,Feet From Curb,Violation Post Code,Violation Description,Latitude,Longitude,datetime,tempmax,tempmin,temp,conditions,humidity,windspeed,visibility,Closest Middle School,Distance to CMS,Closest High School,Distance to CHS,Closest Individual Landmark,Distance to CIL,Closest Scenic Landmark,Distance to CIS,Closest Business,Industry of CB,Distance to CB
0,1484699750,GCX5397,NY,PAS,1687132800000,63,SUBN,CHEVR,N,30640,13015,28540,20230103,1,1,1,161331,1,0,1102A,Manhattan,O,109.0,SOUTH ST,,0,408,F3,,BBBBBBB,ALL,ALL,WHT,2010,0,,,40.708622,-74.00387,1687132800000,25.7,19.3,21.8,Clear,64.5,21.5,15.9,Spruce Street School,0.303262,Urban Assembly School for Emergency Management...,0.403826,18 Fulton Street Building,0.20808,Bryant Park,5.313856,SEAPORT PARKING LLC,Garage,0.009434
1,1484721329,HEZ7860,NY,PAS,1656892800000,20,SUBN,SUBAR,K,39202,0,0,20240305,122,122,5,160548,5,0,0235P,Staten Island,,,MIDLAND BEACH,LOT 8,0,408,C4,,BBBBBBB,ALL,ALL,GL,2016,0,,,40.573161,-74.094586,1656892800000,29.4,19.9,25.0,Rain,42.2,14.0,15.7,I.S. R002 George L. Egbert,0.652695,New Dorp High School,1.187157,Ernest Flagg Estate Cottage: McCall's Demonstr...,1.653102,Coney Island (Riegelman) Boardwalk,9.709113,"BILTMORE GENERAL CONTRACTORS, INC.",Home Improvement Contractor,0.013309
2,1484721330,HDE5505,NY,PAS,1656892800000,20,SUBN,ME/BE,K,39202,0,0,20240303,122,122,5,160548,5,0,0156P,Staten Island,,,MIDLAND BEACH,LOT 8,0,408,D,,BBBBBBB,ALL,ALL,WHITE,2019,0,,,40.573161,-74.094586,1656892800000,29.4,19.9,25.0,Rain,42.2,14.0,15.7,I.S. R002 George L. Egbert,0.652695,New Dorp High School,1.187157,Ernest Flagg Estate Cottage: McCall's Demonstr...,1.653102,Coney Island (Riegelman) Boardwalk,9.709113,"BILTMORE GENERAL CONTRACTORS, INC.",Home Improvement Contractor,0.013309
3,1484724136,JEB6269,NY,PAS,1656892800000,68,SDN,CHEVR,N,22278,0,0,20221018,122,122,5,160690,5,0,0626P,Staten Island,,,CEDAR GROVE BEACH,PARKING LOT,0,408,F3,,BBBBBBB,ALL,ALL,SILVE,2015,0,,,40.558055,-74.101515,1656892800000,29.4,19.9,25.0,Rain,42.2,14.0,15.7,I.S. R002 George L. Egbert,2.077059,New Dorp High School,1.365887,Gustave A. Mayer House,2.571409,Coney Island (Riegelman) Boardwalk,10.43148,U AND P DELI & GROCERY INC,Electronic Cigarette Dealer,0.430214
4,1484725888,KET8159,NY,PAS,1657324800000,20,SDN,CHEVR,K,42850,33720,61830,20221029,84,84,4,160448,4,0,0841P,Brooklyn,F,60.0,FURMAN ST,,0,408,C4,,BBBBBBB,ALL,ALL,SILVE,2011,0,,,40.697978,-73.993521,1657324800000,29.3,22.8,25.1,"Rain, Partially cloudy",61.5,21.6,15.9,Dock Street School for STEAM Studies,0.576071,"Urban Assembly School for Law and Justice, The",0.537792,Brooklyn Trust Company Building,0.398727,Prospect Park,4.541993,SUN NEWS INC.,Tobacco Retail Dealer,0.058913


In [5]:
def chunk(s):
    # for the comments, assume only a single grouping column, the 
    # implementation can handle multiple group columns.
    #
    # s is a grouped series. value_counts creates a multi-series like 
    # (group, value): count
    return s.value_counts()


def agg(s):
    # s is a grouped multi-index series. In .apply the full sub-df will passed
    # multi-index and all. Group on the value level and sum the counts. The
    # result of the lambda function is a series. Therefore, the result of the 
    # apply is a multi-index series like (group, value): count
    return s.apply(lambda s: s.groupby(level=-1, sort=False).sum())

    # # faster version using pandas internals
    # s = s._selected_obj
    # return s.groupby(level=list(range(s.index.nlevels))).sum()


def finalize(s):
    # s is a multi-index series of the form (group, value): count. First
    # manually group on the group part of the index. The lambda will receive a
    # sub-series with multi index. Next, drop the group part from the index.
    # Finally, determine the index with the maximum value, i.e., the mode.
    level = list(range(s.index.nlevels - 1))
    return (
        s.groupby(level=level, sort=False)
        # .apply(lambda s: s.reset_index(level=level, drop=True).argmax())
        .apply(lambda s: s.reset_index(level=level, drop=True).idxmax())
    )


mode = dd.Aggregation('mode', chunk, agg, finalize)

In [6]:
daily_ddf = ddf.groupby("Issue Date").agg({
    "tempmax": "first", 
    "tempmin": "first", 
    "temp": "first", 
    "conditions": "first", 
    "humidity": "first", 
    "windspeed": "first", 
    "visibility": "first",
    "Distance to CMS": "mean",
    "Distance to CHS": "mean",
    "Distance to CIL": "mean",
    "Distance to CIS": "mean",
    "Distance to CB": "mean",
})

daily_ddf["count"] = ddf.groupby("Issue Date").size()
daily_ddf = daily_ddf.dropna()
daily_ddf.compute()

Unnamed: 0_level_0,tempmax,tempmin,temp,conditions,humidity,windspeed,visibility,Distance to CMS,Distance to CHS,Distance to CIL,Distance to CIS,Distance to CB,count
Issue Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
1687132800000,25.7,19.3,21.8,Clear,64.5,21.5,15.9,0.622759,0.706088,0.549993,4.631736,0.108002,12192
1656892800000,29.4,19.9,25.0,Rain,42.2,14.0,15.7,0.628087,0.738772,0.705272,5.421443,0.128032,3059
1657324800000,29.3,22.8,25.1,"Rain, Partially cloudy",61.5,21.6,15.9,0.616593,0.707777,0.602161,4.856251,0.106574,12824
1656806400000,31.0,24.3,26.8,Partially cloudy,52.4,21.5,16.0,0.643300,0.767375,0.632290,4.617807,0.132466,5147
1656720000000,31.1,24.3,27.3,"Rain, Partially cloudy",66.9,16.2,16.0,0.594962,0.685915,0.568028,4.771755,0.096920,17168
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1716854400000,26.7,18.9,22.5,Partially cloudy,64.1,35.7,15.9,0.600562,0.663457,0.534219,4.424932,0.110785,28672
1716336000000,27.9,18.0,23.1,Clear,63.2,14.9,16.0,0.622040,0.704169,0.578194,4.671199,0.111999,27157
1715904000000,21.7,15.6,18.2,"Rain, Partially cloudy",66.5,27.3,15.8,0.598306,0.665789,0.528128,4.282436,0.109453,29689
1716422400000,25.6,20.0,22.3,Rain,70.8,24.1,15.4,0.594586,0.686704,0.572924,4.572234,0.111724,28503


In [7]:
daily_mode_ddf = ddf.groupby("Issue Date", sort=False).agg({
    "Registration State": mode,
    "Plate Type": mode,
    "Violation Code": mode,
    "Vehicle Body Type": mode,
    "Vehicle Make": mode,
    "Issuing Agency": mode,
    "Violation County": mode,
})
daily_mode_ddf.compute()

  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)


  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)


  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)


Unnamed: 0_level_0,Registration State,Plate Type,Violation Code,Vehicle Body Type,Vehicle Make,Issuing Agency,Violation County
Issue Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1656720000000,NY,PAS,38,SUBN,HONDA,T,Manhattan
1656806400000,NY,PAS,14,SUBN,HONDA,T,Manhattan
1656892800000,NY,PAS,71,SUBN,HONDA,T,Manhattan
1656979200000,NY,PAS,21,SUBN,HONDA,T,Manhattan
1657065600000,NY,PAS,38,SUBN,FORD,T,Manhattan
...,...,...,...,...,...,...,...
1719273600000,NY,PAS,68,SUBN,FORD,U,Queens
1719360000000,NY,COM,98,,ISUZU,P,Brooklyn
1719446400000,NY,OMT,14,TAXI,TOYOT,M,Manhattan
1719532800000,NY,PAS,63,SUBN,FORD,K,Queens


In [8]:
daily_ddf = daily_ddf.merge(daily_mode_ddf, on="Issue Date")
daily_ddf.compute()

  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)


Unnamed: 0_level_0,tempmax,tempmin,temp,conditions,humidity,windspeed,visibility,Distance to CMS,Distance to CHS,Distance to CIL,Distance to CIS,Distance to CB,count,Registration State,Plate Type,Violation Code,Vehicle Body Type,Vehicle Make,Issuing Agency,Violation County
Issue Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
1687132800000,25.7,19.3,21.8,Clear,64.5,21.5,15.9,0.622759,0.706088,0.549993,4.631736,0.108002,12192,NY,PAS,38,SUBN,HONDA,T,Manhattan
1656892800000,29.4,19.9,25.0,Rain,42.2,14.0,15.7,0.628087,0.738772,0.705272,5.421443,0.128032,3059,NY,PAS,71,SUBN,HONDA,T,Manhattan
1657324800000,29.3,22.8,25.1,"Rain, Partially cloudy",61.5,21.6,15.9,0.616593,0.707777,0.602161,4.856251,0.106574,12824,NY,PAS,38,SUBN,HONDA,T,Manhattan
1656806400000,31.0,24.3,26.8,Partially cloudy,52.4,21.5,16.0,0.643300,0.767375,0.632290,4.617807,0.132466,5147,NY,PAS,14,SUBN,HONDA,T,Manhattan
1656720000000,31.1,24.3,27.3,"Rain, Partially cloudy",66.9,16.2,16.0,0.594962,0.685915,0.568028,4.771755,0.096920,17168,NY,PAS,38,SUBN,HONDA,T,Manhattan
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1716854400000,26.7,18.9,22.5,Partially cloudy,64.1,35.7,15.9,0.600562,0.663457,0.534219,4.424932,0.110785,28672,NY,PAS,21,SUBN,HONDA,T,Manhattan
1716336000000,27.9,18.0,23.1,Clear,63.2,14.9,16.0,0.622040,0.704169,0.578194,4.671199,0.111999,27157,NY,PAS,38,SUBN,FORD,T,Manhattan
1715904000000,21.7,15.6,18.2,"Rain, Partially cloudy",66.5,27.3,15.8,0.598306,0.665789,0.528128,4.282436,0.109453,29689,NY,PAS,21,SUBN,FORD,T,Manhattan
1716422400000,25.6,20.0,22.3,Rain,70.8,24.1,15.4,0.594586,0.686704,0.572924,4.572234,0.111724,28503,NY,PAS,21,SUBN,FORD,T,Manhattan


# ML

In [9]:
from dask_ml.model_selection import train_test_split
from dask_ml.preprocessing import Categorizer, OrdinalEncoder
from dask_ml.metrics import mean_squared_error

In [10]:
ce = Categorizer(columns=["conditions", "Registration State", "Plate Type", "Violation Code", "Vehicle Body Type", "Vehicle Make", "Issuing Agency", "Violation County"])
daily_ddf = ce.fit_transform(daily_ddf)

enc = OrdinalEncoder(columns=["conditions", "Registration State", "Plate Type", "Violation Code", "Vehicle Body Type", "Vehicle Make", "Issuing Agency", "Violation County"])
daily_ddf = enc.fit_transform(daily_ddf)

daily_ddf = daily_ddf.persist()
daily_ddf.compute()

  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)


  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)
  return df.__class__(result)


Unnamed: 0_level_0,tempmax,tempmin,temp,conditions,humidity,windspeed,visibility,Distance to CMS,Distance to CHS,Distance to CIL,Distance to CIS,Distance to CB,count,Registration State,Plate Type,Violation Code,Vehicle Body Type,Vehicle Make,Issuing Agency,Violation County
Issue Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
1687132800000,25.7,19.3,21.8,0,64.5,21.5,15.9,0.622759,0.706088,0.549993,4.631736,0.108002,12192,0,0,3,0,1,0,0
1656892800000,29.4,19.9,25.0,3,42.2,14.0,15.7,0.628087,0.738772,0.705272,5.421443,0.128032,3059,0,0,5,0,1,0,0
1657324800000,29.3,22.8,25.1,5,61.5,21.6,15.9,0.616593,0.707777,0.602161,4.856251,0.106574,12824,0,0,3,0,1,0,0
1656806400000,31.0,24.3,26.8,2,52.4,21.5,16.0,0.643300,0.767375,0.632290,4.617807,0.132466,5147,0,0,0,0,1,0,0
1656720000000,31.1,24.3,27.3,5,66.9,16.2,16.0,0.594962,0.685915,0.568028,4.771755,0.096920,17168,0,0,3,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1716854400000,26.7,18.9,22.5,2,64.1,35.7,15.9,0.600562,0.663457,0.534219,4.424932,0.110785,28672,0,0,1,0,1,0,0
1716336000000,27.9,18.0,23.1,0,63.2,14.9,16.0,0.622040,0.704169,0.578194,4.671199,0.111999,27157,0,0,3,0,0,0,0
1715904000000,21.7,15.6,18.2,5,66.5,27.3,15.8,0.598306,0.665789,0.528128,4.282436,0.109453,29689,0,0,1,0,0,0,0
1716422400000,25.6,20.0,22.3,3,70.8,24.1,15.4,0.594586,0.686704,0.572924,4.572234,0.111724,28503,0,0,1,0,0,0,0


In [11]:
X, y = daily_ddf.drop(columns=["count"]), daily_ddf["count"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=42)

In [12]:
display(X_train.compute(), y_train.compute())

Unnamed: 0_level_0,tempmax,tempmin,temp,conditions,humidity,windspeed,visibility,Distance to CMS,Distance to CHS,Distance to CIL,Distance to CIS,Distance to CB,Registration State,Plate Type,Violation Code,Vehicle Body Type,Vehicle Make,Issuing Agency,Violation County
Issue Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
1704153600000,7.7,-0.1,3.5,0,58.7,14.6,15.9,0.624092,0.770406,0.694539,5.639707,0.116056,0,0,1,0,1,0,0
1681084800000,17.9,5.7,11.3,0,49.7,16.7,15.9,0.599122,0.674377,0.539057,4.413902,0.101889,0,0,1,0,1,0,0
1711152000000,11.0,3.8,6.4,4,86.5,34.2,10.9,0.618336,0.726699,0.613188,4.994734,0.109155,0,0,3,0,1,0,0
1700524800000,8.8,3.7,6.5,5,71.1,28.7,15.3,0.596087,0.683229,0.547075,4.488524,0.099647,0,0,1,0,0,0,0
1715212800000,22.2,16.2,18.7,0,54.0,14.4,16.0,0.616560,0.697509,0.577836,4.718427,0.108054,0,0,3,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1712793600000,16.5,10.0,12.5,4,86.8,23.3,14.1,0.621512,0.685835,0.558747,4.628527,0.108062,0,0,3,0,0,0,0
1693526400000,25.5,17.8,21.4,0,51.1,21.5,15.9,0.590814,0.680176,0.528206,4.291356,0.099738,0,0,1,0,0,0,0
1701648000000,11.2,6.7,8.7,0,70.5,32.7,16.0,0.600065,0.710330,0.588772,4.735504,0.104834,0,0,1,0,1,0,0
1701734400000,6.7,4.3,5.7,2,57.4,14.2,16.0,0.587513,0.691421,0.588377,4.807496,0.103586,0,0,1,0,0,0,0


Issue Date
1704153600000    20629
1681084800000    29127
1711152000000     7751
1700524800000    27503
1715212800000    25843
                 ...  
1712793600000    22919
1693526400000    26225
1701648000000    21375
1701734400000    25547
1693958400000    23647
Name: count, Length: 562, dtype: int64

In [13]:
X_train_, y_train_ = X_train.values.persist(), y_train.values.persist()
X_test_, y_test_ = X_test.values.persist(), y_test.values.persist()

In [14]:
X_train_.compute_chunk_sizes(), y_train_.compute_chunk_sizes(), X_test_.compute_chunk_sizes(), y_test_.compute_chunk_sizes()

(dask.array<values, shape=(562, 19), dtype=float64, chunksize=(562, 19), chunktype=numpy.ndarray>,
 dask.array<values, shape=(562,), dtype=int64, chunksize=(562,), chunktype=numpy.ndarray>,
 dask.array<values, shape=(133, 19), dtype=float64, chunksize=(133, 19), chunktype=numpy.ndarray>,
 dask.array<values, shape=(133,), dtype=int64, chunksize=(133,), chunktype=numpy.ndarray>)

## Dask ML

In [15]:
from dask_ml.linear_model import LinearRegression

In [16]:
lr = LinearRegression(solver_kwargs={"normalize":False})
lr.fit(X_train_, y_train_)

In [17]:
preds = lr.predict(X_test_)

In [18]:
y_test_.compute(), preds.compute()

(array([29984, 28031, 28573,  7152,  6687, 20362,  4646, 27157, 15005,
        23293, 31709, 28448, 30771, 26071, 28365,  6574, 27334, 26800,
        23120, 23442, 31224, 30509, 19697, 22579, 26172, 30645,  7054,
        14323, 24818, 14495, 31529,  6469, 23765, 20701, 29026,  9561,
        16706, 27964, 21810, 27418, 25474, 28562, 26214, 27950, 29126,
        23083,  6195,  5575, 24996, 21356,  5470,  7265, 28996,  6566,
         6309, 26573, 27196, 19962, 22338, 15819, 25414, 17467, 22606,
        21464,  6541, 28000, 28569, 26489, 30849,  5150,  4794, 23870,
        27926,  3002, 18989, 30254, 19779, 27943, 28676, 15833,  2997,
         7336, 16820, 26634, 27112, 30154, 16203,  8358, 31456,  5825,
        25508, 29818,  3113,  8385, 28703, 16756, 18653, 32954, 28251,
        24752,  7898,  6512,  5981,  5672, 24991, 21950, 31873,  8241,
        31703, 27955, 14993, 15683, 27539, 19723, 29702, 28672, 28040,
        16313, 12436, 30736, 31290,  8416, 13528, 29127, 12192, 30086,
      

In [19]:
mean_squared_error(y_test_, preds, squared=False)

np.float64(6172.891868975004)

## XGBoost

In [20]:
import xgboost as xgb

In [21]:
clf = xgb.XGBRegressor( 
    n_estimators=1000,
    max_depth=10,
    learning_rate=0.3,
    n_jobs=-1,
    random_state=42
)
clf.fit(X_train_, y_train_)

In [22]:
preds = clf.predict(X_test_)

In [23]:
y_test_.compute(), preds

(array([29984, 28031, 28573,  7152,  6687, 20362,  4646, 27157, 15005,
        23293, 31709, 28448, 30771, 26071, 28365,  6574, 27334, 26800,
        23120, 23442, 31224, 30509, 19697, 22579, 26172, 30645,  7054,
        14323, 24818, 14495, 31529,  6469, 23765, 20701, 29026,  9561,
        16706, 27964, 21810, 27418, 25474, 28562, 26214, 27950, 29126,
        23083,  6195,  5575, 24996, 21356,  5470,  7265, 28996,  6566,
         6309, 26573, 27196, 19962, 22338, 15819, 25414, 17467, 22606,
        21464,  6541, 28000, 28569, 26489, 30849,  5150,  4794, 23870,
        27926,  3002, 18989, 30254, 19779, 27943, 28676, 15833,  2997,
         7336, 16820, 26634, 27112, 30154, 16203,  8358, 31456,  5825,
        25508, 29818,  3113,  8385, 28703, 16756, 18653, 32954, 28251,
        24752,  7898,  6512,  5981,  5672, 24991, 21950, 31873,  8241,
        31703, 27955, 14993, 15683, 27539, 19723, 29702, 28672, 28040,
        16313, 12436, 30736, 31290,  8416, 13528, 29127, 12192, 30086,
      

In [24]:
mean_squared_error(y_test_, preds, squared=False)

np.float64(3385.4289228766124)

## Scikit

In [25]:
from sklearn.linear_model import LinearRegression, Lasso, Ridge, SGDRegressor
from sklearn.ensemble import RandomForestRegressor

In [26]:
batch_size = 10
X_batches = [X_train_[i:i+batch_size] for i in range(0, len(X_train_), batch_size)]
y_batches = [y_train_[i:i+batch_size] for i in range(0, len(y_train_), batch_size)]

In [27]:
clf = SGDRegressor(
    max_iter=1000,
    tol=1e-3,
    random_state=42
)
for X_batch, y_batch in zip(X_batches, y_batches):
    clf.partial_fit(X_batch, y_batch)

In [28]:
preds = clf.predict(X_test_)

In [29]:
y_test_.compute(), preds

(array([29984, 28031, 28573,  7152,  6687, 20362,  4646, 27157, 15005,
        23293, 31709, 28448, 30771, 26071, 28365,  6574, 27334, 26800,
        23120, 23442, 31224, 30509, 19697, 22579, 26172, 30645,  7054,
        14323, 24818, 14495, 31529,  6469, 23765, 20701, 29026,  9561,
        16706, 27964, 21810, 27418, 25474, 28562, 26214, 27950, 29126,
        23083,  6195,  5575, 24996, 21356,  5470,  7265, 28996,  6566,
         6309, 26573, 27196, 19962, 22338, 15819, 25414, 17467, 22606,
        21464,  6541, 28000, 28569, 26489, 30849,  5150,  4794, 23870,
        27926,  3002, 18989, 30254, 19779, 27943, 28676, 15833,  2997,
         7336, 16820, 26634, 27112, 30154, 16203,  8358, 31456,  5825,
        25508, 29818,  3113,  8385, 28703, 16756, 18653, 32954, 28251,
        24752,  7898,  6512,  5981,  5672, 24991, 21950, 31873,  8241,
        31703, 27955, 14993, 15683, 27539, 19723, 29702, 28672, 28040,
        16313, 12436, 30736, 31290,  8416, 13528, 29127, 12192, 30086,
      

In [30]:
mean_squared_error(y_test_, preds, squared=False)

np.float64(11562747078980.26)