In [1]:
import glob, os

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from itertools import chain

from catboost import CatBoostRanker, CatBoostClassifier

In [2]:
pd.options.display.max_rows = 300
pd.options.display.max_columns = 300

In [3]:
BASE_DIR = "/Users/artemvopilov/Programming/yandex_cup_2023"

In [4]:
DATA_DIR = f"{BASE_DIR}/data"

TRAIN_DF_PATH = f"{DATA_DIR}/train.csv"
TEST_DF_PATH = f"{DATA_DIR}/test.csv"

In [10]:
preds_dfs = {}
for fp in os.listdir(f'{BASE_DIR}/notebooks'):
    if not os.path.isfile(fp):
        continue
    if not (fp.startswith('prediction') and fp.endswith('.csv') and 'final' not in fp):
        continue

    print(f'reading from {fp}')
    pred_name = fp.split('.')[0].split('prediction_')[1]
    preds = pd.read_csv(fp)
    preds['prediction'] = preds['prediction'].apply(lambda x: list(map(float, x.split(','))))

    preds_dfs[pred_name] = preds

reading from prediction_vae_als.csv
reading from prediction_knn_first_e_vae.csv
reading from prediction_lstm_normed_2.csv
reading from prediction_knn_normed.csv
reading from prediction_lstm_pca.csv
reading from prediction_lstm_normed.csv
reading from prediction_knn_pca.csv
reading from prediction_pca_als.csv
reading from prediction_pca_als_2.csv
reading from prediction_normed_lstm_dssm_dot.csv
reading from prediction_knn_vae.csv
reading from prediction_lstm_vae_2.csv
reading from prediction_vae_dssm_cos.csv
reading from prediction_knn_vae_last.csv
reading from prediction_knn_normed_lstm.csv
reading from prediction_normder_lstm_als.csv
reading from prediction_normed_lstm_dssm_cos.csv
reading from prediction_lstm_vae.csv
reading from prediction_vae_dssm_dot.csv


In [11]:
train_df = pd.read_csv(TRAIN_DF_PATH)
test_df = pd.read_csv(TEST_DF_PATH)

In [12]:
test_tracks_sorted = test_df['track'].values
test_tracks_sorted

array([17730, 32460, 11288, ...,  5257, 73095, 67472])

In [19]:
for f, preds_df in tqdm(preds_dfs.items()):
    print(f)

    t_to_preds = {}

    for t in tqdm(test_tracks_sorted):
        t_preds = np.exp(preds_df[preds_df['track'] == t]['prediction'].values[0])
        t_preds_sum = np.sum(t_preds)
        t_preds = t_preds / t_preds_sum
        t_to_preds[t] = ','.join(map(str, t_preds))
    
    predictions_df = pd.DataFrame([
        {'track': t, 'prediction': preds}
        for t, preds in t_to_preds.items()
    ])

    print(predictions_df.head())
    predictions_df.to_csv(f'prediction_final_fff222_{f}.csv', index=False)

  0%|          | 0/19 [00:00<?, ?it/s]

vae_als


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.0047307643818296635,0.004608197774893473,0.0...
1  32460  0.004232246555907819,0.005207444220727864,0.00...
2  11288  0.0046256180449810915,0.004619240064607455,0.0...
3  18523  0.005109836324573963,0.004476020244263724,0.00...
4  71342  0.0042640909139698355,0.0035570765473121684,0....
knn_first_e_vae


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.006658522389290557,0.006992375173325447,0.00...
1  32460  0.0046950107797955115,0.006326718275954842,0.0...
2  11288  0.0049347226072303995,0.005252516996070001,0.0...
3  18523  0.006343597239843235,0.0060402822087835635,0.0...
4  71342  0.005786774257456198,0.004155422733027662,0.00...
lstm_normed_2


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.0023040909502699994,0.005149702865532742,0.0...
1  32460  0.03342862016186549,0.09463480482914365,0.0919...
2  11288  0.02987211661242744,0.04921805809105226,0.0265...
3  18523  0.13275510369466853,0.029370835884543145,0.140...
4  71342  0.021021876196493924,0.0023080899456725834,0.0...
knn_normed


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.004714221900310374,0.005417092671431759,0.00...
1  32460  0.00440591008509585,0.004541052938336113,0.005...
2  11288  0.004244560850268295,0.006402947164765259,0.00...
3  18523  0.0057492570507301,0.004002288100627045,0.0052...
4  71342  0.003853362469239949,0.004066219880131073,0.00...
lstm_pca


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.11474643910028143,0.05697756060073047,0.0412...
1  32460  0.05005645262235391,0.17524512725203695,0.0181...
2  11288  0.11474645914522437,0.05697755915857169,0.0412...
3  18523  0.11474643935633659,0.056977560727875146,0.041...
4  71342  0.11474645802790782,0.05697755860376616,0.0412...
lstm_normed


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.05597115662344964,0.07442184376583777,0.0244...
1  32460  0.10547294380506377,0.13393717720667578,0.0454...
2  11288  0.040449312302402174,0.1690238183852639,0.0751...
3  18523  0.19228998074304907,0.05445426358221426,0.0561...
4  71342  0.1086142397601529,0.015921275874638156,0.0197...
knn_pca


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.00508200623540262,0.005319405469765026,0.004...
1  32460  0.005373990831422904,0.005423124762010974,0.00...
2  11288  0.004274709494658365,0.005996919102759026,0.00...
3  18523  0.005161915340788948,0.004665278930695206,0.00...
4  71342  0.004305160151707233,0.0040930823686539785,0.0...
pca_als


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.0038972099016021224,0.003532816918978879,0.0...
1  32460  0.003504065490326404,0.004017039927909007,0.00...
2  11288  0.0033549339594479066,0.00358714795769601,0.00...
3  18523  0.004175148916396576,0.003472924714868067,0.00...
4  71342  0.00336677284735894,0.0031380585844663746,0.00...
pca_als_2


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.003899538902029538,0.0035337201278227855,0.0...
1  32460  0.003506687553184447,0.004013944492708027,0.00...
2  11288  0.003359727485322296,0.0035889566019364555,0.0...
3  18523  0.004172665810902078,0.0034711555315893074,0.0...
4  71342  0.0033677758554176204,0.0031388589210157867,0....
normed_lstm_dssm_dot


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.00017013720404032942,0.00045601148432678667,...
1  32460  0.0003283597641103241,0.0007640459635659708,7....
2  11288  7.741924384529804e-05,0.00024328433889269244,4...
3  18523  6.649949607153463e-05,0.0002152574965246744,3....
4  71342  2.3466858667016104e-05,9.256893926647659e-05,5...
knn_vae


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.005892314134961234,0.009219337461630067,0.00...
1  32460  0.0045514653640295125,0.004629833844266934,0.0...
2  11288  0.0058737883704044445,0.005059102772730436,0.0...
3  18523  0.0046664540517186584,0.004503954003007108,0.0...
4  71342  0.004434558743269639,0.003992547951993532,0.00...
lstm_vae_2


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.03393857097939978,0.05789845235491907,0.0487...
1  32460  0.04087810335197326,0.13571717832956703,0.0461...
2  11288  0.0571017476847112,0.039959487310544425,0.0572...
3  18523  0.058355881278170235,0.0126064520047528,0.1439...
4  71342  0.013718956847391897,0.0025886649716979253,0.0...
vae_dssm_cos


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.003906931469764556,0.0039061514896953048,0.0...
1  32460  0.00390693820403257,0.0039061604100681054,0.00...
2  11288  0.003906929833246631,0.0039061600095069158,0.0...
3  18523  0.00390686182639265,0.0039061474829534867,0.00...
4  71342  0.003906879604763235,0.0039061453757428537,0.0...
knn_vae_last


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.004870480630634544,0.00637355973493592,0.004...
1  32460  0.004905313250343139,0.005428076183952914,0.00...
2  11288  0.006826388340903188,0.004717924549020209,0.00...
3  18523  0.004760095001835747,0.005637310512488045,0.00...
4  71342  0.005842831061017021,0.004780852511918962,0.00...
knn_normed_lstm


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.003847150537808986,0.004025301535391302,0.00...
1  32460  0.004913451409759925,0.0051120008917122555,0.0...
2  11288  0.004249468153033917,0.004652139306956722,0.00...
3  18523  0.007806214941620327,0.004029221866669217,0.00...
4  71342  0.004242994133697241,0.0038581348767982747,0.0...
normder_lstm_als


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.0037868509713402457,0.003798761887068822,0.0...
1  32460  0.004154015802497785,0.005980411687227117,0.00...
2  11288  0.004250337862449555,0.004679326301158979,0.00...
3  18523  0.00560437024331879,0.0043373652737321,0.00536...
4  71342  0.004000284245936667,0.003825799909492643,0.00...
normed_lstm_dssm_cos


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.004236567810665729,0.0042366928336265266,0.0...
1  32460  0.004236656804282052,0.0042366972221807575,0.0...
2  11288  0.004236608802710172,0.004236655236197104,0.00...
3  18523  0.00423665631625448,0.004236596156161927,0.000...
4  71342  0.004236641425821992,0.004236660914417373,0.00...
lstm_vae


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.1044317626175271,0.10232760383908551,0.04964...
1  32460  0.03074999007531406,0.16468009462462094,0.0225...
2  11288  0.04229141312508075,0.052203913640701746,0.100...
3  18523  0.11965380358824379,0.026864381546765844,0.113...
4  71342  0.029474111141628314,0.0038359939874546395,0.0...
vae_dssm_dot


  0%|          | 0/25580 [00:00<?, ?it/s]

   track                                         prediction
0  17730  0.008680967953305795,7.503312631862434e-93,5.3...
1  32460  0.007637902668046786,1.3040915854827534e-53,0....
2  11288  0.007117667149877595,4.496614305315884e-44,0.0...
3  18523  0.007433242047715463,1.2409303578405475e-49,0....
4  71342  0.002008060950390043,8.038259979838685e-305,8....
