In [1]:
import numpy as np
import pandas as pd
from river import evaluate
from sail.models.river.forest import AdaptiveRandomForestRegressor
from sail.pipeline import SAILPipeline
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sail.transformers.river.preprocessing import StandardScaler
from sail.models.torch.rnn import RNNRegressor

#### Load Data


In [2]:
X = pd.read_csv("../datasets/HDWF2.csv")

y = X["power"]
X.drop(["power", "time"], axis=1, inplace=True)

#### Model Definition


In [3]:
learner_gru = RNNRegressor(
    input_units=12,
    output_units=1,
    hidden_units=100,
    n_hidden_layers=3,
    lr=0.001,
    cell_type="GRU",
    verbose=0
)

#### Create SAIL Pipeline


In [4]:
steps = [
    ("Imputer", SimpleImputer(missing_values=np.nan, strategy="mean")),
    ("standard_scalar", StandardScaler()),
    ("regressor", learner_gru),
]
sail_pipeline = SAILPipeline(steps=steps, scoring="R2", verbosity_level=1, verbosity_interval=100)

#### Train Test Split


In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=42)

#### Start Incremental Training


In [6]:
y_preds = []
y_true = []
batch_size = 10

for start in range(0, 5000, batch_size):

    end = start + batch_size

    X = X_train.iloc[start:end]
    y = y_train.iloc[start:end]

    sail_pipeline.partial_fit(X, y)


    
>> Epoch: 1 | Samples Seen: 0 -------------------------------------------------------------------------------------

    
>> Epoch: 100 | Samples Seen: 990 -------------------------------------------------------------------------------------

    
>> Epoch: 200 | Samples Seen: 1990 -------------------------------------------------------------------------------------

    
>> Epoch: 300 | Samples Seen: 2990 -------------------------------------------------------------------------------------

    
>> Epoch: 400 | Samples Seen: 3990 -------------------------------------------------------------------------------------

    
>> Epoch: 500 | Samples Seen: 4990 -------------------------------------------------------------------------------------


#### Save SAIL Pipeline


In [7]:
sail_pipeline.save(".")

[2023-10-02 00:50:33:111] - INFO - SAIL (PyTorch) - Model saved successfully.


'./sail_pipeline'

#### Load SAIL Pipeline


In [8]:
sail_new_pipeline = SAILPipeline.load(".")



In [9]:
y_preds = []
y_true = []
batch_size = 10

for start in range(5001, 10768, batch_size):

    end = start + batch_size

    X = X_train.iloc[start:end]
    y = y_train.iloc[start:end]

    sail_new_pipeline.partial_fit(X, y)

  epoch    train_loss     dur
-------  ------------  ------
    501       [36m72.9539[0m  0.0011
    502      566.4787  0.0011
    503      132.8276  0.0011
    504      115.9440  0.0013
    505      245.5162  0.0012
    506      170.1505  0.0013
    507      136.3387  0.0011
    508      441.2717  0.0012
    509      118.4934  0.0013
    510      180.9164  0.0010
    511      327.4562  0.0012
    512      367.7494  0.0010
    513      180.3380  0.0011
    514      165.7492  0.0010
    515      192.4044  0.0011
    516      302.7406  0.0013
    517      273.7800  0.0010
    518      251.5001  0.0010
    519      248.4970  0.0010
    520      276.3235  0.0012
    521      599.9944  0.0010
    522      254.7390  0.0012
    523      208.6125  0.0011
    524      398.7694  0.0011
    525      144.0548  0.0013
    526      478.4481  0.0013
    527      229.7571  0.0014
    528      100.2355  0.0011
    529      203.9711  0.0011
    530       [36m39.9040[0m  0.0010
    531      139.7032 


    
>> Epoch: 600 | Samples Seen: 5990 -------------------------------------------------------------------------------------

    600      430.7782  0.0042




    601      336.8963  0.0023
    602      142.4220  0.0013
    603      260.9496  0.0013
    604      211.0201  0.0012
    605      482.9416  0.0010
    606      151.5855  0.0012
    607      447.2303  0.0012
    608      396.9992  0.0011
    609      278.3117  0.0013
    610       88.2425  0.0011
    611      451.3175  0.0011
    612      522.4875  0.0013
    613       54.5248  0.0012
    614      265.1602  0.0012
    615      338.1915  0.0012
    616      363.1442  0.0009
    617      800.0979  0.0012
    618      153.4140  0.0011
    619      273.0910  0.0010
    620      630.5610  0.0012
    621      376.4488  0.0014
    622      248.1106  0.0014
    623      274.0689  0.0010
    624      111.2421  0.0011
    625      232.5437  0.0011
    626      294.4535  0.0011
    627      604.3905  0.0011
    628      248.6763  0.0012
    629       82.2511  0.0011
    630       80.4971  0.0010
    631      266.3600  0.0011
    632      281.4179  0.0010
    633      458.5069  0.0010
    634   




    636      228.6299  0.0010
    637      260.6375  0.0011
    638      285.9547  0.0011
    639      215.5296  0.0011
    640      331.5695  0.0012
    641      181.8684  0.0010
    642      296.2509  0.0009
    643      225.8499  0.0012
    644      185.2162  0.0028
    645      181.9296  0.0139
    646      104.3466  0.0015
    647      411.1776  0.0015
    648      289.3484  0.0010
    649      217.6430  0.0012
    650      682.6847  0.0011
    651      246.1540  0.0010
    652      235.8898  0.0011
    653      160.3121  0.0012
    654       48.5780  0.0011
    655      668.5575  0.0009
    656      105.2803  0.0011
    657      111.5759  0.0011
    658      168.1333  0.0011
    659      141.1505  0.0011
    660      183.1449  0.0011
    661      422.4329  0.0010
    662      680.1699  0.0011
    663       97.1034  0.0012
    664      137.3518  0.0012
    665      590.4204  0.0012
    666      484.4846  0.0010
    667      401.9491  0.0012
    668      383.9173  0.0013
    669   


    
>> Epoch: 700 | Samples Seen: 6990 -------------------------------------------------------------------------------------

    700       51.5927  0.0023




    701       94.7694  0.0015
    702      346.2664  0.0010
    703      611.2337  0.0011
    704      259.5265  0.0011
    705      408.0291  0.0013
    706      599.0249  0.0011
    707      270.8023  0.0013
    708      401.4578  0.0010
    709      176.3421  0.0013
    710      180.8723  0.0010
    711       69.3829  0.0011
    712      526.5890  0.0011
    713      108.9454  0.0010
    714      176.0854  0.0010
    715      262.8309  0.0012
    716      226.5933  0.0012
    717      244.9335  0.0012
    718      432.9828  0.0010
    719      197.5834  0.0010
    720      260.1613  0.0012
    721      517.3994  0.0011
    722      118.3884  0.0010
    723      447.7459  0.0010
    724      168.6021  0.0012
    725      165.9494  0.0014
    726      307.0113  0.0010
    727       58.7620  0.0010
    728      338.6602  0.0012
    729      534.6696  0.0012
    730      199.4066  0.0013
    731      108.5594  0.0012
    732       94.0160  0.0011





    733      459.7415  0.0015
    734      696.5851  0.0013
    735      117.2498  0.0013
    736      244.2136  0.0011
    737      191.4494  0.0012
    738      465.8113  0.0011
    739      282.7079  0.0014
    740      207.5585  0.0011
    741      429.1466  0.0011
    742      288.4987  0.0013
    743      630.6783  0.0012
    744      460.5118  0.0013
    745      101.2465  0.0014
    746      265.9203  0.0010
    747       74.7211  0.0010
    748      375.6874  0.0011
    749      231.3814  0.0010
    750      345.0389  0.0009
    751      440.7422  0.0011
    752      195.2692  0.0012
    753       93.9874  0.0011
    754      385.5423  0.0010
    755       81.9983  0.0011
    756      259.4724  0.0012
    757      272.3039  0.0012
    758      433.9697  0.0010
    759      321.1351  0.0013
    760      246.1304  0.0012
    761      385.7003  0.0010
    762      105.0182  0.0012
    763      179.9996  0.0011
    764       77.2933  0.0012
    765      315.8700  0.0010
    766  


    
>> Epoch: 800 | Samples Seen: 7990 -------------------------------------------------------------------------------------

    800      619.3714  0.0024




    801      318.8587  0.0016
    802      139.9602  0.0010
    803      200.1201  0.0009
    804      121.2439  0.0010
    805      154.8347  0.0012
    806      183.4516  0.0012
    807      289.4830  0.0013
    808      465.3514  0.0011
    809      464.9254  0.0013
    810      536.6431  0.0013
    811      230.0855  0.0011
    812      439.1803  0.0010
    813      565.8610  0.0010
    814      149.3311  0.0010
    815       74.4979  0.0012
    816      118.0696  0.0011
    817      210.7413  0.0011
    818      115.1063  0.0010
    819      274.1724  0.0011
    820     1062.6962  0.0011
    821      243.4249  0.0011
    822      310.4833  0.0022
    823      140.9797  0.0059
    824      464.9281  0.0043
    825       80.2534  0.0013
    826      399.0772  0.0042
    827      167.3619  0.0010
    828      158.6154  0.0011
    829      202.4902  0.0010
    830      140.7419  0.0009
    831       94.7608  0.0010
    832       86.6501  0.0012
    833      515.3177  0.0011
    834   




    835      223.6438  0.0011
    836      213.8431  0.0011
    837      180.2852  0.0012
    838      613.5580  0.0010
    839      140.7603  0.0010
    840      180.6964  0.0012
    841      137.8001  0.0010
    842      258.0959  0.0011
    843      229.5401  0.0012
    844      263.4631  0.0012
    845      147.5313  0.0011
    846      600.3036  0.0012
    847      464.9890  0.0012
    848      488.6269  0.0011
    849      665.9220  0.0012
    850      253.7405  0.0011
    851      120.5912  0.0011
    852      108.8511  0.0010
    853      242.0821  0.0012
    854      346.2993  0.0010
    855      140.9605  0.0010
    856      229.4375  0.0011
    857       80.6625  0.0010
    858      260.4985  0.0011
    859      376.0689  0.0010
    860      497.8981  0.0010
    861      211.7604  0.0010
    862      905.5911  0.0010
    863      161.7211  0.0011
    864      195.9790  0.0010
    865      471.2103  0.0010
    866      404.7133  0.0012
    867      230.4187  0.0010
    868   


    
>> Epoch: 900 | Samples Seen: 8990 -------------------------------------------------------------------------------------

    900      100.9520  0.0023




    901      284.9699  0.0015
    902       70.0352  0.0010
    903      206.0066  0.0011
    904      284.6319  0.0011
    905      276.6690  0.0010
    906      101.3370  0.0010
    907      251.8748  0.0011
    908      125.8174  0.0019
    909      140.8578  0.0029
    910      220.3954  0.0056
    911      183.5663  0.0012
    912      443.7536  0.0011
    913      260.6442  0.0011
    914      222.4535  0.0011
    915      245.0365  0.0014
    916      169.0226  0.0012
    917      289.9501  0.0011
    918      533.3428  0.0012
    919      148.2187  0.0011
    920      224.9592  0.0011
    921       73.7206  0.0012
    922      124.1651  0.0013
    923      656.4856  0.0011
    924      404.0936  0.0012
    925      373.0768  0.0011
    926      116.7580  0.0010
    927      220.1784  0.0014
    928      127.0661  0.0011
    929      106.1712  0.0011
    930       79.4085  0.0012
    931      642.7683  0.0010
    932      123.4561  0.0013
    933      137.0844  0.0013





    934      392.2577  0.0015
    935      114.5403  0.0012
    936      538.0921  0.0012
    937      186.8890  0.0011
    938      243.4598  0.0010
    939      341.0057  0.0012
    940      409.9100  0.0010
    941      131.0300  0.0011
    942      214.0565  0.0010
    943      129.0679  0.0009
    944      293.6115  0.0013
    945      212.8603  0.0010
    946       65.6198  0.0010
    947      448.0512  0.0012
    948      255.5079  0.0011
    949       63.8905  0.0012
    950      160.9198  0.0011
    951      360.1900  0.0010
    952      368.5373  0.0012
    953      213.8929  0.0012
    954      220.0767  0.0013
    955      225.3640  0.0010
    956      333.8830  0.0011
    957       99.0400  0.0012
    958      487.6265  0.0011
    959      176.7005  0.0013
    960      217.6396  0.0012
    961      321.4305  0.0012
    962      172.1451  0.0010
    963      218.4517  0.0010
    964      543.8469  0.0010
    965       92.5721  0.0010
    966      124.9017  0.0009
    967   


    
>> Epoch: 1000 | Samples Seen: 9990 -------------------------------------------------------------------------------------

   1000      381.6263  0.0027




   1001      312.1396  0.0053
   1002      217.7592  0.0036
   1003      361.2621  0.0014
   1004      162.3637  0.0012
   1005       40.6156  0.0011
   1006      119.1513  0.0013
   1007      302.9351  0.0012
   1008      113.5313  0.0011
   1009      146.1936  0.0012
   1010      632.6141  0.0010
   1011      176.1833  0.0011
   1012      173.5949  0.0011
   1013      113.5121  0.0012
   1014       96.1610  0.0012
   1015      614.6171  0.0012
   1016      106.5000  0.0011
   1017      307.6331  0.0010
   1018      255.4276  0.0010
   1019      177.5762  0.0010
   1020      298.8627  0.0010
   1021      233.9860  0.0010
   1022      292.8723  0.0012
   1023      245.9397  0.0012
   1024      196.9597  0.0013
   1025      258.2346  0.0011
   1026       71.7609  0.0011
   1027      196.3288  0.0012
   1028      587.0775  0.0010
   1029      319.2612  0.0011
   1030      185.8290  0.0010
   1031       80.0009  0.0010
   1032      699.5176  0.0010
   1033      333.9524  0.0010
   1034   




   1035      385.1044  0.0012
   1036      221.8911  0.0012
   1037      311.7576  0.0011
   1038      599.7544  0.0011
   1039      225.6103  0.0011
   1040      460.1644  0.0010
   1041      215.6809  0.0010
   1042      120.1493  0.0012
   1043      164.9901  0.0011
   1044      508.0969  0.0009
   1045      263.8515  0.0012
   1046       40.1944  0.0011
   1047      413.7855  0.0012
   1048       [36m21.7896[0m  0.0011
   1049      134.2711  0.0010
   1050      126.0769  0.0011
   1051      183.9503  0.0011
   1052      139.6402  0.0010
   1053      307.5391  0.0011
   1054      362.5526  0.0011
   1055      110.3634  0.0010
   1056      259.9938  0.0011
   1057      268.5780  0.0010
   1058      322.8316  0.0012
   1059      152.8616  0.0010
   1060      509.4525  0.0010
   1061      143.2292  0.0011
   1062      228.3846  0.0012
   1063       79.0228  0.0011
   1064      152.2517  0.0010
   1065      253.1306  0.0009
   1066      272.1347  0.0010
   1067      385.6739  0.0009
 

#### Make Prediction on Hold out set


In [10]:
y_preds = []
y_true = []
batch_size = 100

for start in range(0, X_test.shape[0], batch_size):

    end = start + batch_size

    X = X_test.iloc[start:end]
    y = y_test.iloc[start:end]

    preds = sail_new_pipeline.predict(X)
    y_preds.extend(list(preds))
    y_true.extend(list(y))

#### Final Score


In [11]:
sail_new_pipeline.get_progressive_score

0.635659857151349

In [13]:
import plotly.express as px

df = pd.DataFrame({"y_true": y_true, "y_preds": y_preds}).head(100)
fig = px.line(df, y=["y_true", "y_preds"], title='')
fig.show()