In [1]:
import pandas as pd 
import numpy as np 
from scipy.stats import entropy
from sklearn.model_selection import GroupKFold
import matplotlib.pyplot as plt

from engine_hms_trainer import *
from engine_hms_model import CustomModel, JobConfig, ModelConfig

import torch
from torch import nn
import torch.nn.functional as F

  _torch_pytree._register_pytree_node(


In [2]:
seed_everything(JobConfig.SEED)

ModelConfig.EPOCHS = 6
ModelConfig.USE_EEG_SPECTROGRAMS = False
ModelConfig.MODEL_BACKBONE = 'tf_efficientnet_b2'
ModelConfig.MODEL_NAME = "ENet_b2_xymask_cutmix"
ModelConfig.AUGMENT = True
ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
ModelConfig.USE_EEG_SPECTROGRAMS = True

ModelConfig.AUGMENTATIONS = ['xy_masking', 'cut_mix']

hms_predictor = HMSPredictor(JobConfig, ModelConfig)

****************************************************************************************************
Script Start: Sat Mar  9 19:40:08 2024
Initializing HMS Predictor...
Model Name: ENet_b2_xymask_cutmix
Drop Rate: 0.15
Drop Path Rate: 0.25
Augment: True
Augmentations: ['xy_masking', 'cut_mix']
Enropy Split: 5.5
Device: cuda
Output Dir: ./outputs/
****************************************************************************************************


In [3]:
train_easy, train_hard, all_specs, all_eegs = hms_predictor.load_train_data()

print(train_easy.shape)
print(train_hard.shape)

# check if contain NaN
print(train_easy.isnull().sum().sum())
print(train_hard.isnull().sum().sum())

display(train_easy.head())
print(" ")
display(train_hard.head())

(11999, 14)
(5090, 14)
0
0


Unnamed: 0,eeg_id,spectrogram_id,min,max,patient_id,target,total_votes,entropy,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,642382,14960202,1008.0,1032.0,5955,Other,2,7.802343,0.0,0.0,0.0,0.0,0.0,1.0
1,751790,618728447,908.0,908.0,38549,GPD,1,7.802343,0.0,0.0,1.0,0.0,0.0,0.0
2,778705,52296320,0.0,0.0,40955,Other,2,7.68682,0.0,0.0,0.0,0.0,0.0,1.0
3,1629671,2036345030,0.0,160.0,37481,Seizure,51,7.619243,1.0,0.0,0.0,0.0,0.0,0.0
4,2061593,320962633,1450.0,1450.0,23828,Other,1,7.802343,0.0,0.0,0.0,0.0,0.0,1.0


 


Unnamed: 0,eeg_id,spectrogram_id,min,max,patient_id,target,total_votes,entropy,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,568657,789577333,0.0,16.0,20654,Other,48,3.341757,0.0,0.0,0.25,0.0,0.166667,0.583333
1,582999,1552638400,0.0,38.0,20230,LPD,154,3.550549,0.0,0.857143,0.0,0.071429,0.0,0.071429
2,1895581,128369999,1138.0,1138.0,47999,Other,13,3.565051,0.076923,0.0,0.0,0.0,0.076923,0.846154
3,2482631,978166025,1902.0,1944.0,20606,Other,105,1.431066,0.0,0.0,0.133333,0.066667,0.133333,0.666667
4,2521897,673742515,0.0,4.0,62117,Other,24,1.516203,0.0,0.0,0.083333,0.083333,0.333333,0.5


In [4]:
# Use only half data for fast debugging
# train_easy = train_easy[:len(train_easy)//2]
# train_hard = train_hard[:len(train_hard)//2]

hms_predictor.train_folds(train_easy, train_hard, all_specs, all_eegs)

Fold: 0 First Training


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [1][0/599]Elapsed 1.07s | Loss: 0.8345 Grad: 69042.6250 LR: 4.0000e-06
Epoch: [1][50/599]Elapsed 8.23s | Loss: 0.8308 Grad: 75636.1328 LR: 5.1479e-06
Epoch: [1][100/599]Elapsed 15.30s | Loss: 0.8226 Grad: 61634.6484 LR: 8.5368e-06
Epoch: [1][150/599]Elapsed 22.38s | Loss: 0.8171 Grad: 57516.7031 LR: 1.4005e-05
Epoch: [1][200/599]Elapsed 29.47s | Loss: 0.8091 Grad: 70790.9766 LR: 2.1290e-05
Epoch: [1][250/599]Elapsed 36.56s | Loss: 0.7969 Grad: 69634.0078 LR: 3.0044e-05
Epoch: [1][300/599]Elapsed 43.66s | Loss: 0.7804 Grad: 79607.8828 LR: 3.9848e-05
Epoch: [1][350/599]Elapsed 50.78s | Loss: 0.7615 Grad: 129872.3438 LR: 5.0233e-05
Epoch: [1][400/599]Elapsed 57.89s | Loss: 0.7422 Grad: 53024.8203 LR: 6.0703e-05
Epoch: [1][450/599]Elapsed 65.03s | Loss: 0.7222 Grad: 71614.6484 LR: 7.0757e-05
Epoch: [1][500/599]Elapsed 72.17s | Loss: 0.7026 Grad: 67804.5078 LR: 7.9913e-05
Epoch: [1][550/599]Elapsed 79.32s | Loss: 0.6828 Grad: 75003.4922 LR: 8.7735e-05
Epoch: [1][598/599]Elapsed 86.21

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [1][0/150]Elapsed 0.10s | Loss: 0.7545
Epoch: [1][50/150]Elapsed 4.85s | Loss: 0.4970
Epoch: [1][100/150]Elapsed 9.59s | Loss: 0.5004


----------------------------------------------------------------------------------------------------
Epoch 1 - Average Train Loss: 0.6654 | Average Valid Loss: 0.4962 | Time: 100.69s
Best model found in epoch 1 | valid loss: 0.4962


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [2][0/599]Elapsed 0.10s | Loss: 0.3842 Grad: 142093.0469 LR: 9.3639e-05
Epoch: [2][50/599]Elapsed 7.29s | Loss: 0.4178 Grad: 46330.9102 LR: 9.7834e-05
Epoch: [2][100/599]Elapsed 14.49s | Loss: 0.4063 Grad: 33928.4258 LR: 9.9837e-05
Epoch: [2][150/599]Elapsed 21.70s | Loss: 0.4143 Grad: 89391.2266 LR: 9.9994e-05
Epoch: [2][200/599]Elapsed 28.95s | Loss: 0.4045 Grad: 31731.7090 LR: 9.9961e-05
Epoch: [2][250/599]Elapsed 36.17s | Loss: 0.3976 Grad: 41026.5625 LR: 9.9899e-05
Epoch: [2][300/599]Elapsed 43.41s | Loss: 0.3959 Grad: 60293.5391 LR: 9.9807e-05
Epoch: [2][350/599]Elapsed 50.66s | Loss: 0.3928 Grad: 25876.3184 LR: 9.9685e-05
Epoch: [2][400/599]Elapsed 57.91s | Loss: 0.3892 Grad: 30222.4316 LR: 9.9535e-05
Epoch: [2][450/599]Elapsed 65.16s | Loss: 0.3836 Grad: 47859.9648 LR: 9.9355e-05
Epoch: [2][500/599]Elapsed 72.42s | Loss: 0.3812 Grad: 43956.9141 LR: 9.9146e-05
Epoch: [2][550/599]Elapsed 79.66s | Loss: 0.3775 Grad: 68269.3828 LR: 9.8908e-05
Epoch: [2][598/599]Elapsed 86.62

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [2][0/150]Elapsed 0.10s | Loss: 0.6648
Epoch: [2][50/150]Elapsed 4.85s | Loss: 0.4353
Epoch: [2][100/150]Elapsed 9.60s | Loss: 0.4526


----------------------------------------------------------------------------------------------------
Epoch 2 - Average Train Loss: 0.3741 | Average Valid Loss: 0.4538 | Time: 101.11s
Best model found in epoch 2 | valid loss: 0.4538


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [3][0/599]Elapsed 0.11s | Loss: 0.2464 Grad: 109735.1641 LR: 9.8653e-05
Epoch: [3][50/599]Elapsed 7.34s | Loss: 0.2902 Grad: 126770.3516 LR: 9.8359e-05
Epoch: [3][100/599]Elapsed 14.67s | Loss: 0.2907 Grad: 28631.8047 LR: 9.8036e-05
Epoch: [3][150/599]Elapsed 22.02s | Loss: 0.3047 Grad: 37765.5352 LR: 9.7685e-05
Epoch: [3][200/599]Elapsed 29.37s | Loss: 0.3014 Grad: 37428.9531 LR: 9.7306e-05
Epoch: [3][250/599]Elapsed 36.72s | Loss: 0.2984 Grad: 36949.7891 LR: 9.6899e-05
Epoch: [3][300/599]Elapsed 44.08s | Loss: 0.3000 Grad: 36849.8672 LR: 9.6464e-05
Epoch: [3][350/599]Elapsed 51.42s | Loss: 0.2989 Grad: 49472.4336 LR: 9.6002e-05
Epoch: [3][400/599]Elapsed 58.76s | Loss: 0.2971 Grad: 30841.2441 LR: 9.5513e-05
Epoch: [3][450/599]Elapsed 66.12s | Loss: 0.2947 Grad: 46519.9570 LR: 9.4997e-05
Epoch: [3][500/599]Elapsed 73.48s | Loss: 0.2944 Grad: 26111.7148 LR: 9.4455e-05
Epoch: [3][550/599]Elapsed 80.85s | Loss: 0.2919 Grad: 55629.1445 LR: 9.3886e-05
Epoch: [3][598/599]Elapsed 87.9

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [3][0/150]Elapsed 0.10s | Loss: 0.6345
Epoch: [3][50/150]Elapsed 4.87s | Loss: 0.4356
Epoch: [3][100/150]Elapsed 9.63s | Loss: 0.4567


----------------------------------------------------------------------------------------------------
Epoch 3 - Average Train Loss: 0.2912 | Average Valid Loss: 0.4587 | Time: 102.46s


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [4][0/599]Elapsed 0.10s | Loss: 0.2141 Grad: 123974.9453 LR: 9.3316e-05
Epoch: [4][50/599]Elapsed 7.40s | Loss: 0.2411 Grad: 79364.4219 LR: 9.2697e-05
Epoch: [4][100/599]Elapsed 14.78s | Loss: 0.2421 Grad: 30797.0547 LR: 9.2053e-05
Epoch: [4][150/599]Elapsed 22.19s | Loss: 0.2558 Grad: 39875.2695 LR: 9.1384e-05
Epoch: [4][200/599]Elapsed 29.59s | Loss: 0.2524 Grad: 27349.4492 LR: 9.0691e-05
Epoch: [4][250/599]Elapsed 36.94s | Loss: 0.2509 Grad: 24700.2266 LR: 8.9973e-05
Epoch: [4][300/599]Elapsed 44.29s | Loss: 0.2504 Grad: 33373.4922 LR: 8.9233e-05
Epoch: [4][350/599]Elapsed 51.62s | Loss: 0.2494 Grad: 30650.9355 LR: 8.8469e-05
Epoch: [4][400/599]Elapsed 58.93s | Loss: 0.2485 Grad: 62647.5078 LR: 8.7682e-05
Epoch: [4][450/599]Elapsed 66.23s | Loss: 0.2457 Grad: 58746.7539 LR: 8.6873e-05
Epoch: [4][500/599]Elapsed 73.53s | Loss: 0.2468 Grad: 35471.3125 LR: 8.6043e-05
Epoch: [4][550/599]Elapsed 80.84s | Loss: 0.2455 Grad: 63155.4219 LR: 8.5191e-05
Epoch: [4][598/599]Elapsed 87.84

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [4][0/150]Elapsed 0.10s | Loss: 0.5775
Epoch: [4][50/150]Elapsed 4.87s | Loss: 0.4239
Epoch: [4][100/150]Elapsed 9.63s | Loss: 0.4504


----------------------------------------------------------------------------------------------------
Epoch 4 - Average Train Loss: 0.2450 | Average Valid Loss: 0.4561 | Time: 102.37s


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [5][0/599]Elapsed 0.10s | Loss: 0.1996 Grad: 141908.5156 LR: 8.4354e-05
Epoch: [5][50/599]Elapsed 7.36s | Loss: 0.2061 Grad: 95460.1484 LR: 8.3462e-05
Epoch: [5][100/599]Elapsed 14.74s | Loss: 0.2032 Grad: 50888.0391 LR: 8.2550e-05
Epoch: [5][150/599]Elapsed 22.10s | Loss: 0.2116 Grad: 30870.4570 LR: 8.1619e-05
Epoch: [5][200/599]Elapsed 29.47s | Loss: 0.2079 Grad: 34410.8750 LR: 8.0670e-05
Epoch: [5][250/599]Elapsed 36.80s | Loss: 0.2070 Grad: 26766.7988 LR: 7.9702e-05
Epoch: [5][300/599]Elapsed 44.11s | Loss: 0.2075 Grad: 60561.5820 LR: 7.8717e-05
Epoch: [5][350/599]Elapsed 51.42s | Loss: 0.2089 Grad: 37366.2695 LR: 7.7715e-05
Epoch: [5][400/599]Elapsed 58.73s | Loss: 0.2076 Grad: 22367.2188 LR: 7.6697e-05
Epoch: [5][450/599]Elapsed 66.05s | Loss: 0.2054 Grad: 45877.0000 LR: 7.5663e-05
Epoch: [5][500/599]Elapsed 73.36s | Loss: 0.2057 Grad: 56285.5156 LR: 7.4614e-05
Epoch: [5][550/599]Elapsed 80.70s | Loss: 0.2037 Grad: 75790.4531 LR: 7.3550e-05
Epoch: [5][598/599]Elapsed 87.75

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [5][0/150]Elapsed 0.11s | Loss: 0.6307
Epoch: [5][50/150]Elapsed 4.87s | Loss: 0.4379
Epoch: [5][100/150]Elapsed 9.62s | Loss: 0.4559


----------------------------------------------------------------------------------------------------
Epoch 5 - Average Train Loss: 0.2033 | Average Valid Loss: 0.4626 | Time: 102.26s


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [6][0/599]Elapsed 0.10s | Loss: 0.2046 Grad: inf LR: 7.2516e-05
Epoch: [6][50/599]Elapsed 7.40s | Loss: 0.1709 Grad: 65081.3008 LR: 7.1426e-05
Epoch: [6][100/599]Elapsed 14.82s | Loss: 0.1697 Grad: 56867.7539 LR: 7.0323e-05
Epoch: [6][150/599]Elapsed 22.21s | Loss: 0.1770 Grad: 46316.9141 LR: 6.9208e-05
Epoch: [6][200/599]Elapsed 29.64s | Loss: 0.1752 Grad: 31888.0469 LR: 6.8082e-05
Epoch: [6][250/599]Elapsed 37.04s | Loss: 0.1748 Grad: 27573.3027 LR: 6.6945e-05
Epoch: [6][300/599]Elapsed 44.41s | Loss: 0.1743 Grad: 29712.3477 LR: 6.5799e-05
Epoch: [6][350/599]Elapsed 51.77s | Loss: 0.1733 Grad: 54139.2969 LR: 6.4642e-05
Epoch: [6][400/599]Elapsed 59.13s | Loss: 0.1729 Grad: 20577.5410 LR: 6.3478e-05
Epoch: [6][450/599]Elapsed 66.47s | Loss: 0.1702 Grad: 42005.5234 LR: 6.2305e-05
Epoch: [6][500/599]Elapsed 73.80s | Loss: 0.1710 Grad: 29141.4336 LR: 6.1125e-05
Epoch: [6][550/599]Elapsed 81.15s | Loss: 0.1693 Grad: 43241.3789 LR: 5.9939e-05
Epoch: [6][598/599]Elapsed 88.20s | Loss

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [6][0/150]Elapsed 0.11s | Loss: 0.5734
Epoch: [6][50/150]Elapsed 4.88s | Loss: 0.4446
Epoch: [6][100/150]Elapsed 9.65s | Loss: 0.4778


----------------------------------------------------------------------------------------------------
Epoch 6 - Average Train Loss: 0.1696 | Average Valid Loss: 0.4815 | Time: 102.75s
Fold: 1 First Training


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [1][0/599]Elapsed 0.10s | Loss: 0.8187 Grad: 90367.3984 LR: 4.0000e-06
Epoch: [1][50/599]Elapsed 7.40s | Loss: 0.8193 Grad: 61856.2148 LR: 5.1479e-06
Epoch: [1][100/599]Elapsed 14.77s | Loss: 0.8062 Grad: 69708.6875 LR: 8.5368e-06
Epoch: [1][150/599]Elapsed 22.18s | Loss: 0.8008 Grad: 65168.4062 LR: 1.4005e-05
Epoch: [1][200/599]Elapsed 29.58s | Loss: 0.7945 Grad: 60428.7070 LR: 2.1290e-05
Epoch: [1][250/599]Elapsed 36.97s | Loss: 0.7847 Grad: 65722.2031 LR: 3.0044e-05
Epoch: [1][300/599]Elapsed 44.34s | Loss: 0.7691 Grad: 88584.0000 LR: 3.9848e-05
Epoch: [1][350/599]Elapsed 51.72s | Loss: 0.7533 Grad: 84640.8047 LR: 5.0233e-05
Epoch: [1][400/599]Elapsed 59.08s | Loss: 0.7345 Grad: 102233.5078 LR: 6.0703e-05
Epoch: [1][450/599]Elapsed 66.44s | Loss: 0.7145 Grad: 79600.0234 LR: 7.0757e-05
Epoch: [1][500/599]Elapsed 73.80s | Loss: 0.6955 Grad: 49560.7461 LR: 7.9913e-05
Epoch: [1][550/599]Elapsed 81.15s | Loss: 0.6765 Grad: 57476.7422 LR: 8.7735e-05
Epoch: [1][598/599]Elapsed 88.19

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [1][0/150]Elapsed 0.10s | Loss: 0.3597
Epoch: [1][50/150]Elapsed 4.86s | Loss: 0.4746
Epoch: [1][100/150]Elapsed 9.62s | Loss: 0.4953


----------------------------------------------------------------------------------------------------
Epoch 1 - Average Train Loss: 0.6594 | Average Valid Loss: 0.5030 | Time: 102.71s
Best model found in epoch 1 | valid loss: 0.5030


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [2][0/599]Elapsed 0.10s | Loss: 0.5007 Grad: 150057.3750 LR: 9.3639e-05
Epoch: [2][50/599]Elapsed 7.41s | Loss: 0.4185 Grad: 75971.0703 LR: 9.7834e-05
Epoch: [2][100/599]Elapsed 14.79s | Loss: 0.4069 Grad: 71578.7891 LR: 9.9837e-05
Epoch: [2][150/599]Elapsed 22.16s | Loss: 0.4100 Grad: 59593.5156 LR: 9.9994e-05
Epoch: [2][200/599]Elapsed 29.52s | Loss: 0.4034 Grad: 28914.7852 LR: 9.9961e-05
Epoch: [2][250/599]Elapsed 36.87s | Loss: 0.3993 Grad: 39601.1992 LR: 9.9899e-05
Epoch: [2][300/599]Elapsed 44.21s | Loss: 0.3955 Grad: 39854.2031 LR: 9.9807e-05
Epoch: [2][350/599]Elapsed 51.55s | Loss: 0.3918 Grad: 46848.2617 LR: 9.9685e-05
Epoch: [2][400/599]Elapsed 58.92s | Loss: 0.3854 Grad: 28335.2754 LR: 9.9535e-05
Epoch: [2][450/599]Elapsed 66.27s | Loss: 0.3800 Grad: 52756.7930 LR: 9.9355e-05
Epoch: [2][500/599]Elapsed 73.61s | Loss: 0.3772 Grad: 37780.6367 LR: 9.9146e-05
Epoch: [2][550/599]Elapsed 80.96s | Loss: 0.3733 Grad: 28562.9863 LR: 9.8908e-05
Epoch: [2][598/599]Elapsed 88.01

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [2][0/150]Elapsed 0.10s | Loss: 0.3118
Epoch: [2][50/150]Elapsed 4.85s | Loss: 0.3850
Epoch: [2][100/150]Elapsed 9.62s | Loss: 0.3962


----------------------------------------------------------------------------------------------------
Epoch 2 - Average Train Loss: 0.3705 | Average Valid Loss: 0.3977 | Time: 102.53s
Best model found in epoch 2 | valid loss: 0.3977


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [3][0/599]Elapsed 0.10s | Loss: 0.3298 Grad: nan LR: 9.8653e-05
Epoch: [3][50/599]Elapsed 7.43s | Loss: 0.3010 Grad: 21338.7129 LR: 9.8359e-05
Epoch: [3][100/599]Elapsed 14.83s | Loss: 0.3003 Grad: 39418.8438 LR: 9.8036e-05
Epoch: [3][150/599]Elapsed 22.22s | Loss: 0.3095 Grad: 40957.5391 LR: 9.7685e-05
Epoch: [3][200/599]Elapsed 29.61s | Loss: 0.3049 Grad: 21981.3145 LR: 9.7306e-05
Epoch: [3][250/599]Elapsed 36.99s | Loss: 0.3028 Grad: 26094.0352 LR: 9.6899e-05
Epoch: [3][300/599]Elapsed 44.37s | Loss: 0.3040 Grad: 45831.8789 LR: 9.6464e-05
Epoch: [3][350/599]Elapsed 51.74s | Loss: 0.3011 Grad: 42579.3984 LR: 9.6002e-05
Epoch: [3][400/599]Elapsed 59.11s | Loss: 0.2969 Grad: 27818.4473 LR: 9.5513e-05
Epoch: [3][450/599]Elapsed 66.46s | Loss: 0.2935 Grad: 55390.7617 LR: 9.4997e-05
Epoch: [3][500/599]Elapsed 73.82s | Loss: 0.2934 Grad: 36681.0156 LR: 9.4455e-05
Epoch: [3][550/599]Elapsed 81.15s | Loss: 0.2925 Grad: 29872.9043 LR: 9.3886e-05
Epoch: [3][598/599]Elapsed 88.22s | Loss

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [3][0/150]Elapsed 0.11s | Loss: 0.2976
Epoch: [3][50/150]Elapsed 4.88s | Loss: 0.3757
Epoch: [3][100/150]Elapsed 9.65s | Loss: 0.3833


----------------------------------------------------------------------------------------------------
Epoch 3 - Average Train Loss: 0.2922 | Average Valid Loss: 0.3843 | Time: 102.77s
Best model found in epoch 3 | valid loss: 0.3843


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [4][0/599]Elapsed 0.10s | Loss: 0.3630 Grad: inf LR: 9.3316e-05
Epoch: [4][50/599]Elapsed 7.41s | Loss: 0.2468 Grad: 29989.4531 LR: 9.2697e-05
Epoch: [4][100/599]Elapsed 14.77s | Loss: 0.2463 Grad: 47235.0469 LR: 9.2053e-05
Epoch: [4][150/599]Elapsed 22.13s | Loss: 0.2568 Grad: 39901.6680 LR: 9.1384e-05
Epoch: [4][200/599]Elapsed 29.49s | Loss: 0.2511 Grad: 22997.1211 LR: 9.0691e-05
Epoch: [4][250/599]Elapsed 36.86s | Loss: 0.2494 Grad: 22708.4727 LR: 8.9973e-05
Epoch: [4][300/599]Elapsed 44.21s | Loss: 0.2499 Grad: 59539.6602 LR: 8.9233e-05
Epoch: [4][350/599]Elapsed 51.57s | Loss: 0.2494 Grad: 40909.0781 LR: 8.8469e-05
Epoch: [4][400/599]Elapsed 58.92s | Loss: 0.2449 Grad: 37714.2500 LR: 8.7682e-05
Epoch: [4][450/599]Elapsed 66.26s | Loss: 0.2419 Grad: 78410.5000 LR: 8.6873e-05
Epoch: [4][500/599]Elapsed 73.62s | Loss: 0.2416 Grad: 40052.4180 LR: 8.6043e-05
Epoch: [4][550/599]Elapsed 80.95s | Loss: 0.2405 Grad: 32881.4570 LR: 8.5191e-05
Epoch: [4][598/599]Elapsed 88.00s | Loss

Valid:   0%|          | 0/150 [00:00<?, ?batch/s]

Epoch: [4][0/150]Elapsed 0.10s | Loss: 0.2862
Epoch: [4][50/150]Elapsed 4.85s | Loss: 0.3762
Epoch: [4][100/150]Elapsed 9.60s | Loss: 0.3900


----------------------------------------------------------------------------------------------------
Epoch 4 - Average Train Loss: 0.2397 | Average Valid Loss: 0.3935 | Time: 102.50s


Train:   0%|          | 0/599 [00:00<?, ?batch/s]

Epoch: [5][0/599]Elapsed 0.10s | Loss: 0.2920 Grad: 204411.2812 LR: 8.4354e-05
Epoch: [5][50/599]Elapsed 7.41s | Loss: 0.2099 Grad: 36181.5898 LR: 8.3462e-05
Epoch: [5][100/599]Elapsed 14.81s | Loss: 0.2079 Grad: 45585.6836 LR: 8.2550e-05
Epoch: [5][150/599]Elapsed 22.21s | Loss: 0.2181 Grad: 37846.0625 LR: 8.1619e-05


KeyboardInterrupt: 

In [None]:
dataset = CustomDataset(train_easy, TARGETS, ModelConfig, all_specs, all_eegs, mode='test')

X, y = dataset[0]
print(X.shape, y.shape)

model = CustomModel(ModelConfig, num_classes=6, pretrained=True)
y_pred = model(X.unsqueeze(0))

print(y_pred.shape)

In [None]:
from kl_divergence import score as kl_score


def calc_kl_div(p, q, criterion):
    
    p = torch.tensor(p.astype(np.float32)).unsqueeze(0)
    q = torch.tensor(q.astype(np.float32)).unsqueeze(0)
    return criterion(F.log_softmax(p, dim=1), q).item()

def calc_kaggle_score(solution, submission):
    solution = solution.to_frame().T
    solution[TARGETS] = solution[TARGETS].astype(np.float32)
    submission = submission.to_frame().T
    submission.columns = ['eeg_id'] + TARGETS
    submission[TARGETS] = submission[TARGETS].astype(np.float32)
    
    return kl_score(solution, submission, 'eeg_id')

In [None]:
def evaluate_oof(oof_csv_path):
    oof_df = pd.read_csv(oof_csv_path)
    softmax = nn.Softmax(dim=1)
    criterion = nn.KLDivLoss(reduction="batchmean")

    oof_df["kl_loss"] = oof_df.apply(lambda row: 
        calc_kl_div(row[TARGETS_PRED].values, row[TARGETS].values, criterion), axis=1
        )

    kl_loss_all = criterion(
        F.log_softmax(torch.tensor(oof_df[TARGETS_PRED].values.astype(np.float32)), dim=1),
        torch.tensor(oof_df[TARGETS].values.astype(np.float32)),
        )

    print(f"KL Loss All: {kl_loss_all}")
    print(f"KL Loss Mean: {oof_df['kl_loss'].mean()}")

    y_pred = oof_df[TARGETS].values.astype(np.float32)
    oof_df[TARGETS_PRED] = softmax(torch.tensor(y_pred)).numpy()

    solution = oof_df[['eeg_id'] + TARGETS].copy()
    submission = oof_df[['eeg_id'] + TARGETS_PRED].copy()
    submission.columns = ['eeg_id'] + TARGETS

    kaggle_score_all = kl_score(solution, submission, 'eeg_id')
    
    oof_df['kaggle_score'] = oof_df.apply(lambda row:
        calc_kaggle_score(row[['eeg_id'] + TARGETS], row[['eeg_id'] + TARGETS_PRED]), axis=1
        )

    print(f"Kaggle Score All: {kaggle_score_all}")
    print(f"Kaggle Score Mean: {oof_df['kaggle_score'].mean()}")

    return oof_df, kl_loss_all, kaggle_score_all


In [None]:
oof_1, kl_loss_all, kaggle_score_all = evaluate_oof(f"{JobConfig.OUTPUT_DIR}/oof_1.csv")
oof_2, kl_loss_all, kaggle_score_all = evaluate_oof(f"{JobConfig.OUTPUT_DIR}/oof_2.csv")

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10), sharex=True, sharey=True)

# rows = oof_df.iloc[-len(axes.ravel()):, :]
rows = oof_1.sample(len(axes.ravel()))

for i, (idx, row) in enumerate(rows.iterrows()):

    ax = axes.ravel()[i]
    ax.plot(row[TARGETS].values, label='True')
    ax.plot(row[TARGETS_PRED].values, label='Pred')
    ax.set_title(f"{idx} | {row['target']} | KL: {row['kl_loss']:.4f}")
    ax.set_xticks(range(6))
    ax.set_xticklabels(BRAIN_ACTIVITY)
    ax.grid(True)
    ax.legend()

fig.tight_layout()
fig.savefig(f"{JobConfig.OUTPUT_DIR}/oof_examples_1.png")
plt.show()

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10), sharex=True, sharey=True)

# rows = oof_2.iloc[5:5+len(axes.ravel()), :]
rows = oof_2.sample(len(axes.ravel()))

for i, (idx, row) in enumerate(rows.iterrows()):

    ax = axes.ravel()[i]
    y_true = row[TARGETS].values
    y_pred = row[TARGETS_PRED].values
    y_norm = (y_true - y_true.min()) / (y_true.max() - y_true.min())

    ax.plot(row[TARGETS].values, label='True')
    ax.plot(row[TARGETS_PRED].values, label='Pred')
    ax.plot(y_norm, "b:", label='True Norm')

    ax.set_title(f"{idx} | {row['target']} | KL: {row['kl_loss']:.4f}")
    ax.set_xticks(range(6))
    ax.set_xticklabels(BRAIN_ACTIVITY)
    ax.grid(True)
    ax.legend()

fig.tight_layout()
fig.savefig(f"{JobConfig.OUTPUT_DIR}/oof_examples_2.png")
plt.show()

In [None]:
row = oof_2.loc[6]

min_pred = row[TARGETS_PRED].min()
max_pred = row[TARGETS_PRED].max()
print(min_pred, max_pred)

print(row[TARGETS_PRED])

targets_norm = (row[TARGETS] - row[TARGETS].min()) / (row[TARGETS].max() - row[TARGETS].min())

targets_norm = targets_norm / targets_norm.sum()

print(targets_norm)