In [1]:
import numpy as np
import pandas as pd
from river import evaluate
from sail.models.river.forest import AdaptiveRandomForestRegressor
from sail.pipeline import SAILPipeline
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sail.transformers.river.preprocessing import StandardScaler
from sail.models.torch.rnn import RNNRegressor

#### Load Data


In [2]:
X = pd.read_csv("../../datasets/HDWF2.csv")

y = X["power"]
X.drop(["power", "time"], axis=1, inplace=True)

#### Model Definition


In [3]:
learner_gru = RNNRegressor(
    input_units=12,
    output_units=1,
    hidden_units=100,
    n_hidden_layers=3,
    lr=0.001,
    cell_type="GRU",
    verbose=0,
)

#### Create SAIL Pipeline


In [4]:
steps = [
    ("Imputer", SimpleImputer(missing_values=np.nan, strategy="mean")),
    ("standard_scalar", StandardScaler()),
    ("regressor", learner_gru),
]
sail_pipeline = SAILPipeline(
    steps=steps, scoring=["R2"], verbosity_level=1, verbosity_interval=100
)

#### Train Test Split


In [5]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.30, random_state=42
)

#### Start Incremental Training


In [6]:
y_preds = []
y_true = []
batch_size = 10

for start in range(0, 5000, batch_size):
    end = start + batch_size

    X = X_train.iloc[start:end]
    y = y_train.iloc[start:end]

    sail_pipeline.partial_fit(X, y)


    
>> Epoch: 1 | Samples Seen: 0 -------------------------------------------------------------------------------------

    
>> Epoch: 100 | Samples Seen: 990 -------------------------------------------------------------------------------------

    
>> Epoch: 200 | Samples Seen: 1990 -------------------------------------------------------------------------------------

    
>> Epoch: 300 | Samples Seen: 2990 -------------------------------------------------------------------------------------

    
>> Epoch: 400 | Samples Seen: 3990 -------------------------------------------------------------------------------------

    
>> Epoch: 500 | Samples Seen: 4990 -------------------------------------------------------------------------------------


#### Save SAIL Pipeline


In [7]:
sail_pipeline.save(".")

[2023-12-12 17:48:58:666] - INFO - SAIL (PyTorch) - Model saved successfully.


'./sail_pipeline'

#### Load SAIL Pipeline


In [8]:
sail_new_pipeline = SAILPipeline.load(".")



In [9]:
y_preds = []
y_true = []
batch_size = 10

for start in range(5001, 10768, batch_size):
    end = start + batch_size

    X = X_train.iloc[start:end]
    y = y_train.iloc[start:end]

    sail_new_pipeline.partial_fit(X, y)

  epoch    train_loss     dur
-------  ------------  ------
    501       [36m71.8888[0m  0.0018
    502      566.5175  0.0015
    503      131.7938  0.0012
    504      114.2689  0.0013
    505      248.9749  0.0013
    506      165.8501  0.0012
    507      131.9454  0.0014
    508      439.9035  0.0012
    509      119.8833  0.0012
    510      183.1458  0.0014
    511      341.1264  0.0011
    512      357.0931  0.0010
    513      175.2396  0.0011
    514      166.6846  0.0012
    515      185.7204  0.0010
    516      311.2947  0.0011
    517      263.8472  0.0012
    518      263.6564  0.0010
    519      250.8266  0.0011
    520      263.2386  0.0010
    521      611.0502  0.0012
    522      255.3256  0.0011
    523      201.8802  0.0013
    524      402.3539  0.0020
    525      145.4151  0.0011
    526      472.4382  0.0016
    527      230.3152  0.0013
    528       95.9053  0.0012
    529      201.5382  0.0013
    530       [36m40.2959[0m  0.0011
    531      150.1678 


    
>> Epoch: 600 | Samples Seen: 5990 -------------------------------------------------------------------------------------

    600      440.3103  0.0017




    601      349.0640  0.0040
    602      148.2638  0.0014
    603      418.7603  0.0011
    604      224.7227  0.0015
    605      493.5359  0.0010
    606      148.5124  0.0010
    607      443.9808  0.0013
    608      376.0975  0.0010
    609      286.0743  0.0011
    610       87.8358  0.0012
    611      458.2039  0.0012
    612      509.4558  0.0010
    613       55.7207  0.0011
    614      274.4852  0.0010
    615      374.0755  0.0010
    616      346.5542  0.0012
    617      791.0632  0.0012
    618      153.6690  0.0013
    619      272.5239  0.0010
    620      640.4315  0.0010
    621      360.7834  0.0010
    622      255.1215  0.0010
    623      287.2895  0.0010
    624      110.7721  0.0010
    625      228.7529  0.0010
    626      309.4091  0.0010
    627      585.6977  0.0010
    628      260.3463  0.0010
    629       75.5383  0.0013
    630       80.2034  0.0011
    631      264.2268  0.0010
    632      267.0144  0.0010
    633      461.2321  0.0011
    634   




    640      346.7634  0.0014
    641      205.3755  0.0010
    642      299.3793  0.0011
    643      216.1162  0.0011
    644      185.8291  0.0013
    645      179.3683  0.0009
    646      100.5648  0.0010
    647      399.3815  0.0011
    648      291.6077  0.0010
    649      214.4273  0.0010
    650      670.6654  0.0010
    651      244.8009  0.0011
    652      238.6045  0.0010
    653      157.2081  0.0010
    654       46.4317  0.0011
    655      656.1580  0.0012
    656      115.8816  0.0010
    657      110.9500  0.0010
    658      169.2008  0.0012
    659      137.0777  0.0010
    660      181.3766  0.0010
    661      420.1706  0.0009
    662      662.2025  0.0012
    663       97.8384  0.0012
    664      144.5750  0.0012
    665      602.7516  0.0012
    666      475.0326  0.0011
    667      360.8189  0.0010
    668      393.9659  0.0010
    669      240.7316  0.0011
    670      543.6960  0.0011
    671       61.9285  0.0010
    672      226.9124  0.0015
    673   


    
>> Epoch: 700 | Samples Seen: 6990 -------------------------------------------------------------------------------------

    700       43.7839  0.0018




    701       93.9216  0.0012
    702      356.7786  0.0014
    703      595.2950  0.0012
    704      252.1884  0.0012
    705      420.1998  0.0010
    706      593.7235  0.0013
    707      279.2716  0.0011
    708      398.4638  0.0010
    709      191.7598  0.0014
    710      191.1945  0.0011
    711       64.0211  0.0010
    712      525.0099  0.0011
    713      109.5452  0.0012
    714      167.2546  0.0010
    715      277.2362  0.0010
    716      227.8845  0.0011
    717      227.8096  0.0011
    718      424.6374  0.0010
    719      183.3428  0.0009
    720      275.0434  0.0012
    721      531.9860  0.0014
    722      132.1814  0.0010
    723      440.2501  0.0010
    724      181.9582  0.0011
    725      160.2368  0.0010
    726      302.0275  0.0011
    727       57.2387  0.0013
    728      324.3923  0.0013
    729      533.2961  0.0011
    730      208.3788  0.0011
    731      102.8284  0.0011
    732       94.8097  0.0015
    733      458.0962  0.0012
    734   




    739      296.2462  0.0013
    740      193.5154  0.0012
    741      436.1546  0.0010
    742      306.2261  0.0010
    743      631.2711  0.0010
    744      461.2880  0.0011
    745       99.6175  0.0010
    746      262.2674  0.0010
    747       77.2267  0.0011
    748      410.6066  0.0010
    749      224.6912  0.0011
    750      338.0707  0.0010
    751      443.6489  0.0011
    752      184.4331  0.0011
    753       98.7701  0.0011
    754      377.6933  0.0012
    755       78.2120  0.0011
    756      273.8278  0.0010
    757      284.1763  0.0011
    758      426.9717  0.0010
    759      318.6567  0.0009
    760      251.5077  0.0009
    761      372.7180  0.0011
    762      111.6412  0.0010
    763      179.5060  0.0010
    764       79.7558  0.0010
    765      301.2048  0.0011
    766      206.0454  0.0011
    767      311.8703  0.0010
    768      233.2793  0.0012
    769      216.7051  0.0013
    770      297.6706  0.0011
    771      141.9313  0.0010
    772   


    
>> Epoch: 800 | Samples Seen: 7990 -------------------------------------------------------------------------------------

    800      634.9795  0.0019




    801      311.1047  0.0012
    802      140.6671  0.0015
    803      205.1393  0.0021
    804      120.9071  0.0013
    805      148.9703  0.0013
    806      183.7708  0.0014
    807      288.2090  0.0011
    808      488.9022  0.0014
    809      474.5132  0.0013
    810      546.7972  0.0010
    811      229.5727  0.0010
    812      442.1812  0.0013
    813      575.5957  0.0012
    814      146.7220  0.0010
    815       74.1262  0.0010
    816      115.3224  0.0014
    817      210.6419  0.0011
    818      111.6284  0.0011
    819      267.1926  0.0015
    820     1024.3076  0.0011
    821      226.1991  0.0011
    822      299.2641  0.0011
    823      160.6245  0.0011
    824      464.4871  0.0010
    825       86.7547  0.0011
    826      397.3834  0.0012
    827      157.9718  0.0010
    828      157.3386  0.0010
    829      210.7821  0.0010
    830      143.9144  0.0015
    831       95.5230  0.0012
    832       81.3349  0.0010
    833      510.6271  0.0013
    834   




    837      187.3599  0.0017
    838      628.4269  0.0011
    839      140.8352  0.0011
    840      180.4830  0.0012
    841      142.1272  0.0010
    842      247.5705  0.0010
    843      244.1903  0.0011
    844      277.5549  0.0010
    845      144.3680  0.0010
    846      628.9091  0.0010
    847      460.6104  0.0012
    848      459.8906  0.0010
    849      650.7899  0.0011
    850      242.5114  0.0010
    851      125.2834  0.0012
    852      116.7427  0.0013
    853      244.9520  0.0011
    854      362.7706  0.0010
    855      132.3168  0.0013
    856      236.7959  0.0010
    857       79.9267  0.0011
    858      244.6731  0.0014
    859      364.3484  0.0013
    860      495.6691  0.0018
    861      201.7173  0.0012
    862      913.9410  0.0012
    863      170.9771  0.0013
    864      204.0663  0.0014
    865      476.0432  0.0014
    866      407.6344  0.0011
    867      235.9461  0.0014
    868      142.5303  0.0012
    869      211.9891  0.0011
    870   


    
>> Epoch: 900 | Samples Seen: 8990 -------------------------------------------------------------------------------------

    900      101.4082  0.0018




    901      286.9297  0.0018
    902       76.5076  0.0011
    903      209.5379  0.0010
    904      247.8264  0.0013
    905      271.0085  0.0012
    906      107.1674  0.0010
    907      256.2899  0.0011
    908      124.0077  0.0010
    909      152.4185  0.0010
    910      226.1337  0.0011
    911      189.8914  0.0013
    912      449.5698  0.0012
    913      253.4840  0.0011
    914      221.0208  0.0010
    915      259.5727  0.0013
    916      159.2638  0.0009
    917      287.8112  0.0009
    918      509.5189  0.0012
    919      155.2549  0.0011
    920      222.0626  0.0010
    921       61.9921  0.0010
    922      113.8377  0.0011
    923      672.6208  0.0011
    924      404.1887  0.0011
    925      380.6064  0.0010
    926      112.1247  0.0011
    927      214.8211  0.0010
    928      132.5059  0.0009
    929      110.3840  0.0012
    930       84.6983  0.0011
    931      663.6503  0.0010
    932      127.3226  0.0014
    933      131.6179  0.0028
    934   




    939      327.7510  0.0014
    940      444.9112  0.0012
    941      134.5066  0.0011
    942      221.5326  0.0011
    943      127.3707  0.0010
    944      291.4095  0.0010
    945      202.3688  0.0011
    946       63.9056  0.0010
    947      438.1408  0.0011
    948      259.5609  0.0012
    949       71.3134  0.0011
    950      153.4194  0.0010
    951      361.7773  0.0010
    952      373.4185  0.0011
    953      231.0882  0.0011
    954      217.1312  0.0010
    955      227.2780  0.0010
    956      351.4692  0.0011
    957      103.0241  0.0010
    958      490.7227  0.0009
    959      184.6233  0.0010
    960      223.3168  0.0012
    961      322.1734  0.0010
    962      158.7211  0.0011
    963      221.2236  0.0010
    964      554.4438  0.0011
    965      104.4600  0.0010
    966      123.5109  0.0010
    967       85.5661  0.0012
    968       47.3755  0.0012
    969      614.3285  0.0010
    970      203.6591  0.0010
    971      375.5453  0.0012
    972   


    
>> Epoch: 1000 | Samples Seen: 9990 -------------------------------------------------------------------------------------

   1000      401.9387  0.0020




   1001      315.9359  0.0013
   1002      216.0316  0.0010
   1003      361.9715  0.0011
   1004      156.3135  0.0016
   1005       [36m37.7104[0m  0.0035
   1006      125.2292  0.0011
   1007      286.9528  0.0012
   1008      119.6226  0.0012
   1009      143.3138  0.0011
   1010      663.3904  0.0011
   1011      180.2863  0.0012
   1012      184.1097  0.0011
   1013      112.3237  0.0010
   1014      106.8147  0.0010
   1015      614.5563  0.0011
   1016      107.2000  0.0010
   1017      315.1335  0.0010
   1018      255.0704  0.0011
   1019      174.6810  0.0012
   1020      312.7851  0.0010
   1021      227.2909  0.0010
   1022      300.5590  0.0013
   1023      231.7662  0.0012
   1024      192.6102  0.0011
   1025      252.9846  0.0010
   1026       69.8974  0.0013
   1027      194.1710  0.0011
   1028      611.2283  0.0010
   1029      309.7976  0.0013
   1030      198.2696  0.0010
   1031       77.0998  0.0011
   1032      689.9325  0.0011
   1033      324.5117  0.0015
 




   1040      468.0385  0.0016
   1041      217.8708  0.0011
   1042      115.3591  0.0010
   1043      142.9978  0.0011
   1044      512.0624  0.0011
   1045      260.4395  0.0011
   1046       [36m37.6605[0m  0.0010
   1047      392.6596  0.0015
   1048       [36m22.9526[0m  0.0011
   1049      155.6678  0.0011
   1050      129.3821  0.0011
   1051      181.1139  0.0011
   1052      138.3929  0.0012
   1053      286.3607  0.0011
   1054      353.3319  0.0012
   1055      112.5617  0.0015
   1056      231.1056  0.0011
   1057      265.6822  0.0010
   1058      282.5490  0.0012
   1059      152.9145  0.0013
   1060      493.5140  0.0010
   1061      155.3964  0.0013
   1062      213.4594  0.0011
   1063       83.6210  0.0011
   1064      157.3374  0.0010
   1065      245.5015  0.0011
   1066      302.1479  0.0010
   1067      372.0648  0.0014
   1068      315.1951  0.0011
   1069      443.5312  0.0012
   1070      384.0384  0.0011
   1071      265.9665  0.0012
   1072      205.3786 

#### Make Prediction on Hold out set


In [10]:
y_preds = []
y_true = []
batch_size = 100

for start in range(0, X_test.shape[0], batch_size):
    end = start + batch_size

    X = X_test.iloc[start:end]
    y = y_test.iloc[start:end]

    preds = sail_new_pipeline.predict(X)
    y_preds.extend(list(preds))
    y_true.extend(list(y))

#### Final Score


In [11]:
sail_new_pipeline.get_progressive_score

0.6283018631213579

In [12]:
import plotly.express as px

df = pd.DataFrame({"y_true": y_true, "y_preds": y_preds}).head(100)
fig = px.line(df, y=["y_true", "y_preds"], title="")
fig.show()