In [1]:
# Add package 
!pip install polars



In [2]:
# === Imports ===
import os
import warnings
import numpy as np
import pandas as pd
import polars as pl

from sklearn import *
import kaggle_evaluation.mitsui_inference_server

# === Ignore warnings for cleaner output ===
warnings.filterwarnings("ignore")

In [3]:
# Define input path
DATA_PATH = '/kaggle/input/mitsui-commodity-prediction-challenge/'

# Load datasets
df_train = pd.read_csv(f"{DATA_PATH}train.csv")
df_labels = pd.read_csv(f"{DATA_PATH}train_labels.csv")
df_targets = pd.read_csv(f"{DATA_PATH}target_pairs.csv")

# List of target columns
target_columns = [f'target_{i}' for i in range(424)]

# Replace NaN targets with 0 (baseline)
df_labels[target_columns] = df_labels[target_columns].fillna(0)

In [4]:
def rank_correlation_sharpe_ratio(merged: pd.DataFrame) -> float:
    pred_cols = [c for c in merged.columns if c.startswith("prediction_")]
    true_cols = [c for c in merged.columns if c.startswith("target_")]

    def daily_rank_corr(row):
        valid_targets = [c for c in true_cols if pd.notnull(row[c])]
        pred_match = [c.replace("target_", "prediction_") for c in valid_targets]
        
        if not valid_targets:
            raise ValueError("No valid targets on this row")
        
        if row[valid_targets].std(ddof=0) == 0 or row[pred_match].std(ddof=0) == 0:
            raise ZeroDivisionError("Standard deviation is zero for ranking")
        
        return np.corrcoef(
            row[pred_match].rank(),
            row[valid_targets].rank()
        )[0, 1]

    daily_corrs = merged.apply(daily_rank_corr, axis=1)
    std = daily_corrs.std(ddof=0)
    if std == 0:
        raise ZeroDivisionError("Cannot compute Sharpe ratio: std is 0")

    return float(daily_corrs.mean() / std)

In [5]:
def score(solution_df: pd.DataFrame, prediction_df: pd.DataFrame) -> float:
    assert list(solution_df.columns) == list(prediction_df.columns), "Column mismatch"
    
    preds = prediction_df.rename(columns=lambda c: c.replace("target_", "prediction_"))
    truth = solution_df.replace(0, np.nan)

    merged = pd.concat([truth, preds], axis=1)
    return rank_correlation_sharpe_ratio(merged)

# Optional local evaluation
score(df_labels[target_columns].tail(90), df_labels[target_columns].tail(90))

1.1418709159997808e+16

In [6]:
# Cache to hold rolling test data
rolling_test_df = pd.DataFrame()

def predict(test, lag1, lag2, lag3, lag4):
    """
    This is the official prediction function registered with the inference server.
    Currently returns label data from train set as a placeholder.
    """
    global rolling_test_df, df_labels, target_columns

    # Convert test to pandas
    test_df = test.to_pandas()

    # Accumulate test data if needed
    if not rolling_test_df.empty:
        full_test = pd.concat([rolling_test_df, test_df])
    else:
        full_test = test_df.copy()

    # Save updated state
    rolling_test_df = full_test.copy()

    # Predict for current date_id
    current_date = full_test["date_id"].iloc[-1]
    preds = df_labels[df_labels["date_id"] == current_date][target_columns]

    return preds

In [7]:
# Initialize server with correct predict function
inference_server = kaggle_evaluation.mitsui_inference_server.MitsuiInferenceServer(predict)

# Run either in competition mode or locally
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway((DATA_PATH,))
    display(pl.read_parquet('/kaggle/working/submission.parquet'))

date_id,target_0,target_1,target_2,target_3,target_4,target_5,target_6,target_7,target_8,target_9,target_10,target_11,target_12,target_13,target_14,target_15,target_16,target_17,target_18,target_19,target_20,target_21,target_22,target_23,target_24,target_25,target_26,target_27,target_28,target_29,target_30,target_31,target_32,target_33,target_34,target_35,…,target_387,target_388,target_389,target_390,target_391,target_392,target_393,target_394,target_395,target_396,target_397,target_398,target_399,target_400,target_401,target_402,target_403,target_404,target_405,target_406,target_407,target_408,target_409,target_410,target_411,target_412,target_413,target_414,target_415,target_416,target_417,target_418,target_419,target_420,target_421,target_422,target_423
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
1827,0.0,0.0,0.017868,-0.000205,-0.016391,-0.013827,0.009972,0.0,0.0,0.007339,0.0,0.000648,-0.000852,-0.008283,-0.002739,0.019704,0.0,0.0,-0.017568,-0.004002,-0.014452,0.0,0.0,0.002912,0.0,0.0,0.002835,0.0,0.0,0.014195,0.0,0.0,0.0,0.0,0.002789,0.0,…,0.0,0.0,0.0,0.0,0.0,0.007543,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.013897,0.0,0.002007,0.0,0.0,0.025312,0.023136,-0.005955,-0.024995,-0.017375,0.026902,0.0,0.0,-0.001379,0.0,0.019701,0.0,-0.02703,0.043602,0.027982,0.0,0.0,0.002177,0.0
1828,0.00256,-0.004592,-0.001776,0.000271,-0.016696,-0.020025,0.002514,0.002204,-0.011962,0.008167,0.017102,0.005028,0.019163,-0.00197,0.010479,0.009287,-0.000343,-0.041377,-0.004934,0.007855,-0.006208,-0.004902,-0.041703,0.00895,0.024104,-0.002321,0.011276,-0.028035,-0.006816,0.001038,-0.055063,-0.009939,0.025668,0.004457,-0.009861,-0.013057,…,-0.026099,-0.034523,0.035679,0.121072,-0.008056,0.021389,-0.029272,-0.029545,0.00028,0.07288,-0.107708,-0.017056,-0.003319,0.017535,-0.037499,-0.012534,0.004439,0.00886,0.003493,0.008265,-0.016261,-0.000541,0.011231,0.020304,-0.000436,0.006479,-0.006886,0.012063,0.012081,-0.020068,0.002858,0.019154,0.019018,0.003875,-0.035202,0.011246,0.099241
1829,0.005346,-0.014539,0.019542,0.014626,-0.011631,-0.009223,-0.005199,-0.026092,-0.003865,0.00895,-0.008017,-0.001784,-0.011473,-0.010397,-0.009773,0.000909,0.02615,-0.007636,-0.003865,0.009393,-0.021462,0.006038,0.006117,0.000398,0.029943,-0.010855,-0.002802,0.005267,0.01606,0.004741,-0.028163,-0.017205,-0.003412,0.015509,0.010165,-0.003796,…,-0.011534,-0.05373,0.039137,0.066989,-0.005044,0.025555,0.007079,-0.012595,-0.010889,0.031488,-0.088161,0.001681,-0.022259,0.019903,-0.056522,0.001215,0.020855,0.000878,0.014683,0.013527,-0.000723,-0.01233,0.003398,0.025563,0.002089,0.03439,0.013036,-0.009004,0.016166,-0.028919,-0.007297,0.033262,0.023174,-0.028512,-0.0179,-0.002096,0.121451
1830,0.000082,-0.005226,0.011452,0.013346,0.008228,-0.014819,-0.011792,-0.007148,0.005712,0.009382,0.006053,-0.001469,0.034369,-0.008849,0.011732,-0.005689,0.006415,-0.001207,0.001055,-0.000658,-0.009977,0.00162,-0.010534,0.005017,-0.022753,0.001035,-0.000097,0.003259,0.004387,-0.031017,0.008543,-0.012897,0.011389,-0.005703,0.021406,0.006979,…,-0.030488,-0.02537,0.022239,0.066084,0.031417,0.017546,0.006995,-0.026773,-0.026172,0.042424,-0.09226,-0.002049,-0.019993,0.010716,-0.02674,-0.012214,0.026445,0.008532,-0.000179,-0.007162,-0.013116,0.001457,0.015902,0.002427,0.001141,0.000699,0.01658,-0.013857,-0.007742,-0.018436,0.004691,0.013311,0.000589,-0.0145,-0.046444,0.009058,0.109246
1831,-0.011469,0.016613,-0.023765,-0.018744,-0.011878,0.007257,0.019829,0.006618,-0.015351,-0.020452,-0.007495,0.022225,-0.024215,0.026028,0.002966,0.009503,-0.00203,-0.015345,0.004251,0.004336,0.014184,0.00157,-0.041702,0.0101,-0.115978,-0.012884,-0.006369,-0.043873,0.006041,0.037664,0.16847,0.014518,0.034673,0.021639,-0.027945,-0.013476,…,-0.007888,-0.020061,0.028008,0.043235,0.077982,0.009942,0.056984,-0.021971,-0.044638,0.040099,-0.058358,0.007707,-0.02602,-0.007722,-0.021644,-0.027072,0.052724,0.015722,-0.003946,-0.010448,-0.016842,-0.000208,0.019106,-0.004725,-0.011568,0.000674,0.007404,-0.015625,-0.01885,-0.025373,0.031197,0.005873,-0.00565,-0.022926,-0.02799,0.011267,0.091318
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1912,-0.012596,0.009309,0.007947,0.006477,-0.026216,-0.029487,-0.003287,0.011751,-0.02761,0.000665,0.026752,0.000976,0.02158,0.001321,0.025961,-0.001953,-0.011606,-0.055654,0.002567,0.030193,-0.005809,-0.012533,-0.041661,-0.003858,-0.015186,0.00545,-0.008017,-0.016237,-0.006058,0.003453,-0.002078,-0.013829,0.003731,0.014301,-0.025108,-0.028728,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.041135,0.0,0.0,0.0,0.0,-0.049973,0.0,0.0,0.0,0.0,0.0,-0.026769,0.0,0.0,0.0,0.042052,0.0,0.0,0.0,0.0,0.031006,0.0
1913,-0.000736,-0.001292,-0.005137,-0.007961,-0.008494,-0.011547,0.004263,0.00567,-0.004822,-0.000875,0.011004,0.002718,0.008976,-0.0039,0.018571,0.003,-0.005799,-0.002896,-0.00415,0.001491,0.001282,0.009566,-0.014855,0.004179,-0.002997,-0.005198,0.002046,-0.007307,-0.003663,0.00653,-0.010523,0.010571,0.009859,0.016003,-0.017071,-0.004511,…,-0.02153,-0.01163,0.002408,-0.038982,-0.02049,-0.003409,0.024459,0.010653,0.015806,0.006548,0.019286,0.012487,-0.011835,-0.001247,-0.010599,0.000446,-0.009257,-0.235108,-0.004405,-0.003146,-0.000069,0.001729,0.01176,-0.004931,-0.018399,0.024919,0.00675,-0.012992,-0.003349,0.010912,0.006198,-0.009216,-0.003038,-0.026082,-0.008057,-0.002069,-0.141053
1914,-0.002294,0.012898,0.009978,0.001567,0.002596,-0.007373,0.007554,0.002661,0.004083,-0.00622,0.006817,-0.017027,0.009083,0.009517,0.002916,-0.002976,-0.00231,0.00038,-0.007149,0.009475,-0.002829,-0.01084,-0.004764,-0.011159,0.100948,-0.00041,-0.010294,-0.004312,0.007715,0.005817,-0.104087,-0.020188,-0.032932,-0.010976,-0.004316,0.00374,…,-0.01604,-0.012072,-0.006242,-0.049473,-0.007303,-0.00595,0.042362,0.005606,0.008433,0.014709,0.0153,0.008254,-0.001277,0.004271,-0.0081,0.010568,-0.021193,-0.218786,-0.002684,-0.001232,0.014846,-0.002419,0.004828,-0.009354,-0.001657,0.003526,-0.008493,-0.00524,0.004044,0.004459,0.002619,0.001308,-0.006772,-0.019918,-0.013304,-0.005527,-0.127688
1915,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.006171,0.0,-0.008765,0.0,0.0,0.0,0.0,-0.001627,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.004378,0.0,…,-0.001893,-0.018356,0.011577,-0.024187,-0.029459,0.006396,0.016987,0.027237,0.000558,0.000394,0.006531,0.015775,-0.021438,0.011393,-0.016914,0.002455,-0.014087,-0.025906,0.007553,0.015503,0.001778,-0.00889,-0.019823,0.005968,0.000596,0.030605,0.000639,-0.00825,0.012842,0.009076,0.000932,0.011613,0.003825,0.02435,-0.006928,0.006805,-0.012187
