In [1]:
import pandas as pd 
import numpy as np 
from scipy.stats import entropy
import matplotlib.pyplot as plt

from engine_hms_trainer import *
from engine_hms_model import CustomModel, JobConfig, ModelConfig

import torch
from torch import nn
import torch.nn.functional as F

  _torch_pytree._register_pytree_node(


In [2]:
seed_everything(JobConfig.SEED)

ModelConfig.EPOCHS = 6
ModelConfig.USE_EEG_SPECTROGRAMS = False
ModelConfig.MODEL_BACKBONE = 'tf_efficientnet_b2'
ModelConfig.MODEL_NAME = "ENet_b2_xymask"
ModelConfig.AUGMENT = True
ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
ModelConfig.USE_EEG_SPECTROGRAMS = True

ModelConfig.AUGMENTATIONS = ['xy_masking']

hms_predictor = HMSPredictor(JobConfig, ModelConfig)

****************************************************************************************************
Script Start: Sun Mar 10 00:46:23 2024
Initializing HMS Predictor...
Model Name: ENet_b2_xymask
Drop Rate: 0.15
Drop Path Rate: 0.25
Augment: True
Augmentations: ['xy_masking']
Enropy Split: 5.5
Device: cuda
Output Dir: ./outputs/
****************************************************************************************************


In [3]:
train_easy, train_hard, all_specs, all_eegs = hms_predictor.load_train_data()

print(train_easy.shape)
print(train_hard.shape)

# check if contain NaN
print(train_easy.isnull().sum().sum())
print(train_hard.isnull().sum().sum())

display(train_easy.head())
print(" ")
display(train_hard.head())

(12440, 14)
(5536, 14)
0
0


Unnamed: 0,eeg_id,spectrogram_id,min,max,patient_id,target,total_votes,entropy,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,642382,14960202,1008.0,1032.0,5955,Other,2,7.802343,0.0,0.0,0.0,0.0,0.0,1.0
1,751790,618728447,908.0,908.0,38549,GPD,1,7.802343,0.0,0.0,1.0,0.0,0.0,0.0
2,778705,52296320,0.0,0.0,40955,Other,2,7.68682,0.0,0.0,0.0,0.0,0.0,1.0
3,1629671,2036345030,0.0,160.0,37481,Seizure,51,7.619243,1.0,0.0,0.0,0.0,0.0,0.0
4,2061593,320962633,1450.0,1450.0,23828,Other,1,7.802343,0.0,0.0,0.0,0.0,0.0,1.0


 


Unnamed: 0,eeg_id,spectrogram_id,min,max,patient_id,target,total_votes,entropy,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,568657,789577333,0.0,16.0,20654,Other,48,3.341757,0.0,0.0,0.25,0.0,0.166667,0.583333
1,582999,1552638400,0.0,38.0,20230,LPD,154,3.550549,0.0,0.857143,0.0,0.071429,0.0,0.071429
2,1895581,128369999,1138.0,1138.0,47999,Other,13,3.565051,0.076923,0.0,0.0,0.0,0.076923,0.846154
3,2482631,978166025,1902.0,1944.0,20606,Other,105,1.431066,0.0,0.0,0.133333,0.066667,0.133333,0.666667
4,2521897,673742515,0.0,4.0,62117,Other,24,1.516203,0.0,0.0,0.083333,0.083333,0.333333,0.5


In [4]:
hms_predictor.train_folds(train_easy, train_hard, all_specs, all_eegs)

Fold: 0 || Valid size 3596 
- First Stage 


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [1][0/622]Elapsed 1.08s | Loss: 0.8220 Grad: 67731.3828 LR: 4.0000e-06
Epoch: [1][50/622]Elapsed 8.17s | Loss: 0.8346 Grad: 85699.5469 LR: 5.0647e-06
Epoch: [1][100/622]Elapsed 15.29s | Loss: 0.8281 Grad: 70401.9922 LR: 8.2116e-06
Epoch: [1][150/622]Elapsed 22.41s | Loss: 0.8233 Grad: 75559.6250 LR: 1.3301e-05
Epoch: [1][200/622]Elapsed 29.53s | Loss: 0.8153 Grad: 72208.4922 LR: 2.0107e-05
Epoch: [1][250/622]Elapsed 36.67s | Loss: 0.8045 Grad: 70698.4688 LR: 2.8328e-05
Epoch: [1][300/622]Elapsed 43.82s | Loss: 0.7909 Grad: 78792.6719 LR: 3.7599e-05
Epoch: [1][350/622]Elapsed 51.00s | Loss: 0.7736 Grad: 118346.3359 LR: 4.7509e-05
Epoch: [1][400/622]Elapsed 58.20s | Loss: 0.7534 Grad: 59210.1992 LR: 5.7619e-05
Epoch: [1][450/622]Elapsed 65.43s | Loss: 0.7316 Grad: 76243.6875 LR: 6.7479e-05
Epoch: [1][500/622]Elapsed 72.67s | Loss: 0.7124 Grad: 71753.1094 LR: 7.6652e-05
Epoch: [1][550/622]Elapsed 79.88s | Loss: 0.6960 Grad: 73817.9453 LR: 8.4732e-05
Epoch: [1][600/622]Elapsed 87.10

Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [1][0/225]Elapsed 0.10s | Loss: 0.3521
Epoch: [1][50/225]Elapsed 4.94s | Loss: 0.5367
Epoch: [1][100/225]Elapsed 9.77s | Loss: 0.5370
Epoch: [1][150/225]Elapsed 14.60s | Loss: 0.5320
Epoch: [1][200/225]Elapsed 19.46s | Loss: 0.4840


----------------------------------------------------------------------------------------------------
Epoch 1 - Average Train Loss: 0.6699 | Average Valid Loss: 0.4670 | Time: 112.17s
Best model found in epoch 1 | valid loss: 0.4670


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [2][0/622]Elapsed 0.10s | Loss: 0.3647 Grad: 149451.5312 LR: 9.3737e-05
Epoch: [2][50/622]Elapsed 7.37s | Loss: 0.4326 Grad: 62924.1523 LR: 9.7777e-05
Epoch: [2][100/622]Elapsed 14.72s | Loss: 0.4220 Grad: 90355.3750 LR: 9.9786e-05
Epoch: [2][150/622]Elapsed 22.11s | Loss: 0.4204 Grad: 49365.2891 LR: 9.9996e-05
Epoch: [2][200/622]Elapsed 29.49s | Loss: 0.4184 Grad: 37034.8984 LR: 9.9967e-05
Epoch: [2][250/622]Elapsed 36.86s | Loss: 0.4110 Grad: 31872.4395 LR: 9.9911e-05
Epoch: [2][300/622]Elapsed 44.22s | Loss: 0.4057 Grad: 60893.3438 LR: 9.9828e-05
Epoch: [2][350/622]Elapsed 51.61s | Loss: 0.4026 Grad: 48400.3359 LR: 9.9717e-05
Epoch: [2][400/622]Elapsed 58.95s | Loss: 0.3955 Grad: 23087.1836 LR: 9.9579e-05
Epoch: [2][450/622]Elapsed 66.30s | Loss: 0.3881 Grad: 44029.6172 LR: 9.9415e-05
Epoch: [2][500/622]Elapsed 73.65s | Loss: 0.3848 Grad: 27803.1016 LR: 9.9223e-05
Epoch: [2][550/622]Elapsed 81.00s | Loss: 0.3841 Grad: 50971.1719 LR: 9.9004e-05
Epoch: [2][600/622]Elapsed 88.35

Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [2][0/225]Elapsed 0.10s | Loss: 0.3426
Epoch: [2][50/225]Elapsed 4.94s | Loss: 0.4441
Epoch: [2][100/225]Elapsed 9.77s | Loss: 0.4462
Epoch: [2][150/225]Elapsed 14.59s | Loss: 0.4484
Epoch: [2][200/225]Elapsed 19.44s | Loss: 0.4165


----------------------------------------------------------------------------------------------------
Epoch 2 - Average Train Loss: 0.3786 | Average Valid Loss: 0.4081 | Time: 113.43s
Best model found in epoch 2 | valid loss: 0.4081


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [3][0/622]Elapsed 0.10s | Loss: 0.2989 Grad: 131582.7969 LR: 9.8642e-05
Epoch: [3][50/622]Elapsed 7.46s | Loss: 0.3011 Grad: 49325.7539 LR: 9.8358e-05
Epoch: [3][100/622]Elapsed 14.90s | Loss: 0.2977 Grad: 122732.9766 LR: 9.8048e-05
Epoch: [3][150/622]Elapsed 22.33s | Loss: 0.3050 Grad: 43385.8242 LR: 9.7711e-05
Epoch: [3][200/622]Elapsed 29.76s | Loss: 0.3069 Grad: 29812.1562 LR: 9.7349e-05
Epoch: [3][250/622]Elapsed 37.22s | Loss: 0.3058 Grad: 25272.7148 LR: 9.6960e-05
Epoch: [3][300/622]Elapsed 44.64s | Loss: 0.3033 Grad: 29568.8418 LR: 9.6546e-05
Epoch: [3][350/622]Elapsed 52.06s | Loss: 0.3062 Grad: 48146.9766 LR: 9.6106e-05
Epoch: [3][400/622]Elapsed 59.48s | Loss: 0.3002 Grad: 22185.4883 LR: 9.5642e-05
Epoch: [3][450/622]Elapsed 66.91s | Loss: 0.2960 Grad: 30812.9062 LR: 9.5152e-05
Epoch: [3][500/622]Elapsed 74.34s | Loss: 0.2942 Grad: 29210.2734 LR: 9.4638e-05
Epoch: [3][550/622]Elapsed 81.75s | Loss: 0.2951 Grad: 42554.5703 LR: 9.4099e-05
Epoch: [3][600/622]Elapsed 89.1

Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [3][0/225]Elapsed 0.10s | Loss: 0.3156
Epoch: [3][50/225]Elapsed 4.93s | Loss: 0.4314
Epoch: [3][100/225]Elapsed 9.75s | Loss: 0.4339
Epoch: [3][150/225]Elapsed 14.53s | Loss: 0.4356
Epoch: [3][200/225]Elapsed 19.32s | Loss: 0.4217


----------------------------------------------------------------------------------------------------
Epoch 3 - Average Train Loss: 0.2932 | Average Valid Loss: 0.4177 | Time: 114.13s


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [4][0/622]Elapsed 0.10s | Loss: 0.1849 Grad: 130271.0078 LR: 9.3281e-05
Epoch: [4][50/622]Elapsed 7.51s | Loss: 0.2529 Grad: 26867.1777 LR: 9.2683e-05
Epoch: [4][100/622]Elapsed 14.97s | Loss: 0.2470 Grad: 44417.0742 LR: 9.2063e-05
Epoch: [4][150/622]Elapsed 22.42s | Loss: 0.2523 Grad: 58387.7734 LR: 9.1420e-05
Epoch: [4][200/622]Elapsed 29.87s | Loss: 0.2543 Grad: 34293.4531 LR: 9.0754e-05
Epoch: [4][250/622]Elapsed 37.30s | Loss: 0.2529 Grad: 18504.4902 LR: 9.0065e-05
Epoch: [4][300/622]Elapsed 44.72s | Loss: 0.2527 Grad: 30702.2910 LR: 8.9355e-05
Epoch: [4][350/622]Elapsed 52.14s | Loss: 0.2541 Grad: 48462.7578 LR: 8.8624e-05
Epoch: [4][400/622]Elapsed 59.55s | Loss: 0.2486 Grad: 24215.0137 LR: 8.7871e-05
Epoch: [4][450/622]Elapsed 66.98s | Loss: 0.2434 Grad: 47173.6484 LR: 8.7097e-05
Epoch: [4][500/622]Elapsed 74.42s | Loss: 0.2423 Grad: 41595.6172 LR: 8.6303e-05
Epoch: [4][550/622]Elapsed 81.85s | Loss: 0.2441 Grad: 39099.7305 LR: 8.5490e-05
Epoch: [4][600/622]Elapsed 89.25

Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [4][0/225]Elapsed 0.10s | Loss: 0.3225
Epoch: [4][50/225]Elapsed 4.88s | Loss: 0.4401
Epoch: [4][100/225]Elapsed 9.69s | Loss: 0.4433
Epoch: [4][150/225]Elapsed 14.50s | Loss: 0.4436
Epoch: [4][200/225]Elapsed 19.32s | Loss: 0.4325


----------------------------------------------------------------------------------------------------
Epoch 4 - Average Train Loss: 0.2416 | Average Valid Loss: 0.4282 | Time: 114.22s


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [5][0/622]Elapsed 0.10s | Loss: 0.2198 Grad: 162993.0312 LR: 8.4284e-05
Epoch: [5][50/622]Elapsed 7.54s | Loss: 0.2117 Grad: 29758.4648 LR: 8.3424e-05
Epoch: [5][100/622]Elapsed 15.01s | Loss: 0.2048 Grad: 38943.4805 LR: 8.2546e-05
Epoch: [5][150/622]Elapsed 22.52s | Loss: 0.2131 Grad: 39602.3359 LR: 8.1650e-05
Epoch: [5][200/622]Elapsed 29.98s | Loss: 0.2162 Grad: 28816.3398 LR: 8.0736e-05
Epoch: [5][250/622]Elapsed 37.42s | Loss: 0.2138 Grad: 25072.4902 LR: 7.9806e-05
Epoch: [5][300/622]Elapsed 44.85s | Loss: 0.2114 Grad: 27906.9082 LR: 7.8859e-05
Epoch: [5][350/622]Elapsed 52.29s | Loss: 0.2114 Grad: 36871.9375 LR: 7.7897e-05
Epoch: [5][400/622]Elapsed 59.70s | Loss: 0.2077 Grad: 57041.8750 LR: 7.6920e-05
Epoch: [5][450/622]Elapsed 67.12s | Loss: 0.2052 Grad: 57072.8672 LR: 7.5927e-05
Epoch: [5][500/622]Elapsed 74.51s | Loss: 0.2049 Grad: 25914.4160 LR: 7.4921e-05
Epoch: [5][550/622]Elapsed 81.91s | Loss: 0.2063 Grad: 51203.0312 LR: 7.3901e-05
Epoch: [5][600/622]Elapsed 89.35

Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [5][0/225]Elapsed 0.11s | Loss: 0.2984
Epoch: [5][50/225]Elapsed 4.93s | Loss: 0.4490
Epoch: [5][100/225]Elapsed 9.74s | Loss: 0.4495
Epoch: [5][150/225]Elapsed 14.54s | Loss: 0.4525
Epoch: [5][200/225]Elapsed 19.36s | Loss: 0.4454


----------------------------------------------------------------------------------------------------
Epoch 5 - Average Train Loss: 0.2042 | Average Valid Loss: 0.4425 | Time: 114.40s


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [6][0/622]Elapsed 0.11s | Loss: 0.1971 Grad: 229036.6875 LR: 7.2409e-05
Epoch: [6][50/622]Elapsed 7.48s | Loss: 0.1824 Grad: 27124.2637 LR: 7.1358e-05
Epoch: [6][100/622]Elapsed 14.95s | Loss: 0.1768 Grad: 35226.4492 LR: 7.0296e-05
Epoch: [6][150/622]Elapsed 22.39s | Loss: 0.1812 Grad: 41538.5352 LR: 6.9222e-05
Epoch: [6][200/622]Elapsed 29.83s | Loss: 0.1840 Grad: 54086.1914 LR: 6.8138e-05
Epoch: [6][250/622]Elapsed 37.27s | Loss: 0.1838 Grad: 35163.1992 LR: 6.7044e-05
Epoch: [6][300/622]Elapsed 44.70s | Loss: 0.1813 Grad: 40802.1211 LR: 6.5940e-05
Epoch: [6][350/622]Elapsed 52.13s | Loss: 0.1803 Grad: 55207.8281 LR: 6.4828e-05
Epoch: [6][400/622]Elapsed 59.57s | Loss: 0.1754 Grad: 28926.9434 LR: 6.3708e-05
Epoch: [6][450/622]Elapsed 67.01s | Loss: 0.1727 Grad: 31264.4414 LR: 6.2581e-05
Epoch: [6][500/622]Elapsed 74.44s | Loss: 0.1712 Grad: 17770.4707 LR: 6.1446e-05
Epoch: [6][550/622]Elapsed 81.87s | Loss: 0.1712 Grad: 43745.3125 LR: 6.0305e-05
Epoch: [6][600/622]Elapsed 89.28

Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [6][0/225]Elapsed 0.10s | Loss: 0.3582
Epoch: [6][50/225]Elapsed 4.90s | Loss: 0.4707
Epoch: [6][100/225]Elapsed 9.69s | Loss: 0.4787
Epoch: [6][150/225]Elapsed 14.48s | Loss: 0.4782
Epoch: [6][200/225]Elapsed 19.29s | Loss: 0.4763


----------------------------------------------------------------------------------------------------
Epoch 6 - Average Train Loss: 0.1697 | Average Valid Loss: 0.4742 | Time: 114.23s
Fold 0 Valid Loss: 
Easy: 0.8951 | Hard: 0.6388
Elapse: 11.40 min 
- Second Stage 
Use Checkpoint: ENet_b2_xymask_fold_0_stage_1.pth


Train:   0%|          | 0/276 [00:00<?, ?batch/s]

Epoch: [1][0/276]Elapsed 0.10s | Loss: 0.6521 Grad: nan LR: 4.0000e-06




Epoch: [1][50/276]Elapsed 7.44s | Loss: 0.4548 Grad: 94100.8594 LR: 9.3614e-06
Epoch: [1][100/276]Elapsed 14.91s | Loss: 0.4342 Grad: 79477.4922 LR: 2.4248e-05
Epoch: [1][150/276]Elapsed 22.38s | Loss: 0.4065 Grad: 46323.3906 LR: 4.5334e-05
Epoch: [1][200/276]Elapsed 29.85s | Loss: 0.3745 Grad: 29432.5996 LR: 6.7909e-05
Epoch: [1][250/276]Elapsed 37.30s | Loss: 0.3482 Grad: 25979.0859 LR: 8.6930e-05
Epoch: [1][275/276]Elapsed 41.11s | Loss: 0.3374 Grad: 36008.8477 LR: 9.3946e-05


Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [1][0/225]Elapsed 0.11s | Loss: 0.4283
Epoch: [1][50/225]Elapsed 4.92s | Loss: 0.4450
Epoch: [1][100/225]Elapsed 9.72s | Loss: 0.4483
Epoch: [1][150/225]Elapsed 14.54s | Loss: 0.4443
Epoch: [1][200/225]Elapsed 19.35s | Loss: 0.3965


----------------------------------------------------------------------------------------------------
Epoch 1 - Average Train Loss: 0.3374 | Average Valid Loss: 0.3787 | Time: 62.94s
Best model found in epoch 1 | valid loss: 0.3787


Train:   0%|          | 0/276 [00:00<?, ?batch/s]

Epoch: [2][0/276]Elapsed 0.10s | Loss: 0.3010 Grad: nan LR: 9.3946e-05
Epoch: [2][50/276]Elapsed 7.49s | Loss: 0.2283 Grad: 73792.9531 LR: 9.9978e-05
Epoch: [2][100/276]Elapsed 14.97s | Loss: 0.2174 Grad: 110963.1953 LR: 9.9939e-05
Epoch: [2][150/276]Elapsed 22.42s | Loss: 0.2149 Grad: 54119.9414 LR: 9.9740e-05
Epoch: [2][200/276]Elapsed 29.88s | Loss: 0.2120 Grad: 64724.0703 LR: 9.9403e-05
Epoch: [2][250/276]Elapsed 37.34s | Loss: 0.2079 Grad: 41688.7227 LR: 9.8929e-05
Epoch: [2][275/276]Elapsed 41.12s | Loss: 0.2067 Grad: 74235.5156 LR: 9.8628e-05


Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [2][0/225]Elapsed 0.11s | Loss: 0.3390
Epoch: [2][50/225]Elapsed 4.93s | Loss: 0.4379
Epoch: [2][100/225]Elapsed 9.74s | Loss: 0.4382
Epoch: [2][150/225]Elapsed 14.55s | Loss: 0.4350
Epoch: [2][200/225]Elapsed 19.38s | Loss: 0.3841


----------------------------------------------------------------------------------------------------
Epoch 2 - Average Train Loss: 0.2067 | Average Valid Loss: 0.3651 | Time: 62.98s
Best model found in epoch 2 | valid loss: 0.3651


Train:   0%|          | 0/276 [00:00<?, ?batch/s]

Epoch: [3][0/276]Elapsed 0.10s | Loss: 0.2917 Grad: 189928.5938 LR: 9.8628e-05
Epoch: [3][50/276]Elapsed 7.51s | Loss: 0.1907 Grad: 45940.3477 LR: 9.7948e-05
Epoch: [3][100/276]Elapsed 15.00s | Loss: 0.1847 Grad: 69405.4219 LR: 9.7135e-05
Epoch: [3][150/276]Elapsed 22.44s | Loss: 0.1843 Grad: 48978.7695 LR: 9.6191e-05
Epoch: [3][200/276]Elapsed 29.87s | Loss: 0.1826 Grad: 32440.0586 LR: 9.5119e-05
Epoch: [3][250/276]Elapsed 37.29s | Loss: 0.1800 Grad: 35709.1992 LR: 9.3922e-05
Epoch: [3][275/276]Elapsed 41.06s | Loss: 0.1791 Grad: 77738.2109 LR: 9.3251e-05


Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [3][0/225]Elapsed 0.10s | Loss: 0.3279
Epoch: [3][50/225]Elapsed 4.91s | Loss: 0.4229
Epoch: [3][100/225]Elapsed 9.72s | Loss: 0.4236
Epoch: [3][150/225]Elapsed 14.53s | Loss: 0.4205
Epoch: [3][200/225]Elapsed 19.36s | Loss: 0.3712


----------------------------------------------------------------------------------------------------
Epoch 3 - Average Train Loss: 0.1791 | Average Valid Loss: 0.3526 | Time: 62.90s
Best model found in epoch 3 | valid loss: 0.3526


Train:   0%|          | 0/276 [00:00<?, ?batch/s]

Epoch: [4][0/276]Elapsed 0.11s | Loss: 0.1955 Grad: 124167.5781 LR: 9.3251e-05
Epoch: [4][50/276]Elapsed 7.48s | Loss: 0.1677 Grad: 130851.2500 LR: 9.1870e-05
Epoch: [4][100/276]Elapsed 14.92s | Loss: 0.1654 Grad: 141889.1250 LR: 9.0373e-05
Epoch: [4][150/276]Elapsed 22.38s | Loss: 0.1646 Grad: 102266.6250 LR: 8.8763e-05
Epoch: [4][200/276]Elapsed 29.82s | Loss: 0.1639 Grad: 57032.6875 LR: 8.7047e-05
Epoch: [4][250/276]Elapsed 37.25s | Loss: 0.1617 Grad: 77901.7109 LR: 8.5227e-05
Epoch: [4][275/276]Elapsed 41.06s | Loss: 0.1612 Grad: 119717.8359 LR: 8.4242e-05


Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [4][0/225]Elapsed 0.10s | Loss: 0.3233
Epoch: [4][50/225]Elapsed 4.92s | Loss: 0.4179
Epoch: [4][100/225]Elapsed 9.75s | Loss: 0.4181
Epoch: [4][150/225]Elapsed 14.58s | Loss: 0.4152
Epoch: [4][200/225]Elapsed 19.41s | Loss: 0.3667


----------------------------------------------------------------------------------------------------
Epoch 4 - Average Train Loss: 0.1612 | Average Valid Loss: 0.3483 | Time: 62.96s
Best model found in epoch 4 | valid loss: 0.3483


Train:   0%|          | 0/276 [00:00<?, ?batch/s]

Epoch: [5][0/276]Elapsed 0.10s | Loss: 0.1618 Grad: 83491.6250 LR: 8.4242e-05
Epoch: [5][50/276]Elapsed 7.49s | Loss: 0.1542 Grad: 91924.5938 LR: 8.2275e-05
Epoch: [5][100/276]Elapsed 14.95s | Loss: 0.1488 Grad: 103052.4688 LR: 8.0220e-05
Epoch: [5][150/276]Elapsed 22.44s | Loss: 0.1489 Grad: 74675.7109 LR: 7.8080e-05
Epoch: [5][200/276]Elapsed 29.85s | Loss: 0.1483 Grad: 52619.8203 LR: 7.5863e-05
Epoch: [5][250/276]Elapsed 37.27s | Loss: 0.1469 Grad: 66618.9922 LR: 7.3573e-05
Epoch: [5][275/276]Elapsed 41.04s | Loss: 0.1464 Grad: 117523.7812 LR: 7.2357e-05


Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [5][0/225]Elapsed 0.11s | Loss: 0.3316
Epoch: [5][50/225]Elapsed 4.94s | Loss: 0.4296
Epoch: [5][100/225]Elapsed 9.75s | Loss: 0.4273
Epoch: [5][150/225]Elapsed 14.56s | Loss: 0.4233
Epoch: [5][200/225]Elapsed 19.38s | Loss: 0.3738


----------------------------------------------------------------------------------------------------
Epoch 5 - Average Train Loss: 0.1464 | Average Valid Loss: 0.3549 | Time: 62.91s


Train:   0%|          | 0/276 [00:00<?, ?batch/s]

Epoch: [6][0/276]Elapsed 0.10s | Loss: 0.1440 Grad: 76021.1562 LR: 7.2357e-05
Epoch: [6][50/276]Elapsed 7.49s | Loss: 0.1365 Grad: 77988.0078 LR: 6.9971e-05
Epoch: [6][100/276]Elapsed 14.97s | Loss: 0.1362 Grad: 129757.1172 LR: 6.7529e-05
Epoch: [6][150/276]Elapsed 22.44s | Loss: 0.1367 Grad: 106577.1016 LR: 6.5039e-05
Epoch: [6][200/276]Elapsed 29.89s | Loss: 0.1365 Grad: 72352.9766 LR: 6.2507e-05
Epoch: [6][250/276]Elapsed 37.34s | Loss: 0.1341 Grad: 110907.5156 LR: 5.9941e-05
Epoch: [6][275/276]Elapsed 41.12s | Loss: 0.1339 Grad: 55976.5547 LR: 5.8595e-05


Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [6][0/225]Elapsed 0.10s | Loss: 0.2966
Epoch: [6][50/225]Elapsed 4.91s | Loss: 0.4231
Epoch: [6][100/225]Elapsed 9.72s | Loss: 0.4216
Epoch: [6][150/225]Elapsed 14.54s | Loss: 0.4169
Epoch: [6][200/225]Elapsed 19.37s | Loss: 0.3681


----------------------------------------------------------------------------------------------------
Epoch 6 - Average Train Loss: 0.1339 | Average Valid Loss: 0.3495 | Time: 62.98s
Fold 0 Valid Loss: 
Easy: 0.8287 | Hard: 0.3999
Elapse: 17.70 min 
Fold: 1 || Valid size 3595 
- First Stage 


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [1][0/622]Elapsed 0.10s | Loss: 0.8383 Grad: 79178.3672 LR: 4.0000e-06
Epoch: [1][50/622]Elapsed 7.43s | Loss: 0.8258 Grad: 52341.9453 LR: 5.0647e-06
Epoch: [1][100/622]Elapsed 14.86s | Loss: 0.8205 Grad: 75250.2812 LR: 8.2116e-06
Epoch: [1][150/622]Elapsed 22.31s | Loss: 0.8140 Grad: 91786.8750 LR: 1.3301e-05
Epoch: [1][200/622]Elapsed 29.75s | Loss: 0.8066 Grad: 84235.8203 LR: 2.0107e-05
Epoch: [1][250/622]Elapsed 37.17s | Loss: 0.7937 Grad: 39432.7109 LR: 2.8328e-05
Epoch: [1][300/622]Elapsed 44.58s | Loss: 0.7810 Grad: 44135.5938 LR: 3.7599e-05
Epoch: [1][350/622]Elapsed 51.97s | Loss: 0.7629 Grad: 54122.6562 LR: 4.7509e-05
Epoch: [1][400/622]Elapsed 59.38s | Loss: 0.7434 Grad: 95066.4609 LR: 5.7619e-05
Epoch: [1][450/622]Elapsed 66.77s | Loss: 0.7228 Grad: 52241.8828 LR: 6.7479e-05
Epoch: [1][500/622]Elapsed 74.16s | Loss: 0.7049 Grad: 46159.8555 LR: 7.6652e-05
Epoch: [1][550/622]Elapsed 81.56s | Loss: 0.6864 Grad: 81157.5078 LR: 8.4732e-05
Epoch: [1][600/622]Elapsed 88.96s

Valid:   0%|          | 0/225 [00:00<?, ?batch/s]

Epoch: [1][0/225]Elapsed 0.10s | Loss: 0.4669
Epoch: [1][50/225]Elapsed 4.89s | Loss: 0.5039
Epoch: [1][100/225]Elapsed 9.68s | Loss: 0.5290
Epoch: [1][150/225]Elapsed 14.47s | Loss: 0.5342
Epoch: [1][200/225]Elapsed 19.29s | Loss: 0.4900


----------------------------------------------------------------------------------------------------
Epoch 1 - Average Train Loss: 0.6622 | Average Valid Loss: 0.4750 | Time: 113.90s
Best model found in epoch 1 | valid loss: 0.4750


Train:   0%|          | 0/622 [00:00<?, ?batch/s]

Epoch: [2][0/622]Elapsed 0.10s | Loss: 0.4490 Grad: 187277.1562 LR: 9.3737e-05
Epoch: [2][50/622]Elapsed 7.54s | Loss: 0.4302 Grad: 85278.2031 LR: 9.7777e-05
Epoch: [2][100/622]Elapsed 15.01s | Loss: 0.4135 Grad: 61995.9531 LR: 9.9786e-05
Epoch: [2][150/622]Elapsed 22.50s | Loss: 0.4132 Grad: 62129.7383 LR: 9.9996e-05
Epoch: [2][200/622]Elapsed 29.96s | Loss: 0.4079 Grad: 82348.5078 LR: 9.9967e-05
Epoch: [2][250/622]Elapsed 37.39s | Loss: 0.4015 Grad: 81442.2031 LR: 9.9911e-05
Epoch: [2][300/622]Elapsed 44.83s | Loss: 0.4008 Grad: 111012.2812 LR: 9.9828e-05
Epoch: [2][350/622]Elapsed 52.27s | Loss: 0.3963 Grad: 70084.2188 LR: 9.9717e-05
Epoch: [2][400/622]Elapsed 59.68s | Loss: 0.3922 Grad: 91746.7344 LR: 9.9579e-05
Epoch: [2][450/622]Elapsed 67.09s | Loss: 0.3873 Grad: 61504.5586 LR: 9.9415e-05
Epoch: [2][500/622]Elapsed 74.49s | Loss: 0.3827 Grad: 117298.9531 LR: 9.9223e-05


In [None]:
dataset = CustomDataset(train_easy, TARGETS, ModelConfig, all_specs, all_eegs, mode='test')

X, y = dataset[0]
print(X.shape, y.shape)

model = CustomModel(ModelConfig, num_classes=6, pretrained=True)
y_pred = model(X.unsqueeze(0))

print(y_pred.shape)

In [None]:
from kl_divergence import score as kl_score


def calc_kl_div(p, q, criterion):
    
    p = torch.tensor(p.astype(np.float32)).unsqueeze(0)
    q = torch.tensor(q.astype(np.float32)).unsqueeze(0)
    return criterion(F.log_softmax(p, dim=1), q).item()

def calc_kaggle_score(solution, submission):
    solution = solution.to_frame().T
    solution[TARGETS] = solution[TARGETS].astype(np.float32)
    submission = submission.to_frame().T
    submission.columns = ['eeg_id'] + TARGETS
    submission[TARGETS] = submission[TARGETS].astype(np.float32)
    
    return kl_score(solution, submission, 'eeg_id')

In [None]:
def evaluate_oof(oof_csv_path):
    oof_df = pd.read_csv(oof_csv_path)
    softmax = nn.Softmax(dim=1)
    criterion = nn.KLDivLoss(reduction="batchmean")

    oof_df["kl_loss"] = oof_df.apply(lambda row: 
        calc_kl_div(row[TARGETS_PRED].values, row[TARGETS].values, criterion), axis=1
        )

    kl_loss_all = criterion(
        F.log_softmax(torch.tensor(oof_df[TARGETS_PRED].values.astype(np.float32)), dim=1),
        torch.tensor(oof_df[TARGETS].values.astype(np.float32)),
        )

    print(f"KL Loss All: {kl_loss_all}")
    print(f"KL Loss Mean: {oof_df['kl_loss'].mean()}")

    y_pred = oof_df[TARGETS].values.astype(np.float32)
    oof_df[TARGETS_PRED] = softmax(torch.tensor(y_pred)).numpy()

    solution = oof_df[['eeg_id'] + TARGETS].copy()
    submission = oof_df[['eeg_id'] + TARGETS_PRED].copy()
    submission.columns = ['eeg_id'] + TARGETS

    kaggle_score_all = kl_score(solution, submission, 'eeg_id')
    
    oof_df['kaggle_score'] = oof_df.apply(lambda row:
        calc_kaggle_score(row[['eeg_id'] + TARGETS], row[['eeg_id'] + TARGETS_PRED]), axis=1
        )

    print(f"Kaggle Score All: {kaggle_score_all}")
    print(f"Kaggle Score Mean: {oof_df['kaggle_score'].mean()}")

    return oof_df, kl_loss_all, kaggle_score_all


In [None]:
oof_1, kl_loss_all, kaggle_score_all = evaluate_oof(f"{JobConfig.OUTPUT_DIR}/oof_1.csv")
oof_2, kl_loss_all, kaggle_score_all = evaluate_oof(f"{JobConfig.OUTPUT_DIR}/oof_2.csv")

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10), sharex=True, sharey=True)

# rows = oof_df.iloc[-len(axes.ravel()):, :]
rows = oof_1.sample(len(axes.ravel()))

for i, (idx, row) in enumerate(rows.iterrows()):

    ax = axes.ravel()[i]
    ax.plot(row[TARGETS].values, label='True')
    ax.plot(row[TARGETS_PRED].values, label='Pred')
    ax.set_title(f"{idx} | {row['target']} | KL: {row['kl_loss']:.4f}")
    ax.set_xticks(range(6))
    ax.set_xticklabels(BRAIN_ACTIVITY)
    ax.grid(True)
    ax.legend()

fig.tight_layout()
fig.savefig(f"{JobConfig.OUTPUT_DIR}/oof_examples_1.png")
plt.show()

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10), sharex=True, sharey=True)

# rows = oof_2.iloc[5:5+len(axes.ravel()), :]
rows = oof_2.sample(len(axes.ravel()))

for i, (idx, row) in enumerate(rows.iterrows()):

    ax = axes.ravel()[i]
    y_true = row[TARGETS].values
    y_pred = row[TARGETS_PRED].values
    y_norm = (y_true - y_true.min()) / (y_true.max() - y_true.min())

    ax.plot(row[TARGETS].values, label='True')
    ax.plot(row[TARGETS_PRED].values, label='Pred')
    ax.plot(y_norm, "b:", label='True Norm')

    ax.set_title(f"{idx} | {row['target']} | KL: {row['kl_loss']:.4f}")
    ax.set_xticks(range(6))
    ax.set_xticklabels(BRAIN_ACTIVITY)
    ax.grid(True)
    ax.legend()

fig.tight_layout()
fig.savefig(f"{JobConfig.OUTPUT_DIR}/oof_examples_2.png")
plt.show()

In [None]:
row = oof_2.loc[6]

min_pred = row[TARGETS_PRED].min()
max_pred = row[TARGETS_PRED].max()
print(min_pred, max_pred)

print(row[TARGETS_PRED])

targets_norm = (row[TARGETS] - row[TARGETS].min()) / (row[TARGETS].max() - row[TARGETS].min())

targets_norm = targets_norm / targets_norm.sum()

print(targets_norm)