# Inference on an upcoming dataset

In this part, we will simulate the real deployment of the package and make inference on an upcoming dataset.

## Training models

Similar to the first example, we initialize a `Trainer` and model bases.

In [1]:
import torch
from tabensemb.trainer import Trainer
from tabensemb.model import *
import tabensemb
import os

prefix = "../../../../"
tabensemb.setting["default_output_path"] = prefix + "output"
tabensemb.setting["default_config_path"] = prefix + "configs"
tabensemb.setting["default_data_path"] = prefix + "data"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

trainer = Trainer(device=device)
trainer.load_config("sample")

from tabensemb.utils import Logging
log = Logging()
log.enter(os.path.join(trainer.project_root, "log.txt"))

trainer.load_data()

models = [
    PytorchTabular(trainer, model_subset=["Category Embedding"]),
    WideDeep(trainer, model_subset=["TabMlp"]),
    AutoGluon(trainer, model_subset=["Linear Regression"]),
]
trainer.add_modelbases(models)

trainer.train(verbose=False, stderr_to_stdout=True)
trainer.get_leaderboard()

Using cuda device
Project will be saved to ../../../../output/sample/2023-07-29-17-20-51-0_sample
Dataset size: 153 51 52
Data saved to ../../../../output/sample/2023-07-29-17-20-51-0_sample (data.csv and tabular_data.csv).
  rank_zero_deprecation(
Trainer saved. To load the trainer, run trainer = load_trainer(path='../../../../output/sample/2023-07-29-17-20-51-0_sample/trainer.pkl')
Trainer saved. To load the trainer, run trainer = load_trainer(path='../../../../output/sample/2023-07-29-17-20-51-0_sample/trainer.pkl')
Trainer saved. To load the trainer, run trainer = load_trainer(path='../../../../output/sample/2023-07-29-17-20-51-0_sample/trainer.pkl')
PytorchTabular metrics
Category Embedding 1/1
WideDeep metrics
TabMlp 1/1
AutoGluon metrics
Linear Regression 1/1
Trainer saved. To load the trainer, run trainer = load_trainer(path='../../../../output/sample/2023-07-29-17-20-51-0_sample/trainer.pkl')


Unnamed: 0,Program,Model,Training RMSE,Training MSE,Training MAE,Training MAPE,Training R2,Training RMSE_CONSERV,Testing RMSE,Testing MSE,Testing MAE,Testing MAPE,Testing R2,Testing RMSE_CONSERV,Validation RMSE,Validation MSE,Validation MAE,Validation MAPE,Validation R2,Validation RMSE_CONSERV
0,WideDeep,TabMlp,117.866783,13892.578569,95.954615,1.767595,0.578265,12843.83249,136.008612,18498.342515,108.973634,3.452354,0.375839,11574.74935,103.435861,10698.977274,84.573629,1.402757,0.51681,12147.943417
1,AutoGluon,Linear Regression,114.065981,13011.04809,91.398514,2.686924,0.605025,12364.215702,139.269733,19396.058516,119.072766,4.078846,0.345548,11994.905029,110.253538,12155.842646,88.607594,1.54647,0.451015,12189.700803
2,PytorchTabular,Category Embedding,103.473606,10706.787126,85.594475,2.112976,0.674976,10549.942827,143.043941,20461.568975,121.271502,4.315138,0.309596,13106.784041,106.790754,11404.265107,83.61624,1.295624,0.484958,13734.326173


*Optional*: Use the following line, we can run multiple random trials based on different random seeds and take the average of metrics to evaluate models.

```python
trainer.get_leaderboard(cross_validation=2, split_type="random", stderr_to_stdout=True)
```

**Remark**: `split_type` can be `cv`, which represents k-fold cross-validation where k is `cross_validation`. Here `split_type="random"` means that the dataset is randomly split according to the given `split_ratio` in the configuration and different random seeds.

## Selecting and storing a model

From the leaderboard, we can check the performance of each model and select one of the models for deployment. Say we want to choose `TabMlp` from `WideDeep` (`pytorch_widedeep`), we detach the model from the heavy `trainer`. It is also stored locally in a separate directory.

In [2]:
trainer_of_one_model = trainer.detach_model(program="WideDeep", model_name="TabMlp")

Trainer saved. To load the trainer, run trainer = load_trainer(path='../../../../output/sample/2023-07-29-17-20-51-0_sample-I1/trainer.pkl')


The detached trainer now has only one model base.

In [3]:
# Model bases of the detached trainer
trainer_of_one_model.modelbases

[<tabensemb.model.widedeep.WideDeep at 0x7f63607f9880>]

In [4]:
# The model in the model base
trainer_of_one_model.get_modelbase("WideDeep_TabMlp").model["TabMlp"]

<pytorch_widedeep.training.trainer.Trainer at 0x7f627690f130>

## Loading the model

Now the `Trainer` containing a single model stores in a seperate directory. Assume that we want to load the local trainer in a separate script for inference. In the following line, the argument `path` of `load_trainer` is the path to `trainer.pkl`, which is already printed when detaching the model or training the model bases. Here we just use the directory of the detached trainer `trainer_of_one_model`.

**Remark**: You can move the directory to any other places (or other devices if the version of the package and the environment are all consistent) and rename the folder. `tabensemb` automatically configures the path.

In [5]:
from tabensemb.trainer import load_trainer

trainer = load_trainer(path=os.path.join(trainer_of_one_model.project_root, "trainer.pkl"))

In [6]:
trainer.get_modelbase("WideDeep_TabMlp").model["TabMlp"]

<pytorch_widedeep.training.trainer.Trainer at 0x7f63608885e0>

## Inference

Assume that we have a new `DataFrame` representing an upcoming dataset. For demonstration, we use the testing set here.

In [7]:
df = trainer.df.loc[trainer.test_indices, :]
truth = trainer.df.loc[trainer.test_indices, trainer.label_name].values.flatten()

Use the functionality of the model base to do inference. You can see the RMSE error on the "new" (testing) dataset is the same as that in the above leaderboard.

In [8]:
from tabensemb.utils import metric_sklearn

result = trainer.get_modelbase("WideDeep_TabMlp").predict(df, model_name="TabMlp")
metric_sklearn(truth, result, "rmse")

136.00861191638165

In [9]:
result

array([[ -74.43971  ],
       [ -81.44765  ],
       [ -31.800875 ],
       [ -75.2285   ],
       [ -53.677227 ],
       [  70.04292  ],
       [ -25.622992 ],
       [  36.68125  ],
       [-211.98978  ],
       [ 159.79549  ],
       [-102.89287  ],
       [-107.03044  ],
       [ -23.381763 ],
       [  26.2937   ],
       [ -56.9037   ],
       [  42.62137  ],
       [ -79.51051  ],
       [ -80.04006  ],
       [ 168.29333  ],
       [  10.588863 ],
       [-124.99167  ],
       [-175.97296  ],
       [ -52.490498 ],
       [ 149.38387  ],
       [ 153.97131  ],
       [  88.51623  ],
       [  16.419373 ],
       [-181.93864  ],
       [-143.57932  ],
       [-108.806404 ],
       [   2.2212856],
       [ 118.65542  ],
       [-147.20416  ],
       [-209.2984   ],
       [ 158.00734  ],
       [-126.64957  ],
       [-173.66629  ],
       [  -6.8193183],
       [  38.248684 ],
       [-183.28394  ],
       [  70.60948  ],
       [ 115.3542   ],
       [-161.44592  ],
       [ 13