In [1]:
# ===================================================================
#  Library
# ===================================================================
import os
import random
import numpy as np
import pandas as pd
import polars as pl
from tqdm.auto import tqdm

from sklearn.metrics import mean_absolute_percentage_error
import warnings
warnings.simplefilter("ignore")

In [2]:
# ===================================================================
#  CFG
# ===================================================================
class CFG:
    seed = 42
    save_dir = "G:/マイドライブ/signate_StudentCup2023/exp/"
    data_dir = "G:/マイドライブ/signate_StudentCup2023/data/"
    filename = "exp071"

In [3]:
# ===================================================================
#  Utils
# ===================================================================
def seed_everything(seed):
    """fix random factors"""
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
seed_everything(CFG.seed)
    

def get_score(y_true, y_pred):
    """get MAPE score"""
    score = mean_absolute_percentage_error(y_true, y_pred)
    return score * 100

In [4]:
# ===================================================================
#  Data Loading
# ===================================================================
df = pl.read_csv(CFG.save_dir+f"oof_df_exp065.csv").sort("id")
train = pl.read_csv(CFG.data_dir+"train.csv", columns=["id", "price"])
df = df.join(train, on="id", how="left")
df = df.join(
    pl.read_csv(CFG.save_dir+"kun_exp00052_oof_pred.csv"),
    on="id", how="left"
)

test = pl.read_csv(CFG.save_dir+f"exp065.csv").sort("id")
test = test.join(
    pl.read_csv(CFG.save_dir+"kun_exp00052.csv"),
    on="id", how="left"
)

df.head()

id,pred_0,pred_1,pred_2,pred_3,pred_4,pred_5,pred_6,pred_7,pred_8,pred_9,pred_10,pred_11,pred_12,pred_13,pred_14,pred_15,pred_16,pred_17,pred_18,pred_19,pred_20,pred_21,pred_22,pred_23,pred_24,pred_25,pred_26,pred_27,pred_28,pred_29,pred_30,pred_31,pred_32,pred_33,pred_34,pred_35,pred_36,pred_37,pred_38,pred_39,price,kun_pred_0,kun_pred_1,kun_pred_2,kun_pred_3,kun_pred_4,kun_pred_5,kun_pred_6,kun_pred_7,kun_pred_8,kun_pred_9
i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0,6881.892301,6517.414858,7476.567746,7052.61059,6809.002174,6465.089852,7714.185588,7010.85559,7184.306903,6364.997182,7501.188408,7678.577536,8374.386963,8105.340176,8644.322072,7985.056344,8628.261509,6528.148787,7798.699012,8063.395359,8981.008775,8613.396141,10717.468387,9447.344769,7984.506685,8272.276158,5876.6578,6016.32209,7374.013382,8496.804984,7400.051675,7502.731399,9046.083334,9790.2077,7999.621081,9179.605316,7404.460231,10863.206534,7674.566326,10945.332692,27587,7118.9443,7627.1987,7084.0015,7043.095,7122.753,7335.231,7350.006,9836.875,8781.688,8747.975
1,3740.634027,3496.012582,3789.907986,3322.844722,3603.935425,3485.623398,3507.497452,3714.057283,3488.353022,3750.770625,3412.19359,3927.363047,3886.662697,3551.474485,3713.3473,3463.093834,3695.630189,3663.237361,3463.024267,3905.364004,3563.378845,4112.248835,3360.300264,3807.072052,3406.217022,3615.842705,4245.317355,4453.437876,4408.398403,3900.820201,3869.869927,3681.464095,3648.063325,4057.324098,3423.298315,4256.8008,3937.424257,4012.205191,3871.318876,4087.226759,4724,3660.8042,3526.6816,3732.6855,3593.2502,3596.9211,3562.2925,3922.8992,3761.7544,3852.5068,3476.0593
2,2954.247573,2735.110086,3288.647709,2893.84021,3253.350282,2885.690295,3202.10377,2799.365349,2989.931642,2920.912332,3059.180374,3198.706345,2764.92988,3237.239548,2756.230854,3233.761663,2927.997318,3258.873698,3235.904715,3139.762054,3016.016944,3167.730032,2929.953707,2957.126627,3618.763534,3022.052454,3913.586712,3705.560742,2921.102463,3116.785537,2846.925805,3310.14365,3429.863903,2884.781039,3109.509196,3182.319352,2966.910201,3300.768016,2975.34178,2996.449911,10931,2883.739,3149.7769,2995.776,3038.0789,2876.831,2982.5576,2787.5325,2911.18,3172.8408,3092.2769
3,8430.949224,8337.416095,9033.353098,8081.40405,8187.453186,8477.54243,8546.771462,8388.923127,8119.18331,8413.930266,8298.903056,8142.383515,8055.886032,8356.561411,8421.656285,8167.994718,8646.470884,8423.019535,8016.7337,8041.633697,8331.325796,7947.686753,8057.31244,8250.578545,8052.200549,8115.355025,8157.420044,7434.588221,8284.727211,7855.624335,8375.564192,8287.367768,8814.043198,8434.204563,8292.250116,8010.910553,8092.32479,8698.940654,8281.208121,7630.436524,16553,8650.627,8943.903,9243.175,8973.261,9158.718,8505.161,8713.231,8287.674,8649.64,8874.921
4,3972.418866,4254.790314,4089.549517,4272.046434,4446.949951,4022.828501,4374.413436,4369.679027,4150.7095,3941.880817,4227.904345,4420.621374,4676.941221,4534.857719,4604.858395,4588.033874,4023.362549,4455.22004,4574.171104,4520.151023,4408.383964,4200.673767,4535.020561,4420.531149,4421.192901,4254.960815,4031.139467,3972.143434,4027.789862,4720.410259,4455.75972,3900.863059,4312.082422,4518.990036,4405.510685,4745.297877,4280.133207,4148.129445,4364.547169,4530.330211,5158,4031.1462,4406.548,4107.919,4059.3975,4229.2666,4429.893,4209.0796,4026.7249,4228.1797,4164.242


In [5]:
# ===================================================================
#  simple greedy forward selection
# ===================================================================
# single modelで最もCVが良いモデルを選択する
scores = dict()
for col in df.columns:
    if col not in ["id", "price"]:
        scores[col] = get_score(y_true=df["price"], y_pred=df[col])    
        
BEST_SCORE = np.inf
for seed in range(40, 50):    
    selected_model = min(scores, key=scores.get)
    best_preds = df[selected_model]
    best_score = min(scores.values())
    
    stores = dict()
    orders = [selected_model]
    stores[selected_model] = 1    # 重みを保存
    

    filenames = [col for col in df.columns if col not in ["id", "price", selected_model]]
    filenames = np.random.RandomState(seed).permutation(filenames)

    for exp in filenames:
        best_weight = 0
        for w in np.arange(-0.5, 0.5, 0.001):
            preds = best_preds * (1-w) + df[exp] * w
            score = get_score(y_true=df["price"], y_pred=preds)
            if best_score > score:
                best_score = score
                best_weight = w
        stores[exp] = best_weight
        orders.append(exp)
        best_preds = best_preds * (1-best_weight) + df[exp] * best_weight
    print(seed, best_score)
    
    if BEST_SCORE > best_score:
        BEST_SCORE = best_score
        BEST_STORE = stores.copy()
        BEST_ORDER = orders.copy()

40 43.59516025065753
41 43.59275278157505
42 43.5889753456137
43 43.58531693429042
44 43.582951599388245
45 43.59384664397003
46 43.587720375346635
47 43.58808360633436
48 43.59796299521994
49 43.59271354018121


In [6]:
# ===================================================================
#  Check
# ===================================================================
def get_preds(df: pl.DataFrame):
    best_preds = 0
    for exp in BEST_ORDER:
        w = BEST_STORE[exp]
        best_preds  = best_preds * (1-w) + df[exp] * w
    return best_preds
    
get_score(y_true=df["price"], y_pred=get_preds(df))

43.582951599388245

In [7]:
# ===================================================================
#  oof_df
# ===================================================================
df.with_columns(
    pl.Series(get_preds(df)).alias("pred")
)[["id", "pred"]].write_csv(CFG.save_dir+f"oof_df_{CFG.filename}.csv", has_header=True)

display(df.with_columns(
    pl.Series(get_preds(df)).alias("pred")
)[["id", "pred"]].head())

id,pred
i64,f64
0,9196.765168
1,3885.20529
2,3060.125866
3,8007.789593
4,4537.125697


In [8]:
# ===================================================================
#  test
# ===================================================================
test = test.with_columns(
    pl.Series(get_preds(test)).alias("pred")
)
test.select(["id", "pred"]).write_csv(CFG.save_dir+f"{CFG.filename}.csv", has_header=False)
test.select(["id", "pred"]).head()

id,pred
i64,f64
27532,9774.45317
27533,5762.007883
27534,5606.688175
27535,18256.398339
27536,4366.343344
