# Submission Script

Use this to submit the prediction. It produces a submission.parquet file after intreracting with a server. 

Don't forget to bundle the model into "model.pkl" pickle. 

In [None]:
import os 

import pandas as pd
import polars as pl 
import numpy as np

import kaggle_evaluation.jane_street_inference_server

lags_ : pl.DataFrame | None = None


import pickle
from sklearn.base import RegressorMixin

# ---------------- 
# BUNDLED UTILS 
# ----------------
def get_feature_cols():
    NUM_FEATURE_COLS = 79
    return [f'feature_{x:02}' for x in range(NUM_FEATURE_COLS)]

# ----------------
# LOAD THE MODEL 
# ----------------
with open("model.pkl", "rb") as file: 
    model:RegressorMixin = pickle.load(file)

In [None]:
# ----------------
# PREDICT FUNCTION 
# ----------------

def predict(test: pl.DataFrame, 
            lags: pl.DataFrame | None) -> pl.DataFrame | pd.DataFrame : 
    """ Prediction """
    global lags_ 
    if lags is not None:
        lags_ = lags

    feature_cols = get_feature_cols()

    # Get the features from the test data.
    X_test = test[feature_cols].to_numpy()
    X_test = np.nan_to_num(X_test)

    predictions = pl.DataFrame({"row_id": test.select("row_id"), 
                            "responder_6": model.predict(X_test)})

    if isinstance(predictions, pl.DataFrame):
        assert predictions.columns == ['row_id', 'responder_6']
    elif isinstance(predictions, pd.DataFrame):
        assert (predictions.columns == ['row_id', 'responder_6']).all()
    else:
        raise TypeError('The predict function must return a DataFrame')
    
    # Confirm has as many rows as the test data.
    assert len(predictions) == len(test)

    return predictions

In [39]:
RUNS_ON_KAGGLE = os.getenv('KAGGLE_URL_BASE') is not None

inference_server = kaggle_evaluation.jane_street_inference_server.JSInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
elif RUNS_ON_KAGGLE:
    inference_server.run_local_gateway(
        (
            '/kaggle/input/jane-street-real-time-market-data-forecasting/test.parquet',
            '/kaggle/input/jane-street-real-time-market-data-forecasting/lags.parquet',
        )
    )
else:
    inference_server.run_local_gateway(
        (
            'test.parquet',
            'lags.parquet',
        )
    )