In [1]:
import numpy as np
import pandas as pd
from river import evaluate
from sail.models.river.forest import AdaptiveRandomForestRegressor
from sail.pipeline import SAILPipeline
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sail.transformers.river.preprocessing import StandardScaler
from sail.models.torch.rnn import RNNRegressor

#### Load Data


In [2]:
X = pd.read_csv("../datasets/HDWF2.csv")

y = X["power"]
X.drop(["power", "time"], axis=1, inplace=True)

#### Model Definition


In [3]:
learner_gru = RNNRegressor(
    input_units=12,
    output_units=1,
    hidden_units=100,
    n_hidden_layers=3,
    lr=0.001,
    cell_type="GRU",
    verbose=0
)

#### Create SAIL Pipeline


In [4]:
steps = [
    ("Imputer", SimpleImputer(missing_values=np.nan, strategy="mean")),
    ("standard_scalar", StandardScaler()),
    ("regressor", learner_gru),
]
sail_pipeline = SAILPipeline(steps=steps, scoring="R2", verbosity_level=1, verbosity_interval=100)

#### Train Test Split


In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=42)

#### Start Incremental Training


In [6]:
y_preds = []
y_true = []
batch_size = 10

for start in range(0, 5000, batch_size):

    end = start + batch_size

    X = X_train.iloc[start:end]
    y = y_train.iloc[start:end]

    sail_pipeline.partial_fit(X, y)


>>> Epoch: 1 | Samples Seen: 0 -------------------------------------------------------------------------------------





>>> Epoch: 100 | Samples Seen: 990 -------------------------------------------------------------------------------------





>>> Epoch: 200 | Samples Seen: 1990 -------------------------------------------------------------------------------------





>>> Epoch: 300 | Samples Seen: 2990 -------------------------------------------------------------------------------------





>>> Epoch: 400 | Samples Seen: 3990 -------------------------------------------------------------------------------------





>>> Epoch: 500 | Samples Seen: 4990 -------------------------------------------------------------------------------------




#### Save SAIL Pipeline


In [7]:
sail_pipeline.save(".")

[2023-09-25 19:41:25:136] - INFO - SAIL (PyTorch) - Model saved successfully.


'./sail_pipeline'

#### Load SAIL Pipeline


In [8]:
sail_new_pipeline = SAILPipeline.load(".")



In [9]:
y_preds = []
y_true = []
batch_size = 10

for start in range(5001, 10768, batch_size):

    end = start + batch_size

    X = X_train.iloc[start:end]
    y = y_train.iloc[start:end]

    sail_new_pipeline.partial_fit(X, y)

  epoch    train_loss     dur
-------  ------------  ------
    501       [36m70.9320[0m  0.0012
    502      586.7608  0.0012
    503      130.7252  0.0013
    504      117.8795  0.0012
    505      244.0257  0.0012
    506      159.6151  0.0015
    507      129.4393  0.0015
    508      459.9520  0.0013
    509      115.4775  0.0014
    510      183.6156  0.0018
    511      328.2495  0.0013
    512      352.2633  0.0013
    513      170.5602  0.0011
    514      167.8503  0.0014
    515      185.4431  0.0011
    516      315.9188  0.0018
    517      268.6629  0.0014
    518      260.9680  0.0013
    519      257.5792  0.0013
    520      271.9582  0.0013
    521      615.4316  0.0012
    522      252.1417  0.0013
    523      206.1322  0.0015
    524      396.0060  0.0011
    525      153.1999  0.0013
    526      470.2961  0.0012
    527      232.6331  0.0013
    528       87.4382  0.0012
    529      207.4537  0.0014
    530       [36m40.7682[0m  0.0015
    531      153.7080 



    600      443.7640  0.0026




    601      348.9725  0.0017
    602      150.3206  0.0011
    603      317.2322  0.0011
    604      223.5169  0.0013
    605      502.6994  0.0013
    606      138.5976  0.0011
    607      454.4296  0.0011
    608      388.8708  0.0013
    609      268.1564  0.0012
    610       90.3049  0.0014
    611      468.2427  0.0010
    612      534.9092  0.0010
    613       58.8251  0.0012
    614      274.3161  0.0011
    615      357.8184  0.0010
    616      349.5341  0.0012
    617      787.9042  0.0011
    618      148.8415  0.0012
    619      268.3844  0.0012
    620      664.3602  0.0011
    621      369.7575  0.0011
    622      251.5016  0.0011
    623      277.8998  0.0011
    624      108.7697  0.0012
    625      234.2333  0.0011
    626      303.1529  0.0014
    627      597.9280  0.0012
    628      259.2596  0.0013
    629       80.1692  0.0011
    630       79.5434  0.0012
    631      250.0045  0.0010
    632      280.6478  0.0011
    633      498.7519  0.0013
    634   




    637      257.5048  0.0016
    638      301.8109  0.0014
    639      218.6095  0.0011
    640      337.5108  0.0013
    641      199.8844  0.0013
    642      304.7297  0.0013
    643      217.9769  0.0012
    644      184.2070  0.0012
    645      183.1130  0.0015
    646       91.2832  0.0016
    647      400.7409  0.0013
    648      290.3233  0.0012
    649      209.1805  0.0010
    650      677.3544  0.0012
    651      255.7775  0.0013
    652      235.4156  0.0013
    653      153.5880  0.0015
    654       50.3079  0.0014
    655      668.4048  0.0011
    656      118.8406  0.0019
    657      108.9019  0.0016
    658      167.5697  0.0013
    659      141.1816  0.0012
    660      177.0083  0.0013
    661      419.6114  0.0013
    662      682.1870  0.0013
    663       97.0674  0.0062
    664      149.3628  0.0020
    665      610.4645  0.0012
    666      498.9125  0.0016
    667      355.9485  0.0018
    668      390.2417  0.0015
    669      237.9058  0.0017
    670   



    700       46.4959  0.0027




    701      100.5438  0.0018
    702      357.0038  0.0012
    703      597.3926  0.0013
    704      261.8702  0.0011
    705      425.2609  0.0011
    706      589.3289  0.0013
    707      286.1112  0.0014
    708      407.6514  0.0012
    709      196.1075  0.0014
    710      203.4310  0.0012
    711       67.4666  0.0012
    712      523.2233  0.0054
    713       97.5931  0.0021
    714      174.8418  0.0015
    715      278.0437  0.0013
    716      244.4631  0.0014
    717      228.7190  0.0015
    718      425.0188  0.0011
    719      179.7062  0.0014
    720      274.2137  0.0013
    721      546.3796  0.0012
    722      125.2840  0.0013
    723      434.5984  0.0013
    724      177.3108  0.0013
    725      175.7055  0.0013
    726      297.7449  0.0013
    727       59.4800  0.0013
    728      326.1380  0.0012
    729      530.6202  0.0014
    730      210.7135  0.0011





    731      101.3587  0.0016
    732       84.1113  0.0013
    733      458.8974  0.0010
    734      706.1378  0.0013
    735      114.6350  0.0011
    736      227.9302  0.0010
    737      197.9959  0.0013
    738      437.9813  0.0011
    739      291.1913  0.0011
    740      191.1304  0.0012
    741      432.9121  0.0012
    742      313.9677  0.0011
    743      632.6880  0.0011
    744      452.5744  0.0013
    745      104.3309  0.0012
    746      255.5489  0.0014
    747       70.6338  0.0012
    748      439.4312  0.0031
    749      231.8243  0.0027
    750      332.3038  0.0018
    751      452.7425  0.0013
    752      198.5415  0.0012
    753       99.7136  0.0013
    754      360.0357  0.0011
    755       79.6067  0.0012
    756      282.8529  0.0011
    757      303.9403  0.0012
    758      425.5388  0.0011
    759      313.6306  0.0011
    760      265.1704  0.0012
    761      380.7385  0.0011
    762      101.3082  0.0011
    763      184.7406  0.0011
    764   



    800      644.1229  0.0018




    801      302.8304  0.0017
    802      138.0804  0.0013
    803      210.5020  0.0011
    804      120.7412  0.0012
    805      144.8648  0.0014
    806      184.6453  0.0012
    807      283.4814  0.0011
    808      504.2775  0.0012
    809      463.8257  0.0010
    810      553.6019  0.0013
    811      220.4515  0.0013
    812      467.4425  0.0011
    813      567.6649  0.0013
    814      142.1114  0.0013
    815       71.8354  0.0014
    816      119.0166  0.0014
    817      209.8362  0.0014
    818      107.0396  0.0011
    819      268.5142  0.0013
    820     1061.1974  0.0013
    821      238.1390  0.0012
    822      299.5010  0.0012
    823      159.5830  0.0011
    824      460.8300  0.0011
    825       84.1445  0.0012
    826      396.1856  0.0011
    827      158.7469  0.0011
    828      163.0041  0.0011
    829      217.3508  0.0010
    830      150.1985  0.0014
    831       88.2353  0.0011
    832       86.7213  0.0012
    833      516.4671  0.0015
    834   




    837      172.9320  0.0015
    838      627.5628  0.0025
    839      134.3978  0.0014
    840      175.4955  0.0014
    841      142.2049  0.0013
    842      250.5656  0.0011
    843      226.7442  0.0013
    844      277.3780  0.0014
    845      145.9057  0.0010
    846      610.7829  0.0012
    847      486.2230  0.0010
    848      459.7150  0.0014
    849      656.1043  0.0015
    850      245.9995  0.0010
    851      120.3720  0.0011
    852      111.8147  0.0010
    853      241.5232  0.0012
    854      353.9413  0.0013
    855      135.5202  0.0012
    856      238.3205  0.0012
    857       86.8420  0.0012
    858      255.1384  0.0014
    859      386.4369  0.0011
    860      505.2766  0.0011
    861      209.3084  0.0014
    862      896.3034  0.0011
    863      164.4809  0.0014
    864      208.1432  0.0010
    865      490.3300  0.0013
    866      414.3300  0.0011
    867      236.3271  0.0010
    868      138.2134  0.0013
    869      212.8629  0.0010
    870   



    900      100.7197  0.0022




    901      287.5206  0.0016
    902       80.2746  0.0025
    903      203.3729  0.0014
    904      277.7468  0.0016
    905      273.3478  0.0012
    906      104.5830  0.0012
    907      273.8094  0.0010
    908      126.7740  0.0011
    909      153.7816  0.0011
    910      243.4235  0.0011
    911      177.2638  0.0011
    912      457.6922  0.0011
    913      264.3716  0.0012
    914      209.0754  0.0014
    915      260.8610  0.0013
    916      159.2984  0.0010
    917      295.6903  0.0012
    918      524.4028  0.0012
    919      160.6889  0.0011
    920      229.9141  0.0013
    921       74.5556  0.0011
    922      128.4685  0.0011
    923      691.3962  0.0011
    924      404.1835  0.0013
    925      383.0308  0.0024
    926      114.5426  0.0015
    927      220.7202  0.0017
    928      141.0471  0.0011
    929      101.3788  0.0021
    930       77.0520  0.0021





    931      641.7318  0.0019
    932      123.1069  0.0014
    933      127.9581  0.0013
    934      350.4210  0.0010
    935      101.3175  0.0013
    936      538.1301  0.0013
    937      184.9104  0.0010
    938      244.3264  0.0012
    939      332.4195  0.0011
    940      433.2874  0.0011
    941      136.2409  0.0011
    942      226.9537  0.0012
    943      144.9733  0.0014
    944      287.9662  0.0012
    945      205.5960  0.0013
    946       65.0842  0.0013
    947      444.3576  0.0010
    948      263.4344  0.0012
    949       66.0185  0.0012
    950      159.7831  0.0012
    951      352.3156  0.0012
    952      376.9499  0.0010
    953      220.4881  0.0012
    954      223.7531  0.0010
    955      218.6310  0.0011
    956      351.3353  0.0012
    957      104.1450  0.0012
    958      486.8969  0.0012
    959      171.9056  0.0010
    960      233.3692  0.0013
    961      318.0825  0.0014
    962      163.1332  0.0010
    963      214.9843  0.0011
    964   



   1000      424.4919  0.0026




   1001      331.5819  0.0013
   1002      209.0488  0.0012
   1003      355.5238  0.0011
   1004      161.6272  0.0013
   1005       [36m33.9827[0m  0.0011
   1006      129.6656  0.0013
   1007      327.0358  0.0010
   1008       97.8086  0.0012
   1009      155.4717  0.0012
   1010      662.3262  0.0011
   1011      169.8103  0.0012
   1012      177.8635  0.0011
   1013      121.1232  0.0011
   1014      112.3965  0.0013
   1015      650.4752  0.0009
   1016      106.8949  0.0012
   1017      285.3104  0.0012
   1018      250.4707  0.0012
   1019      174.1011  0.0013
   1020      311.9372  0.0013
   1021      221.0600  0.0012
   1022      280.2750  0.0012
   1023      237.4093  0.0013
   1024      198.7411  0.0011
   1025      255.5565  0.0011
   1026       72.1644  0.0011
   1027      179.8704  0.0011
   1028      643.7239  0.0014
   1029      362.3438  0.0011
   1030      191.2144  0.0013
   1031       86.5478  0.0014
   1032      729.7314  0.0011
   1033      317.8845  0.0014
 




   1037      306.2672  0.0015
   1038      594.6248  0.0015
   1039      242.4502  0.0013
   1040      506.4570  0.0012
   1041      217.6476  0.0036
   1042      130.7790  0.0022
   1043      142.6466  0.0018
   1044      500.3757  0.0013
   1045      244.4154  0.0011
   1046       41.6832  0.0013
   1047      413.6937  0.0011
   1048       [36m25.6242[0m  0.0014
   1049      140.9920  0.0012
   1050      153.2394  0.0012
   1051      183.5710  0.0012
   1052      147.4456  0.0012
   1053      280.1154  0.0011
   1054      351.7817  0.0016
   1055      107.6402  0.0015
   1056      252.4988  0.0014
   1057      279.8991  0.0014
   1058      306.8587  0.0012
   1059      156.1654  0.0011
   1060      503.7505  0.0013
   1061      165.3148  0.0017
   1062      205.8524  0.0011
   1063       71.7136  0.0013
   1064      149.1614  0.0012
   1065      297.7599  0.0012
   1066      329.0806  0.0014
   1067      402.5633  0.0013
   1068      333.3676  0.0014
   1069      451.2728  0.0013
 

#### Make Prediction on Hold out set


In [10]:
y_preds = []
y_true = []
batch_size = 100

for start in range(0, X_test.shape[0], batch_size):

    end = start + batch_size

    X = X_test.iloc[start:end]
    y = y_test.iloc[start:end]

    preds = sail_new_pipeline.predict(X)
    y_preds.extend(list(preds))
    y_true.extend(list(y))

#### Final Score


In [11]:
sail_new_pipeline.get_progressive_score

0.6285566617151723

In [12]:
sail_new_pipeline._scorer.score(y_true, y_preds)



0.7169338953897136

In [13]:
import plotly.express as px

df = pd.DataFrame({"y_true": y_true, "y_preds": y_preds}).head(100)
fig = px.line(df, y=["y_true", "y_preds"], title='')
fig.show()