In [1]:
#General imports
from typing import Any
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import copy
import os
from collections import defaultdict 



# Lightning imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.optim import Adam, AdamW
from torch.utils.data import Dataset, DataLoader

#Lightning imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
from torchmetrics import MetricCollection
from torchmetrics.regression import R2Score, MeanSquaredError, MeanAbsoluteError
from torchmetrics.classification import BinaryF1Score, BinaryAveragePrecision, BinaryAUROC, BinaryAccuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping


#sklearn imports
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.metrics import roc_auc_score, precision_score, f1_score, accuracy_score, log_loss
from sklearn.model_selection import train_test_split, cross_val_score

#RF plus imports
import imodels
from imodels.tree.rf_plus.rf_plus.rf_plus_models import _RandomForestPlus, RandomForestPlusRegressor, RandomForestPlusClassifier
from imodels.tree.rf_plus.rf_plus_prediction_models.aloocv import AloGLM
from imodels.tree.rf_plus.rf_plus_prediction_models.aloocv_regression import AloElasticNetRegressorCV, AloLOL2Regressor
from imodels.tree.rf_plus.rf_plus_prediction_models.aloocv_classification import AloGLMClassifier, AloLogisticElasticNetClassifierCV, AloSVCRidgeClassifier
from imodels.tree.rf_plus.rf_plus.MOE.moe_utils import TabularDataset, TreePlusExpert, AloTreePlusExpert, GatingNetwork

#Testing imports
import openml
import time

#
from xgboost import XGBRegressor


In [2]:
from imodels.tree.rf_plus.rf_plus.MOE.rfplus_MOE import RandomForestPlusMOE

In [None]:
#Load Data 
suite_id = 353
benchmark_suite = openml.study.get_suite(suite_id)
task_ids = benchmark_suite.tasks
task_id =  361235
random_state = 0
task = "regression"
seed_everything(random_state, workers=True)
print(f"Task ID: {task_id}")
task = openml.tasks.get_task(task_id)
dataset_id = task.dataset_id
dataset = openml.datasets.get_dataset(dataset_id)

# Split data into train, validation, and test sets
max_train = 500
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute,dataset_format="array")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X_train, y_train = copy.deepcopy(X_train)[:max_train], copy.deepcopy(y_train)[:max_train]
X_train_torch, X_val_torch, y_train_torch, y_val_torch = train_test_split(copy.deepcopy(X_train),copy.deepcopy(y_train), test_size=0.25)


#Get datasets and dataloaders
train_dataset = TabularDataset(torch.tensor(X_train_torch), torch.tensor(y_train_torch))
train_dataloader = DataLoader(train_dataset, batch_size=X_train_torch.shape[0])

val_dataset = TabularDataset(torch.tensor(X_val_torch), torch.tensor(y_val_torch))
val_dataloader = DataLoader(val_dataset, batch_size=X_val_torch.shape[0])

test_dataset = TabularDataset(torch.tensor(X_test), torch.tensor(y_test))
test_dataloader = DataLoader(test_dataset, batch_size=X_test.shape[0])

#fit RF plus model
if task == "classification":
    n_estimators = 256
    min_samples_leaf = 2
    max_epochs = 50
    max_features = "sqrt"
else:
    n_estimators = 256
    min_samples_leaf = 5
    max_epochs = 50
    max_features = 0.33

# rf_model = RandomForestClassifier(n_estimators=n_estimators, min_samples_leaf=min_samples_leaf, max_features=max_features,random_state=random_state)
# rf_model.fit(X_train, y_train)
# rfplus_model = RandomForestPlusClassifier(rf_model = rf_model,fit_on = "all")
# rfplus_model.fit(X_train,y_train,n_jobs=-1)

rf_model = RandomForestRegressor(n_estimators=n_estimators, min_samples_leaf=min_samples_leaf, max_features=max_features,random_state=random_state)
rf_model.fit(X_train, y_train)
rfplus_model = RandomForestPlusRegressor(rf_model = rf_model,fit_on = "all")
rfplus_model.fit(X_train,y_train,n_jobs=-1)

xgb_model = XGBRegressor(n_estimators=n_estimators, min_samples_leaf=min_samples_leaf, max_features=max_features,random_state=random_state)
xgb_model.fit(X_train, y_train)

# # RFplus_MOEmodel = RandomForestPlusRegressor(rf_model=rf_model,fit_on = "all")  
# # RFplus_MOEmodel.fit(X_train,y_train,n_jobs=-1)

#Define the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(dirpath='checkpoints',filename='best_model',monitor='val_loss',mode='min',save_top_k=1,save_last=True,verbose=True)
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min")


RFplus_MOE = RandomForestPlusMOE(rfplus_model=rfplus_model, input_dim=X.shape[1], criterion= nn.MSELoss(), use_loo = False, train_experts=  True) #BinaryF1ScoreBinaryF1Score
logger = TensorBoardLogger(f'RFMOE_task_{task_id}', name='RFMOE')
trainer = Trainer(accelerator="cpu",max_epochs=max_epochs,callbacks=[checkpoint_callback],logger=logger)
trainer.fit(RFplus_MOE, train_dataloader, val_dataloader)
test = trainer.test(dataloaders=test_dataloader)


if rfplus_model._task == "classification":
    class_metrics = [accuracy_score, f1_score, precision_score]
    prob_metrics = [log_loss,roc_auc_score]
    for m in class_metrics:
        print(m.__name__)
        print("RF model: ",m(y_test,rf_model.predict(X_test)))
        print("XGB model: ",m(y_test,xgb_model.predict(X_test)))
        print("RF+ Model without MOE: ",m(y_test,rfplus_model.predict(X_test)))
        print("\n")
    for m in prob_metrics:
        print(m.__name__)
        print("RF model: ",m(y_test,rf_model.predict_proba(X_test)[:,1]))
        print("XGB model: ",m(y_test,xgb_model.predict_proba(X_test)[:,1]))
        print("RF+ Model without MOE: ",m(y_test,rfplus_model.predict_proba(X_test)[:,1]))
        print("\n")
else:
    metrics = [mean_absolute_error,mean_squared_error, r2_score]
    for m in metrics:
        print(m.__name__)
        print("RF model: ",m(y_test,rf_model.predict(X_test)))
        print("XGB model: ",m(y_test,xgb_model.predict(X_test)))
        print("RF+ Model without MOE: ",m(y_test,rfplus_model.predict(X_test)))
        print("\n")





Seed set to 0
  exec(code_obj, self.user_global_ns, self.user_ns)
  dataset = get_dataset(task.dataset_id, *dataset_args, **get_dataset_kwargs)
  dataset = openml.datasets.get_dataset(dataset_id)
  X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute,dataset_format="array")


Task ID: 361235


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:   10.3s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:   26.1s
[Parallel(n_jobs=-1)]: Done 256 out of 256 | elapsed:   33.5s finished
Parameters: { "max_features", "min_samples_leaf" } are not used.

/scratch/users/zachrewolinski/conda/envs/mdi/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/users/zachrewolinski/conda/envs/mdi/lib/pyt ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
2025-03-04 12:59:12.436186: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has alre

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/scratch/users/zachrewolinski/conda/envs/mdi/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


x.shape: torch.Size([125, 5])


/scratch/users/zachrewolinski/conda/envs/mdi/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/scratch/users/zachrewolinski/conda/envs/mdi/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

  batch_torch_indices = torch.tensor(index) #training indices of elements in batch


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 0, global step 1: 'val_loss' reached 11.04171 (best 11.04171), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 1, global step 2: 'val_loss' reached 10.33861 (best 10.33861), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 2, global step 3: 'val_loss' reached 9.72729 (best 9.72729), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 3, global step 4: 'val_loss' reached 9.28729 (best 9.28729), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 4, global step 5: 'val_loss' reached 8.69362 (best 8.69362), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 5, global step 6: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 6, global step 7: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 7, global step 8: 'val_loss' reached 8.47607 (best 8.47607), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 8, global step 9: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 9, global step 10: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 10, global step 11: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 11, global step 12: 'val_loss' reached 7.73541 (best 7.73541), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 12, global step 13: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 13, global step 14: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 14, global step 15: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 15, global step 16: 'val_loss' reached 7.68358 (best 7.68358), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 16, global step 17: 'val_loss' reached 7.46242 (best 7.46242), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 17, global step 18: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 18, global step 19: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 19, global step 20: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 20, global step 21: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 21, global step 22: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 22, global step 23: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 23, global step 24: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 24, global step 25: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 25, global step 26: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 26, global step 27: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 27, global step 28: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 28, global step 29: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 29, global step 30: 'val_loss' reached 7.32401 (best 7.32401), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 30, global step 31: 'val_loss' reached 7.15650 (best 7.15650), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 31, global step 32: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 32, global step 33: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 33, global step 34: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 34, global step 35: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 35, global step 36: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 36, global step 37: 'val_loss' reached 7.13914 (best 7.13914), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 37, global step 38: 'val_loss' reached 7.06656 (best 7.06656), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 38, global step 39: 'val_loss' reached 6.98581 (best 6.98581), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 39, global step 40: 'val_loss' reached 6.88977 (best 6.88977), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 40, global step 41: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 41, global step 42: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 42, global step 43: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 43, global step 44: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 44, global step 45: 'val_loss' reached 6.67944 (best 6.67944), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 45, global step 46: 'val_loss' reached 6.66504 (best 6.66504), saving model to '/accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt' as top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 46, global step 47: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 47, global step 48: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 48, global step 49: 'val_loss' was not in top 1


x.shape: torch.Size([375, 5])


Validation: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([125, 5])


Epoch 49, global step 50: 'val_loss' was not in top 1
`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at /accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt
Loaded model weights from the checkpoint at /accounts/grad/zachrewolinski/research/imodels-experiments/feature_importance/moe/checkpoints/best_model.ckpt
/scratch/users/zachrewolinski/conda/envs/mdi/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

x.shape: torch.Size([301, 5])


mean_absolute_error
RF model:  3.3847517866559436
XGB model:  1.6686230276114125
RF+ Model without MOE:  2.785857318691402


mean_squared_error
RF model:  17.687629293494645
XGB model:  5.630796183102364
RF+ Model without MOE:  12.230324268182796


r2_score
RF model:  0.6239795370723867
XGB model:  0.8802951739722457
RF+ Model without MOE:  0.7399961228965639




In [None]:
# get the gating scores for each data point
gating_scores = RFplus_MOE.get_gating_scores(X_test)
