# XGBoost

https://www.kaggle.com/code/cdeotte/catboost-starter-lb-0-60?scriptVersionId=158772898

In [1]:
DATA_PREPARATION_VOTE_METHOD = "max_vote_window" # "max_vote_window" or "sum_and_normalize". Decides how to aggregate the predictions of the overlapping windows

In [3]:
import os
import sys
import warnings
import gc
import pathlib

if bool(os.environ.get("KAGGLE_URL_BASE", "")):
  import sys
  # running on kaggle
  sys.path.insert(0, "/kaggle/input/hsm-source-files")
else:
  # running locally
  sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))

import torch.nn as nn
import pandas as pd
import numpy as np
import xgboost as xgb

import torch

from src.utils.utils import get_raw_data_dir, get_processed_data_dir, get_submission_csv_path, set_seeds, get_models_save_path
from src.utils.constants import Constants
from src.datasets.eeg_processor import EEGDataProcessor
from src.utils.k_folds_creator import KFoldCreator

from tqdm import tqdm

set_seeds(42)

2025-10-07 20:02:54,634 :: root :: INFO :: Initialising Utils
2025-10-07 20:02:54,635 :: root :: INFO :: Initialising Datasets


In [4]:
DATA_PATH = get_raw_data_dir()

processor = EEGDataProcessor(raw_data_path=DATA_PATH, processed_data_path=get_processed_data_dir())
train_df = processor.process_data(vote_method=DATA_PREPARATION_VOTE_METHOD, skip_npy=True)

test_df = pd.read_csv(DATA_PATH / "test.csv")

kl_score = nn.KLDivLoss(reduction="batchmean")

Processor initialized.
Raw data path: '/home/david/git/aicomp/data'
Processed data path: '/home/david/git/aicomp/data/processed'
Starting EEG Data Processing Pipeline
Skipping NumPy file creation as requested.
Using 'max_vote_window' vote aggregation strategy.

Processed train data saved to '/home/david/git/aicomp/data/processed/train_processed.csv'.
Shape of the final dataframe: (17089, 15)

Pipeline finished successfully!


## Feature Engineering

We need features for the XGBoost model.
For this, we take the mean over time for all of the 400 spectrogram frequencies.
We take the middle 10 minutes of all spectrograms.
For each EEG ID, this produces 400 features.

## Load Spectrogram Files into Memory

In [4]:
spectrograms_dir = DATA_PATH / "train_spectrograms"
spectrogram_files = list(spectrograms_dir.glob("*.parquet"))
print(f"Found {len(spectrogram_files)} train spectrogram files to load into memory")

def get_spectrogram_content(spectrogram_file: pathlib.Path):
  spectrogram_id = int(spectrogram_file.stem.split("_")[-1])
  content = pd.read_parquet(file)
  content = content.drop(columns=["time"]).values
  return spectrogram_id, content

spectrograms = {}
for file in tqdm(spectrogram_files):
  spectrogram_id, content = get_spectrogram_content(file)
  spectrograms[spectrogram_id] = content

gc.collect()
print("Loaded all train spectrograms into memory")

Found 11138 spectrogram files to load into memory


100%|██████████| 11138/11138 [06:21<00:00, 29.17it/s]


Loaded all train spectrograms into memory


In [5]:
FEATURES = [f"spec_mean_freq_{x}" for x in range(400)]
data = np.zeros((len(train_df), len(FEATURES)))

def extract_train_spectrogram_features(row, all_spectrograms):
  spectrogram_id = int(row["spectrogram_id"])
  middle_offset = (row["min_offset"] + row["max_offset"]) // 2 # this the middle between the least spectrogram offset and greatest spectogram offset
  row_index = int(middle_offset // 2) # each spectrogram row corresponds to 2s, so we divide by 2 to get the row index
  average_frequencies = np.array(all_spectrograms[spectrogram_id][row_index:row_index+300,:] ).mean(axis=0) # average over 300 rows (10 minutes)
  return average_frequencies

for i in tqdm(range(len(train_df)), total=len(train_df)):
  row = train_df.iloc[i]
  data[i,:] = extract_train_spectrogram_features(row, spectrograms)

100%|██████████| 17089/17089 [00:02<00:00, 6249.52it/s]


In [6]:
warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)
train_df[FEATURES] = data

del data
del spectrograms
gc.collect()

train_df.head()

Unnamed: 0,eeg_id,spectrogram_id,min_offset,max_offset,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,spec_mean_freq_0,spec_mean_freq_1,spec_mean_freq_2,spec_mean_freq_3,spec_mean_freq_4,spec_mean_freq_5,spec_mean_freq_6,spec_mean_freq_7,spec_mean_freq_8,spec_mean_freq_9,spec_mean_freq_10,spec_mean_freq_11,spec_mean_freq_12,spec_mean_freq_13,spec_mean_freq_14,spec_mean_freq_15,spec_mean_freq_16,spec_mean_freq_17,spec_mean_freq_18,spec_mean_freq_19,spec_mean_freq_20,spec_mean_freq_21,spec_mean_freq_22,spec_mean_freq_23,spec_mean_freq_24,spec_mean_freq_25,spec_mean_freq_26,spec_mean_freq_27,spec_mean_freq_28,spec_mean_freq_29,spec_mean_freq_30,spec_mean_freq_31,spec_mean_freq_32,spec_mean_freq_33,spec_mean_freq_34,spec_mean_freq_35,spec_mean_freq_36,spec_mean_freq_37,spec_mean_freq_38,spec_mean_freq_39,spec_mean_freq_40,spec_mean_freq_41,spec_mean_freq_42,spec_mean_freq_43,spec_mean_freq_44,spec_mean_freq_45,spec_mean_freq_46,spec_mean_freq_47,spec_mean_freq_48,spec_mean_freq_49,spec_mean_freq_50,spec_mean_freq_51,spec_mean_freq_52,spec_mean_freq_53,spec_mean_freq_54,spec_mean_freq_55,spec_mean_freq_56,spec_mean_freq_57,spec_mean_freq_58,spec_mean_freq_59,spec_mean_freq_60,spec_mean_freq_61,spec_mean_freq_62,spec_mean_freq_63,spec_mean_freq_64,spec_mean_freq_65,spec_mean_freq_66,spec_mean_freq_67,spec_mean_freq_68,spec_mean_freq_69,spec_mean_freq_70,spec_mean_freq_71,spec_mean_freq_72,spec_mean_freq_73,spec_mean_freq_74,spec_mean_freq_75,spec_mean_freq_76,spec_mean_freq_77,spec_mean_freq_78,spec_mean_freq_79,spec_mean_freq_80,spec_mean_freq_81,spec_mean_freq_82,spec_mean_freq_83,spec_mean_freq_84,spec_mean_freq_85,spec_mean_freq_86,spec_mean_freq_87,spec_mean_freq_88,spec_mean_freq_89,spec_mean_freq_90,spec_mean_freq_91,spec_mean_freq_92,spec_mean_freq_93,spec_mean_freq_94,spec_mean_freq_95,spec_mean_freq_96,spec_mean_freq_97,spec_mean_freq_98,spec_mean_freq_99,spec_mean_freq_100,spec_mean_freq_101,spec_mean_freq_102,spec_mean_freq_103,spec_mean_freq_104,spec_mean_freq_105,spec_mean_freq_106,spec_mean_freq_107,spec_mean_freq_108,spec_mean_freq_109,spec_mean_freq_110,spec_mean_freq_111,spec_mean_freq_112,spec_mean_freq_113,spec_mean_freq_114,spec_mean_freq_115,spec_mean_freq_116,spec_mean_freq_117,spec_mean_freq_118,spec_mean_freq_119,spec_mean_freq_120,spec_mean_freq_121,spec_mean_freq_122,spec_mean_freq_123,spec_mean_freq_124,spec_mean_freq_125,spec_mean_freq_126,spec_mean_freq_127,spec_mean_freq_128,spec_mean_freq_129,spec_mean_freq_130,spec_mean_freq_131,spec_mean_freq_132,spec_mean_freq_133,spec_mean_freq_134,spec_mean_freq_135,spec_mean_freq_136,spec_mean_freq_137,spec_mean_freq_138,spec_mean_freq_139,spec_mean_freq_140,spec_mean_freq_141,spec_mean_freq_142,spec_mean_freq_143,spec_mean_freq_144,spec_mean_freq_145,spec_mean_freq_146,spec_mean_freq_147,spec_mean_freq_148,spec_mean_freq_149,spec_mean_freq_150,spec_mean_freq_151,spec_mean_freq_152,spec_mean_freq_153,spec_mean_freq_154,spec_mean_freq_155,spec_mean_freq_156,spec_mean_freq_157,spec_mean_freq_158,spec_mean_freq_159,spec_mean_freq_160,spec_mean_freq_161,spec_mean_freq_162,spec_mean_freq_163,spec_mean_freq_164,spec_mean_freq_165,spec_mean_freq_166,spec_mean_freq_167,spec_mean_freq_168,spec_mean_freq_169,spec_mean_freq_170,spec_mean_freq_171,spec_mean_freq_172,spec_mean_freq_173,spec_mean_freq_174,spec_mean_freq_175,spec_mean_freq_176,spec_mean_freq_177,spec_mean_freq_178,spec_mean_freq_179,spec_mean_freq_180,spec_mean_freq_181,spec_mean_freq_182,spec_mean_freq_183,spec_mean_freq_184,spec_mean_freq_185,spec_mean_freq_186,spec_mean_freq_187,spec_mean_freq_188,spec_mean_freq_189,spec_mean_freq_190,spec_mean_freq_191,spec_mean_freq_192,spec_mean_freq_193,spec_mean_freq_194,spec_mean_freq_195,spec_mean_freq_196,spec_mean_freq_197,spec_mean_freq_198,spec_mean_freq_199,spec_mean_freq_200,spec_mean_freq_201,spec_mean_freq_202,spec_mean_freq_203,spec_mean_freq_204,spec_mean_freq_205,spec_mean_freq_206,spec_mean_freq_207,spec_mean_freq_208,spec_mean_freq_209,spec_mean_freq_210,spec_mean_freq_211,spec_mean_freq_212,spec_mean_freq_213,spec_mean_freq_214,spec_mean_freq_215,spec_mean_freq_216,spec_mean_freq_217,spec_mean_freq_218,spec_mean_freq_219,spec_mean_freq_220,spec_mean_freq_221,spec_mean_freq_222,spec_mean_freq_223,spec_mean_freq_224,spec_mean_freq_225,spec_mean_freq_226,spec_mean_freq_227,spec_mean_freq_228,spec_mean_freq_229,spec_mean_freq_230,spec_mean_freq_231,spec_mean_freq_232,spec_mean_freq_233,spec_mean_freq_234,spec_mean_freq_235,spec_mean_freq_236,spec_mean_freq_237,spec_mean_freq_238,spec_mean_freq_239,spec_mean_freq_240,spec_mean_freq_241,spec_mean_freq_242,spec_mean_freq_243,spec_mean_freq_244,spec_mean_freq_245,spec_mean_freq_246,spec_mean_freq_247,spec_mean_freq_248,spec_mean_freq_249,spec_mean_freq_250,spec_mean_freq_251,spec_mean_freq_252,spec_mean_freq_253,spec_mean_freq_254,spec_mean_freq_255,spec_mean_freq_256,spec_mean_freq_257,spec_mean_freq_258,spec_mean_freq_259,spec_mean_freq_260,spec_mean_freq_261,spec_mean_freq_262,spec_mean_freq_263,spec_mean_freq_264,spec_mean_freq_265,spec_mean_freq_266,spec_mean_freq_267,spec_mean_freq_268,spec_mean_freq_269,spec_mean_freq_270,spec_mean_freq_271,spec_mean_freq_272,spec_mean_freq_273,spec_mean_freq_274,spec_mean_freq_275,spec_mean_freq_276,spec_mean_freq_277,spec_mean_freq_278,spec_mean_freq_279,spec_mean_freq_280,spec_mean_freq_281,spec_mean_freq_282,spec_mean_freq_283,spec_mean_freq_284,spec_mean_freq_285,spec_mean_freq_286,spec_mean_freq_287,spec_mean_freq_288,spec_mean_freq_289,spec_mean_freq_290,spec_mean_freq_291,spec_mean_freq_292,spec_mean_freq_293,spec_mean_freq_294,spec_mean_freq_295,spec_mean_freq_296,spec_mean_freq_297,spec_mean_freq_298,spec_mean_freq_299,spec_mean_freq_300,spec_mean_freq_301,spec_mean_freq_302,spec_mean_freq_303,spec_mean_freq_304,spec_mean_freq_305,spec_mean_freq_306,spec_mean_freq_307,spec_mean_freq_308,spec_mean_freq_309,spec_mean_freq_310,spec_mean_freq_311,spec_mean_freq_312,spec_mean_freq_313,spec_mean_freq_314,spec_mean_freq_315,spec_mean_freq_316,spec_mean_freq_317,spec_mean_freq_318,spec_mean_freq_319,spec_mean_freq_320,spec_mean_freq_321,spec_mean_freq_322,spec_mean_freq_323,spec_mean_freq_324,spec_mean_freq_325,spec_mean_freq_326,spec_mean_freq_327,spec_mean_freq_328,spec_mean_freq_329,spec_mean_freq_330,spec_mean_freq_331,spec_mean_freq_332,spec_mean_freq_333,spec_mean_freq_334,spec_mean_freq_335,spec_mean_freq_336,spec_mean_freq_337,spec_mean_freq_338,spec_mean_freq_339,spec_mean_freq_340,spec_mean_freq_341,spec_mean_freq_342,spec_mean_freq_343,spec_mean_freq_344,spec_mean_freq_345,spec_mean_freq_346,spec_mean_freq_347,spec_mean_freq_348,spec_mean_freq_349,spec_mean_freq_350,spec_mean_freq_351,spec_mean_freq_352,spec_mean_freq_353,spec_mean_freq_354,spec_mean_freq_355,spec_mean_freq_356,spec_mean_freq_357,spec_mean_freq_358,spec_mean_freq_359,spec_mean_freq_360,spec_mean_freq_361,spec_mean_freq_362,spec_mean_freq_363,spec_mean_freq_364,spec_mean_freq_365,spec_mean_freq_366,spec_mean_freq_367,spec_mean_freq_368,spec_mean_freq_369,spec_mean_freq_370,spec_mean_freq_371,spec_mean_freq_372,spec_mean_freq_373,spec_mean_freq_374,spec_mean_freq_375,spec_mean_freq_376,spec_mean_freq_377,spec_mean_freq_378,spec_mean_freq_379,spec_mean_freq_380,spec_mean_freq_381,spec_mean_freq_382,spec_mean_freq_383,spec_mean_freq_384,spec_mean_freq_385,spec_mean_freq_386,spec_mean_freq_387,spec_mean_freq_388,spec_mean_freq_389,spec_mean_freq_390,spec_mean_freq_391,spec_mean_freq_392,spec_mean_freq_393,spec_mean_freq_394,spec_mean_freq_395,spec_mean_freq_396,spec_mean_freq_397,spec_mean_freq_398,spec_mean_freq_399
0,568657,789577333,0.0,16.0,20654,Other,0.0,0.0,0.25,0.0,0.166667,0.583333,320.005035,418.96521,468.401886,438.750214,364.099854,289.576935,232.589294,183.663315,150.056168,124.47747,103.283936,90.022369,76.535965,65.875931,58.641861,51.808231,47.896965,43.415665,41.003494,39.915268,39.718998,40.58913,40.526432,39.830898,38.853668,35.306831,31.8297,27.230099,23.928635,20.621132,18.748199,16.706667,14.828001,13.442267,12.0869,11.1227,10.3638,9.392666,8.663033,8.189067,7.724,7.532167,7.5734,7.397966,7.2455,6.9567,6.5013,6.105634,5.940067,5.5981,5.3696,4.9714,4.616134,4.3238,4.035567,3.7798,3.569933,3.339333,3.1815,3.033367,2.938834,2.8323,2.7205,2.660433,2.583133,2.503166,2.427533,2.328933,2.241,2.176033,2.050833,1.9733,1.892433,1.792267,1.757,1.683933,1.6247,1.577433,1.497033,1.4543,1.426067,1.388867,1.392033,1.345867,1.323567,1.2809,1.251433,1.193733,1.142867,1.068467,1.031933,0.991167,0.967567,0.950167,0.939,0.9287,0.916367,0.9096,0.895133,0.882933,680.0354,841.835754,895.141052,778.113464,627.641235,504.696777,402.019012,318.695679,258.098602,212.310913,180.588135,151.656387,129.533356,113.475914,100.058327,88.975227,80.968437,74.169533,67.934105,64.286026,61.467201,59.6022,59.294636,58.347134,56.807938,54.098034,48.8522,43.602768,37.723667,33.058067,30.049667,26.438334,23.954998,22.447266,20.446268,19.013166,17.877869,16.863934,15.861333,14.924664,14.006333,13.385468,12.6523,12.129167,11.657434,10.9526,10.458834,9.961433,9.442899,9.0409,8.614166,8.015367,7.646832,7.254299,6.831134,6.443533,6.151334,5.9513,5.7785,5.5607,5.4271,5.2568,5.0432,4.831733,4.640567,4.450033,4.344767,4.2102,4.051867,3.957633,3.7671,3.603534,3.480833,3.3581,3.2171,3.1078,3.0015,2.949433,2.885567,2.828733,2.773067,2.6879,2.6301,2.563534,2.475567,2.428534,2.355433,2.235233,2.1751,2.112333,2.0443,2.002767,1.944,1.9012,1.8855,1.8304,1.795933,1.767266,1.728867,1.676733,393.369263,526.307617,593.36615,570.079285,482.60083,383.061707,315.892853,254.552872,205.908249,170.527863,140.060776,114.557007,99.581169,88.618736,79.320732,73.315132,67.91626,62.000866,57.558567,54.857597,52.549599,53.275204,52.490002,51.550968,51.011398,48.540203,45.191002,40.842899,36.9697,32.273235,29.150534,26.615801,24.183033,21.6973,19.814932,18.438066,17.486799,17.061335,16.633499,16.251534,15.8144,16.019701,16.171734,17.068434,17.611168,17.322399,17.1029,16.089434,15.414535,14.851733,14.074466,12.619866,11.307734,10.079,8.858868,8.106333,7.460033,6.986067,6.469233,6.230967,5.941367,5.927967,5.973367,5.982934,5.877,5.7238,5.562167,5.3677,5.2802,5.162133,4.9265,4.743667,4.4714,4.2442,3.9898,3.642867,3.386467,3.240033,3.0443,2.990566,2.9868,2.969833,3.011867,3.050867,3.078333,3.007334,2.948233,2.807433,2.6416,2.497633,2.353267,2.239967,2.1295,2.023233,1.952267,1.892733,1.8505,1.7992,1.784033,1.799333,1657.101929,2067.21582,2170.995361,1851.59729,1463.577026,1141.660278,894.700012,717.732422,589.266785,490.598267,409.130371,348.696045,297.915924,254.62265,223.556747,200.17334,180.607498,166.524628,152.616928,142.229019,130.864761,123.617722,116.399101,110.616676,105.852898,100.248268,92.1744,84.402367,76.696861,69.904465,64.89547,59.807629,56.019199,52.364834,47.881966,44.136135,41.732201,39.72403,38.420498,36.908833,35.273163,34.388233,32.583771,32.1152,32.045334,30.8396,30.239368,28.975801,28.174166,27.13287,25.997435,24.479933,22.931902,21.848768,20.771097,19.950335,18.853867,18.131231,17.533899,16.615231,16.318867,15.876167,15.320799,14.947767,14.436033,13.897801,13.587167,13.159333,12.644934,12.344168,11.935701,11.418234,11.175067,11.016468,10.836399,10.5372,10.198867,9.888366,9.699466,9.4403,9.258067,9.0718,8.7082,8.387333,8.138767,7.904901,7.7524,7.359901,7.151068,6.967601,6.718432,6.4293,6.291767,6.212066,6.158633,6.0653,5.923433,5.790267,5.545767,5.392233
1,582999,1552638400,0.0,38.0,20230,LPD,0.0,0.857143,0.0,0.071429,0.0,0.071429,58.2211,64.707634,59.324497,39.724167,26.9942,17.692499,12.0471,9.288233,7.682966,6.277733,5.079633,3.926867,3.165367,2.672433,2.412967,2.1976,2.014133,1.8644,1.687667,1.485867,1.327667,1.224833,1.1206,1.018933,0.999133,0.920467,0.857133,0.817333,0.77,0.7378,0.719467,0.717167,0.752467,0.7735,0.791967,0.7811,0.7438,0.698733,0.6772,0.6651,0.6196,0.597233,0.566333,0.549333,0.531867,0.512433,0.4642,0.436367,0.394,0.373233,0.3496,0.332933,0.316133,0.293367,0.274367,0.2596,0.241667,0.230933,0.221233,0.208433,0.204967,0.1918,0.184767,0.174967,0.166267,0.1582,0.149967,0.144033,0.138333,0.130533,0.126933,0.123333,0.120967,0.122467,0.118967,0.117267,0.116567,0.110633,0.108967,0.1006,0.0984,0.0941,0.091567,0.094167,0.0934,0.091,0.0919,0.0902,0.088567,0.085367,0.0807,0.075033,0.074767,0.0732,0.0749,0.0754,0.0738,0.0726,0.074933,0.074133,30.08267,37.212364,38.370701,32.994499,23.889168,16.6798,12.416901,9.7872,8.236667,6.650334,5.283033,4.323166,3.637467,3.236633,2.905967,2.6351,2.394,2.175467,1.9783,1.767533,1.5979,1.438067,1.308733,1.186667,1.1043,1.050367,1.007767,0.979833,0.9493,0.934933,0.912567,0.901667,0.913733,0.928767,0.948067,0.9971,1.029433,1.0433,0.998967,0.957633,0.875633,0.8201,0.782867,0.7525,0.7067,0.664233,0.599,0.5514,0.507633,0.478433,0.441667,0.402967,0.370467,0.3422,0.325467,0.314833,0.2911,0.2673,0.246333,0.2265,0.2173,0.212833,0.2033,0.198633,0.1942,0.191067,0.176233,0.166467,0.159033,0.1557,0.155,0.1549,0.151067,0.145433,0.145767,0.1407,0.1393,0.1346,0.132633,0.132533,0.1361,0.1321,0.130133,0.1313,0.1328,0.1298,0.128567,0.1293,0.1268,0.131033,0.123933,0.1188,0.115067,0.106033,0.1073,0.110767,0.106733,0.108633,0.106633,0.100733,49.179234,57.232964,54.682968,41.020737,30.361334,21.689835,15.276033,11.623799,9.152866,7.2943,5.771233,4.5479,3.634267,2.989033,2.529633,2.260833,2.101167,1.980167,1.766433,1.586433,1.419767,1.299433,1.192933,1.0798,0.978867,0.8973,0.816867,0.756067,0.699033,0.670267,0.658867,0.6623,0.684133,0.679233,0.66,0.6226,0.590633,0.539333,0.5334,0.532867,0.509133,0.501833,0.4913,0.490367,0.4737,0.4689,0.4251,0.3968,0.358767,0.337933,0.320767,0.308,0.298167,0.276567,0.2583,0.2473,0.228033,0.219967,0.206833,0.190067,0.183833,0.169933,0.1621,0.155533,0.152367,0.144433,0.1366,0.130833,0.122567,0.114867,0.111967,0.107433,0.102167,0.104267,0.099833,0.1002,0.096633,0.090267,0.0872,0.080033,0.077667,0.076,0.0729,0.071667,0.070667,0.067733,0.0657,0.063733,0.060333,0.058833,0.056233,0.054767,0.053733,0.052433,0.051367,0.050267,0.047233,0.0466,0.0463,0.047133,32.290764,42.471668,45.512264,44.800598,37.439636,29.972969,26.328268,23.039068,22.6982,21.0916,18.8395,16.1294,12.681067,10.243466,8.782899,7.5166,6.6683,5.788733,4.668133,3.8018,3.1304,2.5537,2.1413,1.818566,1.603533,1.458333,1.3255,1.207933,1.134433,1.069467,1.0174,0.970733,0.9415,0.9073,0.900767,0.894133,0.892267,0.863333,0.827267,0.790133,0.722433,0.6799,0.636733,0.618967,0.583767,0.5694,0.519933,0.4772,0.436833,0.403033,0.374067,0.3563,0.331267,0.309167,0.2917,0.2716,0.248967,0.228533,0.211267,0.1943,0.187933,0.1823,0.174867,0.167667,0.157167,0.149267,0.140167,0.134333,0.129033,0.1304,0.129933,0.131033,0.129133,0.123767,0.1165,0.108367,0.1036,0.099867,0.098067,0.0947,0.089467,0.086367,0.0833,0.079967,0.080133,0.078533,0.075467,0.073267,0.068267,0.067733,0.0648,0.064067,0.065867,0.066167,0.0628,0.062467,0.059933,0.057433,0.055967,0.0566
2,642382,14960202,1008.0,1032.0,5955,Other,0.0,0.0,0.0,0.0,0.0,1.0,10.100633,12.456667,13.519967,11.7832,9.385166,6.897967,5.0934,3.298833,2.091267,1.6143,1.3008,1.212433,1.215733,1.298967,1.377067,1.317367,1.262167,1.128467,0.957733,0.859333,0.788433,0.7399,0.733,0.714,0.7472,0.771667,0.797467,0.8502,0.8999,0.9271,0.8953,0.8576,0.7629,0.677367,0.5662,0.475033,0.4043,0.337033,0.311567,0.272333,0.2524,0.2347,0.214733,0.204167,0.189867,0.174,0.159967,0.146467,0.138667,0.136533,0.130633,0.125133,0.1162,0.1079,0.1087,0.104433,0.102567,0.102233,0.093567,0.0942,0.090633,0.089167,0.0836,0.079033,0.080333,0.083167,0.0857,0.085667,0.0821,0.0787,0.075767,0.076033,0.0778,0.072667,0.071533,0.071933,0.0687,0.0729,0.078067,0.081433,0.085833,0.090967,0.091133,0.0936,0.093833,0.0865,0.083767,0.0809,0.082467,0.086867,0.0849,0.078633,0.078633,0.0796,0.085467,0.088633,0.089633,0.082633,0.086433,0.097467,13.647733,16.801931,18.401432,15.351367,12.2673,9.013834,6.516767,4.517867,3.047733,2.336567,1.91,1.7612,1.7984,1.941967,2.094234,2.115767,2.0559,1.8911,1.697067,1.595633,1.572033,1.540533,1.537667,1.5313,1.5309,1.578367,1.683,1.7688,1.856267,1.8609,1.716767,1.5976,1.418,1.211567,1.049433,0.886933,0.7345,0.6302,0.557533,0.4857,0.437567,0.390267,0.345567,0.316733,0.2884,0.2657,0.248567,0.227567,0.213733,0.206633,0.200167,0.197767,0.189167,0.1762,0.168067,0.160767,0.1615,0.167233,0.1703,0.173233,0.169867,0.167833,0.1603,0.1523,0.155233,0.156233,0.151533,0.1489,0.1367,0.1218,0.119133,0.1208,0.1228,0.126567,0.123,0.118133,0.119,0.1179,0.1152,0.1083,0.099167,0.091933,0.087533,0.086767,0.085667,0.081767,0.0813,0.077067,0.078567,0.080267,0.077733,0.075667,0.072967,0.072567,0.0716,0.071067,0.075267,0.0739,0.073833,0.0742,7.236099,9.6128,11.1102,10.753833,9.496599,7.1643,5.450267,3.5912,2.375433,1.893,1.522933,1.401467,1.458933,1.6609,1.781067,1.710333,1.580767,1.295133,1.022433,0.874367,0.804733,0.763833,0.7742,0.804833,0.871233,0.941433,0.984533,1.0542,1.110533,1.158367,1.189133,1.182133,1.092133,0.971833,0.804367,0.6441,0.5209,0.424533,0.3859,0.340933,0.324333,0.3069,0.286433,0.269967,0.255733,0.243,0.232233,0.222667,0.215667,0.211733,0.205733,0.2035,0.1997,0.194267,0.190933,0.185367,0.185767,0.181367,0.177367,0.1701,0.161667,0.153733,0.147333,0.148033,0.154267,0.157667,0.159467,0.1507,0.1409,0.138933,0.133933,0.1326,0.1351,0.134067,0.1294,0.128933,0.128533,0.125233,0.129133,0.128867,0.126867,0.132633,0.1348,0.152333,0.150867,0.147367,0.137833,0.126267,0.126167,0.128333,0.129833,0.1224,0.120133,0.1189,0.1334,0.134167,0.125733,0.1207,0.1157,0.1141,6812.349121,6632.060547,6151.301758,2865.360107,1290.640747,824.744507,684.944458,660.535095,613.916443,557.049561,434.950684,365.127411,257.868896,223.915207,158.622162,123.955505,106.561867,84.842339,82.349968,65.972099,46.856464,38.492371,30.714931,24.838566,31.52877,34.100098,35.930202,41.057533,38.416763,36.943764,32.672798,20.724234,20.371634,15.771167,14.617301,14.505267,13.153434,13.192734,12.680667,11.847234,10.480499,9.456767,8.151667,7.516434,8.615066,7.709367,7.6189,7.995667,6.743167,6.598666,6.340967,5.5551,4.590433,3.8642,3.279433,2.8476,2.848767,2.5704,2.0741,2.1419,1.9266,1.910767,1.989967,2.245333,2.271267,2.2185,2.2305,1.9224,1.678,1.642166,1.6496,1.427967,1.1662,1.121833,1.016733,0.873833,0.804067,0.8463,0.968033,1.075067,1.222233,1.250633,1.205433,1.1299,1.035967,0.8878,0.7744,0.71,0.6838,0.7538,0.725767,0.767067,0.8333,0.7996,0.745,0.6844,0.699433,0.681167,0.6115,0.5317
3,751790,618728447,908.0,908.0,38549,GPD,0.0,0.0,1.0,0.0,0.0,0.0,27.988499,37.592232,48.323536,54.410072,54.579601,55.704361,55.597767,54.303066,49.619965,48.152267,45.519432,40.649597,36.830364,34.139233,29.813499,26.261633,23.875767,20.964935,18.921501,17.116034,15.236499,13.245632,11.763733,10.691632,10.0176,9.123199,8.385734,7.413333,6.847067,6.147567,5.441033,5.039267,4.513233,3.953967,3.552734,3.243533,3.015367,2.709633,2.437367,2.164733,1.9861,1.8392,1.696833,1.5729,1.4263,1.3068,1.234133,1.112367,1.044767,0.975267,0.873733,0.829267,0.7702,0.717633,0.687433,0.648433,0.6149,0.577767,0.5274,0.500533,0.454467,0.434833,0.402433,0.373967,0.3486,0.327467,0.3121,0.290467,0.271333,0.269933,0.2416,0.234633,0.230167,0.213267,0.207467,0.201533,0.191733,0.187333,0.177833,0.1709,0.165667,0.1551,0.146733,0.140467,0.1334,0.1282,0.1236,0.117033,0.112433,0.1061,0.1024,0.099233,0.095633,0.0926,0.089833,0.084667,0.082433,0.079,0.075233,0.073967,14.939134,26.062866,35.750301,41.753963,45.532372,45.543499,43.630169,45.293301,42.012436,39.7537,37.462864,34.208702,32.685234,31.936666,29.344032,27.347832,24.515068,21.597666,19.154667,17.469534,15.3857,13.424733,12.0438,10.607533,9.445667,8.663934,7.6934,6.697999,6.0605,5.477366,4.9137,4.508367,4.058033,3.557367,3.207833,2.914567,2.721133,2.463967,2.297933,2.158333,1.929567,1.809633,1.647433,1.5297,1.4181,1.316867,1.232133,1.1389,1.036867,0.9773,0.9333,0.885933,0.839667,0.7737,0.715967,0.672033,0.621,0.5845,0.551167,0.5141,0.4897,0.470833,0.433767,0.410467,0.381033,0.3581,0.3471,0.327467,0.304433,0.297567,0.282333,0.265267,0.257433,0.2388,0.2285,0.216633,0.205967,0.200667,0.191033,0.1827,0.175467,0.166867,0.1602,0.1505,0.1449,0.1341,0.128133,0.124533,0.117933,0.115167,0.111167,0.1044,0.0986,0.0959,0.091333,0.088133,0.086,0.0816,0.078167,0.076033,19.709633,25.616833,33.305698,38.774002,39.518536,44.033836,45.122566,44.836037,43.092602,41.629936,39.767502,37.468002,34.176636,31.494667,26.90107,23.933733,21.170668,18.461735,17.158066,15.355133,13.645133,12.099566,10.749233,9.547933,8.521066,7.5486,6.7241,5.7735,5.134,4.507101,3.9346,3.5478,3.1771,2.7771,2.522467,2.287733,2.010133,1.784233,1.5916,1.414367,1.2958,1.199133,1.096233,1.007733,0.9072,0.834167,0.754533,0.6933,0.657033,0.609133,0.575367,0.5389,0.505133,0.465733,0.426,0.4126,0.398833,0.378567,0.358533,0.339367,0.3017,0.2863,0.269133,0.255167,0.2395,0.230533,0.2203,0.2047,0.193767,0.189833,0.168033,0.164233,0.161633,0.146533,0.1476,0.142567,0.131933,0.130633,0.122733,0.116233,0.114133,0.108667,0.100333,0.0948,0.089933,0.083967,0.0818,0.079067,0.074567,0.071733,0.0673,0.0637,0.0609,0.0591,0.0548,0.052833,0.050333,0.048367,0.0471,0.045833,14.398467,22.584566,32.795235,41.0798,45.080601,47.312599,43.921432,41.846035,39.081532,38.565533,39.337833,38.014103,37.105064,34.338535,30.528471,28.159199,25.103065,23.261801,21.2689,19.249434,17.139498,15.056933,13.823334,12.201401,11.011666,10.0604,8.7033,7.6577,6.965866,6.415667,5.6993,5.3653,4.629734,4.034534,3.662367,3.311433,3.0764,2.8044,2.5513,2.323,2.175233,2.0253,1.874767,1.7234,1.581033,1.4399,1.311433,1.179633,1.117,1.061567,1.006433,0.953867,0.892033,0.7987,0.7488,0.711567,0.679467,0.631467,0.587733,0.543767,0.510633,0.482167,0.456867,0.439633,0.395567,0.382833,0.3559,0.337533,0.3144,0.299167,0.279467,0.269,0.270167,0.2525,0.239567,0.2247,0.202567,0.198233,0.1901,0.178033,0.178233,0.1671,0.1548,0.147733,0.1384,0.128067,0.124667,0.121267,0.118167,0.113167,0.1093,0.1045,0.0971,0.095733,0.0896,0.085533,0.080433,0.0765,0.075867,0.074267
4,778705,52296320,0.0,0.0,40955,Other,0.0,0.0,0.0,0.0,0.0,1.0,55.1339,59.609035,54.085033,38.695969,29.842068,25.815701,21.814369,17.542032,14.638301,12.792633,11.2042,10.158667,9.086234,8.279499,7.420534,7.126501,7.119,7.2066,7.194533,7.2259,6.600934,5.863833,5.073333,4.362967,3.9219,3.478767,3.078966,2.692633,2.4285,2.261133,2.1459,1.9893,1.8481,1.6129,1.446333,1.248667,1.085733,0.944533,0.826,0.758733,0.690067,0.605933,0.5787,0.5241,0.483667,0.452433,0.4326,0.407,0.382767,0.3617,0.349767,0.330067,0.327533,0.314733,0.305,0.3033,0.3057,0.313733,0.3108,0.323133,0.3224,0.334533,0.3463,0.3509,0.361367,0.356767,0.365967,0.357333,0.359967,0.361633,0.373133,0.370267,0.370167,0.375933,0.396,0.418767,0.4376,0.4416,0.418033,0.395533,0.392467,0.394467,0.403667,0.416067,0.423667,0.4101,0.395533,0.3704,0.3839,0.364467,0.360167,0.353833,0.3597,0.373533,0.382433,0.3866,0.351933,0.3158,0.297867,0.282267,312.732941,300.969421,261.835327,161.932373,104.2957,76.425499,57.949364,43.021198,33.374268,26.196465,20.461666,17.178267,14.908,13.6251,12.727467,12.7265,12.698133,12.884534,12.517067,12.538667,11.815066,11.120433,10.270499,9.3974,8.412733,7.720767,6.7709,6.033233,5.432867,5.006933,4.743267,4.469267,4.214434,3.874833,3.509667,3.1074,2.614067,2.170133,1.8507,1.5863,1.370933,1.230933,1.1335,1.049567,0.997067,0.9412,0.9038,0.831033,0.796933,0.744767,0.687067,0.623667,0.5959,0.559067,0.5472,0.540033,0.538867,0.521367,0.506867,0.495133,0.4923,0.501733,0.508067,0.5149,0.509167,0.5045,0.498667,0.4803,0.464767,0.456533,0.460533,0.4528,0.442233,0.437,0.410467,0.3811,0.3552,0.338033,0.3221,0.316067,0.3032,0.293733,0.279567,0.276733,0.269567,0.269533,0.2613,0.253233,0.251567,0.2353,0.2257,0.220367,0.206967,0.2027,0.194867,0.186333,0.19,0.187467,0.1988,0.186733,57.562664,61.262104,56.695866,42.121803,32.127068,29.532097,26.201765,21.8652,19.208168,17.453665,15.851699,14.648,13.283466,12.1966,11.040667,10.785233,11.351266,12.093967,12.8135,12.925934,11.777834,10.170067,8.247934,6.913,5.992066,5.202433,4.586533,3.881333,3.336367,2.9377,2.709767,2.419767,2.273566,2.0592,1.897067,1.686633,1.4799,1.2977,1.132667,1.0225,0.941833,0.844733,0.797933,0.722333,0.662467,0.615733,0.5854,0.5588,0.5143,0.485933,0.4551,0.428667,0.434033,0.413367,0.418567,0.422633,0.428567,0.425967,0.4147,0.4097,0.398167,0.4087,0.417533,0.4224,0.438,0.4321,0.442167,0.4379,0.431767,0.4216,0.426767,0.419033,0.424133,0.4435,0.452667,0.4911,0.493133,0.4858,0.4615,0.415133,0.412167,0.4048,0.432167,0.437567,0.434233,0.4386,0.4146,0.391233,0.404333,0.391167,0.368933,0.3708,0.396367,0.409,0.4415,0.4767,0.452,0.4017,0.3896,0.368433,65846.085938,68842.40625,68851.75,55388.613281,39317.582031,25017.021484,14574.469727,8076.858398,3755.092041,2056.221191,1992.355225,2557.420654,3331.400146,3700.48291,3487.062744,2999.618652,2313.9104,1630.109985,966.860107,614.847961,466.956451,519.780029,767.063171,1103.037964,1383.853638,1686.205444,1908.256714,1926.521851,1873.067749,1770.876221,1544.171509,1250.646606,1009.769348,816.479614,679.017212,664.84906,760.155701,901.157837,1074.843384,1265.417236,1407.040283,1450.925903,1431.746338,1326.91394,1134.745972,895.890442,674.062866,473.032928,302.425873,209.791046,179.571457,193.978958,245.573181,310.760834,353.925934,368.684631,363.431549,331.339996,272.194336,217.258133,172.229752,133.577026,107.330925,86.534637,72.531067,59.404263,40.791031,27.450436,24.206299,34.50787,64.382835,117.526794,186.749908,261.17746,335.425507,401.727448,433.831757,440.820221,429.007965,387.76181,330.051788,279.052826,244.362076,225.534241,235.419327,272.091339,319.219421,369.110565,420.183319,451.925354,455.840363,443.026611,412.881775,363.57959,306.752686,256.296387,211.686829,175.735931,156.308807,151.012115


## Train XGBoost Model

In [7]:
N_SPLITS = 5

In [8]:
fold_creator = KFoldCreator(n_splits=N_SPLITS, seed=Constants.SEED)
train_folds_df = fold_creator.create_folds(
    df=train_df, stratify_col="expert_consensus", group_col="patient_id"
)

In [9]:
all_oof = []
all_true = []
targets_dict = {"Seizure":0, "LPD":1, "GPD":2, "LRDA":3, "GRDA":4, "Other":5}

models_save_path = get_models_save_path() / "xgboost" / "spectrogram_means" / DATA_PREPARATION_VOTE_METHOD
models_save_path.mkdir(parents=True, exist_ok=True)

for fold in range(N_SPLITS):
    fold_train_df = train_folds_df[train_folds_df["fold"] != fold].reset_index(drop=True)
    fold_valid_df = train_folds_df[train_folds_df["fold"] == fold].reset_index(drop=True)

    print("=" * 40)
    print(f"FOLD {fold}")
    print(f"Train size: {len(fold_train_df)}, Valid size: {len(fold_valid_df)}")
    print("=" * 30)

    X_train = fold_train_df[FEATURES]
    y_train = fold_train_df["expert_consensus"].map(targets_dict)
    
    X_valid = fold_valid_df[FEATURES]
    y_valid = fold_valid_df["expert_consensus"].map(targets_dict)

    dtrain = xgb.DMatrix(X_train, label=y_train)
    dvalid = xgb.DMatrix(X_valid, label=y_valid)

    params = {
        "objective": "multi:softprob",
        "num_class": len(Constants.TARGETS),
        "device": "cuda",
        "tree_method": "hist",
        "eval_metric": "mlogloss",
        "seed": Constants.SEED,
    }

    evals = [(dvalid, "eval")]
    model = xgb.train(
        params,
        dtrain,
        num_boost_round=300,
        evals=evals,
        verbose_eval=100,
        early_stopping_rounds=10,
    )
    
    model.save_model(models_save_path / f"fold_{fold}.json")

    oof = model.predict(dvalid)
    all_oof.extend(oof)

    all_true.extend(fold_valid_df[Constants.TARGETS].values)

    del X_train, y_train, X_valid, y_valid, dtrain, dvalid, oof
    gc.collect()

all_oof = np.array(all_oof)
all_true = np.array(all_true)

FOLD 0
Train size: 13755, Valid size: 3334
[0]	eval-mlogloss:1.63018
[26]	eval-mlogloss:1.35269
FOLD 1
Train size: 13151, Valid size: 3938
[0]	eval-mlogloss:1.60886
[27]	eval-mlogloss:1.32410
FOLD 2
Train size: 13422, Valid size: 3667
[0]	eval-mlogloss:1.62823
[22]	eval-mlogloss:1.32224
FOLD 3
Train size: 14356, Valid size: 2733
[0]	eval-mlogloss:1.61536
[32]	eval-mlogloss:1.23476
FOLD 4
Train size: 13672, Valid size: 3417
[0]	eval-mlogloss:1.64184
[19]	eval-mlogloss:1.41802


## CV Score

In [31]:
all_oof_tensor = torch.tensor(all_oof, dtype=torch.float32)
all_true_tensor = torch.tensor(all_true, dtype=torch.float32)

kl_score = nn.KLDivLoss(reduction="batchmean")
score = kl_score(all_oof_tensor.log(), all_true_tensor).item()

print(f"OOF KL Score: {score}")

OOF KL Score: 1.055870532989502


## Infer on Test and create Submission

In [11]:
del train_df
gc.collect()

0

In [13]:
test_spectrograms_dir = DATA_PATH / "test_spectrograms"
test_spectrogram_files = list(test_spectrograms_dir.glob("*.parquet"))
print(f"Found {len(test_spectrogram_files)} test spectrogram files to load into memory")

test_spectrograms = {}
for file in tqdm(test_spectrogram_files):
  spectrogram_id, content = get_spectrogram_content(file)
  test_spectrograms[spectrogram_id] = content

gc.collect()
print("Loaded all test spectrograms into memory")

Found 1 test spectrogram files to load into memory


100%|██████████| 1/1 [00:00<00:00, 28.70it/s]

Loaded all test spectrograms into memory





In [23]:
test_data = np.zeros((len(test_df), len(FEATURES)))

def extract_test_spectrogram_features(row, all_spectrograms):
  # this differs from train because all test spectrograms are exactly 10 minutes long, so we don't need to extract the center window
  spectrogram_id = int(row["spectrogram_id"])
  average_frequencies = np.array(all_spectrograms[spectrogram_id][:] ).mean(axis=0) # average over 300 rows (10 minutes)
  return average_frequencies

for i in tqdm(range(len(test_df)), total=len(test_df)):
  row = test_df.iloc[i]
  test_data[i,:] = extract_test_spectrogram_features(row, test_spectrograms)

100%|██████████| 1/1 [00:00<00:00, 1406.07it/s]


In [24]:
test_df[FEATURES] = test_data

del test_data
del test_spectrograms
gc.collect()

test_df.head()

Unnamed: 0,spectrogram_id,eeg_id,patient_id,spec_mean_freq_0,spec_mean_freq_1,spec_mean_freq_2,spec_mean_freq_3,spec_mean_freq_4,spec_mean_freq_5,spec_mean_freq_6,spec_mean_freq_7,spec_mean_freq_8,spec_mean_freq_9,spec_mean_freq_10,spec_mean_freq_11,spec_mean_freq_12,spec_mean_freq_13,spec_mean_freq_14,spec_mean_freq_15,spec_mean_freq_16,spec_mean_freq_17,spec_mean_freq_18,spec_mean_freq_19,spec_mean_freq_20,spec_mean_freq_21,spec_mean_freq_22,spec_mean_freq_23,spec_mean_freq_24,spec_mean_freq_25,spec_mean_freq_26,spec_mean_freq_27,spec_mean_freq_28,spec_mean_freq_29,spec_mean_freq_30,spec_mean_freq_31,spec_mean_freq_32,spec_mean_freq_33,spec_mean_freq_34,spec_mean_freq_35,spec_mean_freq_36,spec_mean_freq_37,spec_mean_freq_38,spec_mean_freq_39,spec_mean_freq_40,spec_mean_freq_41,spec_mean_freq_42,spec_mean_freq_43,spec_mean_freq_44,spec_mean_freq_45,spec_mean_freq_46,spec_mean_freq_47,spec_mean_freq_48,spec_mean_freq_49,spec_mean_freq_50,spec_mean_freq_51,spec_mean_freq_52,spec_mean_freq_53,spec_mean_freq_54,spec_mean_freq_55,spec_mean_freq_56,spec_mean_freq_57,spec_mean_freq_58,spec_mean_freq_59,spec_mean_freq_60,spec_mean_freq_61,spec_mean_freq_62,spec_mean_freq_63,spec_mean_freq_64,spec_mean_freq_65,spec_mean_freq_66,spec_mean_freq_67,spec_mean_freq_68,spec_mean_freq_69,spec_mean_freq_70,spec_mean_freq_71,spec_mean_freq_72,spec_mean_freq_73,spec_mean_freq_74,spec_mean_freq_75,spec_mean_freq_76,spec_mean_freq_77,spec_mean_freq_78,spec_mean_freq_79,spec_mean_freq_80,spec_mean_freq_81,spec_mean_freq_82,spec_mean_freq_83,spec_mean_freq_84,spec_mean_freq_85,spec_mean_freq_86,spec_mean_freq_87,spec_mean_freq_88,spec_mean_freq_89,spec_mean_freq_90,spec_mean_freq_91,spec_mean_freq_92,spec_mean_freq_93,spec_mean_freq_94,spec_mean_freq_95,spec_mean_freq_96,spec_mean_freq_97,spec_mean_freq_98,spec_mean_freq_99,spec_mean_freq_100,spec_mean_freq_101,spec_mean_freq_102,spec_mean_freq_103,spec_mean_freq_104,spec_mean_freq_105,spec_mean_freq_106,spec_mean_freq_107,spec_mean_freq_108,spec_mean_freq_109,spec_mean_freq_110,spec_mean_freq_111,spec_mean_freq_112,spec_mean_freq_113,spec_mean_freq_114,spec_mean_freq_115,spec_mean_freq_116,spec_mean_freq_117,spec_mean_freq_118,spec_mean_freq_119,spec_mean_freq_120,spec_mean_freq_121,spec_mean_freq_122,spec_mean_freq_123,spec_mean_freq_124,spec_mean_freq_125,spec_mean_freq_126,spec_mean_freq_127,spec_mean_freq_128,spec_mean_freq_129,spec_mean_freq_130,spec_mean_freq_131,spec_mean_freq_132,spec_mean_freq_133,spec_mean_freq_134,spec_mean_freq_135,spec_mean_freq_136,spec_mean_freq_137,spec_mean_freq_138,spec_mean_freq_139,spec_mean_freq_140,spec_mean_freq_141,spec_mean_freq_142,spec_mean_freq_143,spec_mean_freq_144,spec_mean_freq_145,spec_mean_freq_146,spec_mean_freq_147,spec_mean_freq_148,spec_mean_freq_149,spec_mean_freq_150,spec_mean_freq_151,spec_mean_freq_152,spec_mean_freq_153,spec_mean_freq_154,spec_mean_freq_155,spec_mean_freq_156,spec_mean_freq_157,spec_mean_freq_158,spec_mean_freq_159,spec_mean_freq_160,spec_mean_freq_161,spec_mean_freq_162,spec_mean_freq_163,spec_mean_freq_164,spec_mean_freq_165,spec_mean_freq_166,spec_mean_freq_167,spec_mean_freq_168,spec_mean_freq_169,spec_mean_freq_170,spec_mean_freq_171,spec_mean_freq_172,spec_mean_freq_173,spec_mean_freq_174,spec_mean_freq_175,spec_mean_freq_176,spec_mean_freq_177,spec_mean_freq_178,spec_mean_freq_179,spec_mean_freq_180,spec_mean_freq_181,spec_mean_freq_182,spec_mean_freq_183,spec_mean_freq_184,spec_mean_freq_185,spec_mean_freq_186,spec_mean_freq_187,spec_mean_freq_188,spec_mean_freq_189,spec_mean_freq_190,spec_mean_freq_191,spec_mean_freq_192,spec_mean_freq_193,spec_mean_freq_194,spec_mean_freq_195,spec_mean_freq_196,spec_mean_freq_197,spec_mean_freq_198,spec_mean_freq_199,spec_mean_freq_200,spec_mean_freq_201,spec_mean_freq_202,spec_mean_freq_203,spec_mean_freq_204,spec_mean_freq_205,spec_mean_freq_206,spec_mean_freq_207,spec_mean_freq_208,spec_mean_freq_209,spec_mean_freq_210,spec_mean_freq_211,spec_mean_freq_212,spec_mean_freq_213,spec_mean_freq_214,spec_mean_freq_215,spec_mean_freq_216,spec_mean_freq_217,spec_mean_freq_218,spec_mean_freq_219,spec_mean_freq_220,spec_mean_freq_221,spec_mean_freq_222,spec_mean_freq_223,spec_mean_freq_224,spec_mean_freq_225,spec_mean_freq_226,spec_mean_freq_227,spec_mean_freq_228,spec_mean_freq_229,spec_mean_freq_230,spec_mean_freq_231,spec_mean_freq_232,spec_mean_freq_233,spec_mean_freq_234,spec_mean_freq_235,spec_mean_freq_236,spec_mean_freq_237,spec_mean_freq_238,spec_mean_freq_239,spec_mean_freq_240,spec_mean_freq_241,spec_mean_freq_242,spec_mean_freq_243,spec_mean_freq_244,spec_mean_freq_245,spec_mean_freq_246,spec_mean_freq_247,spec_mean_freq_248,spec_mean_freq_249,spec_mean_freq_250,spec_mean_freq_251,spec_mean_freq_252,spec_mean_freq_253,spec_mean_freq_254,spec_mean_freq_255,spec_mean_freq_256,spec_mean_freq_257,spec_mean_freq_258,spec_mean_freq_259,spec_mean_freq_260,spec_mean_freq_261,spec_mean_freq_262,spec_mean_freq_263,spec_mean_freq_264,spec_mean_freq_265,spec_mean_freq_266,spec_mean_freq_267,spec_mean_freq_268,spec_mean_freq_269,spec_mean_freq_270,spec_mean_freq_271,spec_mean_freq_272,spec_mean_freq_273,spec_mean_freq_274,spec_mean_freq_275,spec_mean_freq_276,spec_mean_freq_277,spec_mean_freq_278,spec_mean_freq_279,spec_mean_freq_280,spec_mean_freq_281,spec_mean_freq_282,spec_mean_freq_283,spec_mean_freq_284,spec_mean_freq_285,spec_mean_freq_286,spec_mean_freq_287,spec_mean_freq_288,spec_mean_freq_289,spec_mean_freq_290,spec_mean_freq_291,spec_mean_freq_292,spec_mean_freq_293,spec_mean_freq_294,spec_mean_freq_295,spec_mean_freq_296,spec_mean_freq_297,spec_mean_freq_298,spec_mean_freq_299,spec_mean_freq_300,spec_mean_freq_301,spec_mean_freq_302,spec_mean_freq_303,spec_mean_freq_304,spec_mean_freq_305,spec_mean_freq_306,spec_mean_freq_307,spec_mean_freq_308,spec_mean_freq_309,spec_mean_freq_310,spec_mean_freq_311,spec_mean_freq_312,spec_mean_freq_313,spec_mean_freq_314,spec_mean_freq_315,spec_mean_freq_316,spec_mean_freq_317,spec_mean_freq_318,spec_mean_freq_319,spec_mean_freq_320,spec_mean_freq_321,spec_mean_freq_322,spec_mean_freq_323,spec_mean_freq_324,spec_mean_freq_325,spec_mean_freq_326,spec_mean_freq_327,spec_mean_freq_328,spec_mean_freq_329,spec_mean_freq_330,spec_mean_freq_331,spec_mean_freq_332,spec_mean_freq_333,spec_mean_freq_334,spec_mean_freq_335,spec_mean_freq_336,spec_mean_freq_337,spec_mean_freq_338,spec_mean_freq_339,spec_mean_freq_340,spec_mean_freq_341,spec_mean_freq_342,spec_mean_freq_343,spec_mean_freq_344,spec_mean_freq_345,spec_mean_freq_346,spec_mean_freq_347,spec_mean_freq_348,spec_mean_freq_349,spec_mean_freq_350,spec_mean_freq_351,spec_mean_freq_352,spec_mean_freq_353,spec_mean_freq_354,spec_mean_freq_355,spec_mean_freq_356,spec_mean_freq_357,spec_mean_freq_358,spec_mean_freq_359,spec_mean_freq_360,spec_mean_freq_361,spec_mean_freq_362,spec_mean_freq_363,spec_mean_freq_364,spec_mean_freq_365,spec_mean_freq_366,spec_mean_freq_367,spec_mean_freq_368,spec_mean_freq_369,spec_mean_freq_370,spec_mean_freq_371,spec_mean_freq_372,spec_mean_freq_373,spec_mean_freq_374,spec_mean_freq_375,spec_mean_freq_376,spec_mean_freq_377,spec_mean_freq_378,spec_mean_freq_379,spec_mean_freq_380,spec_mean_freq_381,spec_mean_freq_382,spec_mean_freq_383,spec_mean_freq_384,spec_mean_freq_385,spec_mean_freq_386,spec_mean_freq_387,spec_mean_freq_388,spec_mean_freq_389,spec_mean_freq_390,spec_mean_freq_391,spec_mean_freq_392,spec_mean_freq_393,spec_mean_freq_394,spec_mean_freq_395,spec_mean_freq_396,spec_mean_freq_397,spec_mean_freq_398,spec_mean_freq_399
0,853520,3911565283,6885,16.864132,19.120565,18.342468,13.408634,8.0575,4.890133,3.460633,2.449133,1.897233,1.4797,1.182633,1.006167,0.891467,0.753133,0.6846,0.6094,0.5587,0.514467,0.478933,0.444567,0.421967,0.3919,0.373467,0.361367,0.353,0.360733,0.372033,0.399633,0.424467,0.460167,0.486367,0.5182,0.5363,0.568733,0.577433,0.5539,0.523867,0.493867,0.461667,0.456867,0.4538,0.4455,0.437367,0.434867,0.4432,0.458433,0.477,0.502033,0.521667,0.544,0.5779,0.582733,0.5943,0.6222,0.654533,0.693967,0.738767,0.778133,0.829433,0.8237,0.873333,0.892333,0.916567,0.957367,1.0038,1.038933,1.106067,1.125733,1.2267,1.239733,1.249867,1.2768,1.2544,1.3177,1.3731,1.426867,1.460433,1.531267,1.478533,1.4778,1.491433,1.4029,1.446533,1.488467,1.472533,1.492567,1.5073,1.4574,1.444367,1.450667,1.4012,1.4325,1.379033,1.392333,1.378533,1.3269,1.285867,1.2649,1.240767,1.2392,79.439728,85.337959,75.281937,57.786068,43.450436,36.510532,30.038836,18.540333,11.741167,7.270267,6.426301,6.220433,6.2233,5.861366,4.8899,3.917433,2.818267,2.329467,1.937533,1.743933,1.682567,1.509467,1.315266,1.110167,0.9141,0.850367,0.8255,0.810367,0.7937,0.773767,0.706867,0.660933,0.6276,0.604733,0.5724,0.5296,0.4799,0.418367,0.3843,0.349733,0.3201,0.307733,0.2886,0.279933,0.277833,0.253833,0.247933,0.2276,0.2117,0.2054,0.1934,0.194933,0.1908,0.184567,0.174633,0.164867,0.156933,0.153067,0.157433,0.159833,0.161633,0.165133,0.167367,0.1609,0.1556,0.1529,0.151867,0.158267,0.156967,0.158333,0.156633,0.152233,0.148167,0.1425,0.141267,0.1428,0.1363,0.136833,0.1318,0.123833,0.118733,0.116867,0.1149,0.113933,0.114,0.1112,0.1055,0.1039,0.1041,0.106333,0.110533,0.111967,0.112,0.113967,0.1154,0.1168,0.1157,0.117067,0.119767,0.1203,42.766666,52.335632,52.112499,41.719269,30.291132,18.436733,7.219434,3.286067,2.154533,1.561333,1.236,0.9977,0.8392,0.717333,0.619133,0.532367,0.4957,0.4531,0.4211,0.403733,0.376833,0.339933,0.3227,0.312633,0.299867,0.292867,0.282133,0.264133,0.250433,0.246233,0.247433,0.255633,0.260967,0.261133,0.2571,0.253,0.240467,0.2341,0.223067,0.221833,0.215767,0.215867,0.212767,0.209067,0.2079,0.2078,0.2088,0.210233,0.214133,0.2139,0.221333,0.2282,0.230767,0.228567,0.229867,0.223733,0.218933,0.213333,0.2044,0.202733,0.197433,0.195633,0.196733,0.191,0.193,0.195867,0.194533,0.198667,0.205667,0.205733,0.2083,0.208167,0.205967,0.211167,0.2059,0.216367,0.220467,0.2254,0.232333,0.231467,0.233967,0.2277,0.226867,0.222333,0.221533,0.222,0.2213,0.2279,0.2398,0.2516,0.257433,0.265233,0.2549,0.2522,0.259667,0.2666,0.2662,0.267167,0.263967,0.2713,57.133801,62.297031,57.049767,37.8801,23.447701,15.482634,11.216233,7.917367,5.6927,3.981433,3.233567,2.7636,2.4356,2.151067,1.9043,1.6329,1.321467,1.1728,0.9932,0.902767,0.871033,0.8085,0.759867,0.7245,0.636633,0.608667,0.5847,0.560567,0.542933,0.522967,0.500667,0.465333,0.454933,0.435767,0.4129,0.3929,0.368,0.348967,0.3567,0.343367,0.325733,0.3194,0.292367,0.291933,0.284067,0.274633,0.272133,0.2598,0.2576,0.2532,0.2396,0.2304,0.2129,0.2039,0.2031,0.189267,0.185767,0.180533,0.168167,0.165,0.155367,0.1452,0.145867,0.142967,0.140467,0.134167,0.126567,0.1197,0.113567,0.112933,0.1128,0.1145,0.115133,0.111033,0.107233,0.1042,0.100233,0.099333,0.098533,0.097567,0.0945,0.0943,0.0912,0.088733,0.0892,0.0904,0.087733,0.088433,0.089033,0.084067,0.088567,0.086333,0.083633,0.084067,0.081933,0.082867,0.084267,0.082633,0.083967,0.081533


In [27]:
test_preds = []

for fold in range(N_SPLITS):
  print("=" * 40)
  print(f"Predicting fold {fold}")
  print("=" * 40)

  X_train = test_df[FEATURES]
  dtest = xgb.DMatrix(X_train)

  model = xgb.Booster()
  model.load_model(models_save_path / f"fold_{fold}.json")

  preds = model.predict(dtest)
  test_preds.append(preds)

test_preds = np.mean(test_preds, axis=0)
print(f"Test predictions shape: {test_preds.shape}")

Predicting fold 0
Predicting fold 1
Predicting fold 2
Predicting fold 3
Predicting fold 4
Test predictions shape: (1, 6)


In [29]:
# sanity check: all predictions should sum to 1
assert np.allclose(test_preds.sum(axis=1), 1.0)

In [30]:
submission = pd.DataFrame({"eeg_id": test_df["eeg_id"]})
submission[Constants.TARGETS] = test_preds

submission.to_csv(get_submission_csv_path(), index=False)