In [None]:
!pip install iterative-stratification

Collecting iterative-stratification
  Downloading iterative_stratification-0.1.7-py3-none-any.whl (8.5 kB)
Installing collected packages: iterative-stratification
Successfully installed iterative-stratification-0.1.7


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import pandas as pd 
import numpy as np
import json
import random, string
import seaborn as sns
import matplotlib.pyplot as plt
import logging
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import QuantileTransformer
from sklearn.preprocessing import StandardScaler

%matplotlib inline

In [None]:
def seed_everything(seed=1903):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed=42)

In [None]:
maindir = "data" # Directory with your files
traincsv = maindir+"/Train.csv"
testcsv = maindir+"/Test.csv"

train = pd.read_csv(traincsv)
test = pd.read_csv(testcsv)

In [None]:
test.shape

(3660, 173)

In [None]:
cols = [col for col in train.columns if 'absorbance' in col]
q_cols = []

for col in cols:
    vec_len = len(train[col].values)
    vec_len_test = len(test[col].values)
    
    raw_vec = pd.concat([train, test])[col].values.reshape(vec_len+vec_len_test, 1)
    
    transformer = QuantileTransformer(n_quantiles = 9, random_state = 42, output_distribution = "normal")
    transformer.fit(raw_vec)

    train[col+'_q'] = transformer.transform(train[col].values.reshape(vec_len, 1)).reshape(1, vec_len)[0]
    test[col+'_q'] = transformer.transform(test[col].values.reshape(vec_len_test, 1)).reshape(1, vec_len_test)[0]

    q_cols.append(col+'_q')

In [None]:
train.head()

Unnamed: 0,Reading_ID,absorbance0,absorbance1,absorbance2,absorbance3,absorbance4,absorbance5,absorbance6,absorbance7,absorbance8,absorbance9,absorbance10,absorbance11,absorbance12,absorbance13,absorbance14,absorbance15,absorbance16,absorbance17,absorbance18,absorbance19,absorbance20,absorbance21,absorbance22,absorbance23,absorbance24,absorbance25,absorbance26,absorbance27,absorbance28,absorbance29,absorbance30,absorbance31,absorbance32,absorbance33,absorbance34,absorbance35,absorbance36,absorbance37,absorbance38,...,absorbance130_q,absorbance131_q,absorbance132_q,absorbance133_q,absorbance134_q,absorbance135_q,absorbance136_q,absorbance137_q,absorbance138_q,absorbance139_q,absorbance140_q,absorbance141_q,absorbance142_q,absorbance143_q,absorbance144_q,absorbance145_q,absorbance146_q,absorbance147_q,absorbance148_q,absorbance149_q,absorbance150_q,absorbance151_q,absorbance152_q,absorbance153_q,absorbance154_q,absorbance155_q,absorbance156_q,absorbance157_q,absorbance158_q,absorbance159_q,absorbance160_q,absorbance161_q,absorbance162_q,absorbance163_q,absorbance164_q,absorbance165_q,absorbance166_q,absorbance167_q,absorbance168_q,absorbance169_q
0,ID_3SSHI56C,0.479669,0.477423,0.487956,0.491831,0.500516,0.50259,0.511561,0.514639,0.524245,0.53617,0.546407,0.561557,0.568417,0.571877,0.570884,0.569032,0.567476,0.565662,0.561901,0.559722,0.557474,0.554371,0.552386,0.548702,0.544238,0.542579,0.540514,0.53898,0.53665,0.536483,0.535447,0.537577,0.535715,0.536895,0.539589,0.541081,0.544893,0.547765,0.551773,...,0.005347,-0.200495,-0.169134,-0.201174,-0.2687,-0.360849,-0.299175,-0.374937,-0.368232,-0.446304,-0.481743,-0.536963,-0.546214,-0.592069,-0.548906,-0.652757,-0.559784,-0.701215,-0.655335,-0.632611,-0.791866,-0.784521,-0.713666,-0.753014,-0.735886,-0.76753,-0.892156,-0.7813,-0.74402,-0.801688,-0.854736,-0.823175,-0.759883,-0.867688,-0.590366,-0.637286,-0.584662,-0.68398,-0.768665,-0.815892
1,ID_599OOLZA,0.471537,0.474113,0.479981,0.485528,0.491049,0.497942,0.50476,0.510543,0.522328,0.534423,0.548646,0.55842,0.565449,0.569717,0.570999,0.569969,0.568405,0.566628,0.564101,0.559951,0.556193,0.552271,0.550086,0.546207,0.542366,0.539789,0.537221,0.534336,0.533868,0.533018,0.532227,0.530818,0.532171,0.533658,0.535266,0.538939,0.542399,0.546479,0.550606,...,0.948714,0.729946,0.97718,0.902091,0.946421,0.75145,0.781482,0.828006,0.796453,0.723199,0.732754,0.704139,0.750715,0.728255,0.644411,0.635164,0.64815,0.558396,0.721934,0.552997,0.594627,0.631096,0.471492,0.507546,0.452656,0.485731,0.444692,0.55886,0.691252,0.551639,0.468825,0.64089,0.528006,0.351806,0.376244,0.369352,0.595002,0.423373,-0.126827,-0.011632
2,ID_MVJGPQ75,0.444998,0.458034,0.447386,0.456921,0.463225,0.475983,0.476817,0.481565,0.49001,0.505892,0.518125,0.530362,0.53853,0.543128,0.546287,0.547001,0.54712,0.546351,0.544254,0.542802,0.542207,0.539779,0.536417,0.53338,0.531117,0.529093,0.526101,0.524599,0.522952,0.521551,0.521149,0.520478,0.521432,0.521473,0.523567,0.525816,0.527889,0.530697,0.533416,...,0.188565,0.159944,0.069883,0.288553,0.230258,0.29395,0.271335,0.251015,0.330716,0.186716,0.238123,0.378758,0.367895,0.285911,0.325094,0.349507,0.22716,0.498355,0.458895,0.435385,0.264527,0.344995,0.300342,0.215132,0.3112,0.423823,0.357111,0.567359,0.295868,0.482974,0.504709,0.440702,0.337791,0.147785,0.098,0.59299,0.02935,0.722413,0.123484,0.759315
3,ID_CK6RF8YV,0.513434,0.513303,0.522609,0.521068,0.523146,0.530132,0.539517,0.546364,0.552414,0.565502,0.581143,0.594354,0.599457,0.604529,0.605267,0.606276,0.604895,0.603716,0.600683,0.598087,0.594303,0.589403,0.585883,0.581369,0.578962,0.575181,0.573274,0.570471,0.568241,0.565671,0.564579,0.563724,0.561978,0.562744,0.563455,0.565163,0.566505,0.569239,0.572075,...,-0.634534,-0.646776,-0.662182,-0.703819,-0.636916,-0.678044,-0.758524,-0.747302,-0.749403,-0.733126,-0.822535,-0.770965,-0.832533,-0.83797,-0.875521,-0.844299,-0.791368,-0.892418,-0.964479,-0.926125,-0.895796,-0.877917,-0.930868,-0.847778,-0.834806,-0.922771,-0.841652,-0.882441,-0.927972,-0.826974,-0.821212,-0.751495,-0.440116,-0.351561,0.043776,0.303395,0.359714,0.472551,0.256198,0.794309
4,ID_82N6QE6I,0.510485,0.519359,0.524225,0.528419,0.535273,0.545342,0.550314,0.557129,0.56703,0.577731,0.589192,0.604401,0.611372,0.614571,0.619713,0.619805,0.622708,0.620036,0.61807,0.61647,0.614592,0.611658,0.609762,0.608088,0.604118,0.602248,0.598901,0.598259,0.597334,0.59473,0.593618,0.593828,0.595201,0.596143,0.597089,0.599811,0.602078,0.607372,0.610382,...,-0.854737,-0.772028,-0.863299,-0.796284,-0.755787,-0.771807,-0.718124,-0.63196,-0.675858,-0.593209,-0.498462,-0.537186,-0.526856,-0.510368,-0.491973,-0.493553,-0.37352,-0.363134,-0.434087,-0.416978,-0.43814,-0.322276,-0.241307,-0.415354,-0.270319,-0.400733,-0.122631,-0.158394,-0.116011,-0.091576,-0.099951,-0.052423,0.034232,-0.2399,-0.11642,-0.027904,0.251437,0.743818,0.437047,0.78763


In [None]:
test.head()

Unnamed: 0,Reading_ID,absorbance0,absorbance1,absorbance2,absorbance3,absorbance4,absorbance5,absorbance6,absorbance7,absorbance8,absorbance9,absorbance10,absorbance11,absorbance12,absorbance13,absorbance14,absorbance15,absorbance16,absorbance17,absorbance18,absorbance19,absorbance20,absorbance21,absorbance22,absorbance23,absorbance24,absorbance25,absorbance26,absorbance27,absorbance28,absorbance29,absorbance30,absorbance31,absorbance32,absorbance33,absorbance34,absorbance35,absorbance36,absorbance37,absorbance38,...,absorbance130_q,absorbance131_q,absorbance132_q,absorbance133_q,absorbance134_q,absorbance135_q,absorbance136_q,absorbance137_q,absorbance138_q,absorbance139_q,absorbance140_q,absorbance141_q,absorbance142_q,absorbance143_q,absorbance144_q,absorbance145_q,absorbance146_q,absorbance147_q,absorbance148_q,absorbance149_q,absorbance150_q,absorbance151_q,absorbance152_q,absorbance153_q,absorbance154_q,absorbance155_q,absorbance156_q,absorbance157_q,absorbance158_q,absorbance159_q,absorbance160_q,absorbance161_q,absorbance162_q,absorbance163_q,absorbance164_q,absorbance165_q,absorbance166_q,absorbance167_q,absorbance168_q,absorbance169_q
0,ID_37BEI22R,0.449736,0.449798,0.447488,0.464694,0.466377,0.48535,0.488915,0.495073,0.504129,0.51269,0.528313,0.54002,0.550252,0.555062,0.555983,0.562491,0.559443,0.562695,0.558805,0.559067,0.557602,0.554924,0.5533,0.549671,0.548033,0.544539,0.54264,0.541228,0.540335,0.539378,0.539134,0.538375,0.538068,0.54034,0.541754,0.54392,0.547749,0.55049,0.55355,...,-1.440739,-1.349291,-1.358633,-1.321949,-1.281167,-1.268411,-1.243394,-1.223964,-1.207322,-1.199001,-1.176348,-1.021118,-1.020921,-1.100373,-0.994048,-0.878637,-0.877405,-0.848976,-0.759211,-0.787022,-0.865335,-0.68249,-0.707377,-0.574728,-0.493538,-0.731705,-0.589431,-0.551942,-0.623715,-0.759709,-0.798007,-0.762391,-1.036127,-1.212637,-0.940384,-1.179244,-1.183394,-0.9888,-1.157198,-0.634158
1,ID_4W85V5DV,0.495429,0.505488,0.510239,0.51888,0.533147,0.543142,0.55167,0.558261,0.564027,0.575223,0.58878,0.60326,0.609797,0.613326,0.61653,0.6174,0.617284,0.615343,0.611668,0.608864,0.606411,0.602919,0.599854,0.597024,0.5928,0.590059,0.586417,0.585922,0.583848,0.583204,0.582259,0.581994,0.582528,0.584993,0.587332,0.590686,0.591674,0.595796,0.599694,...,0.255984,0.307102,0.343937,0.393466,0.3861,0.393939,0.354894,0.336191,0.476564,0.418264,0.336158,0.368383,0.462564,0.436281,0.45635,0.394539,0.487869,0.550592,0.462785,0.491374,0.421365,0.530339,0.585208,0.588573,0.542323,0.58309,0.413273,0.530809,0.582355,0.28088,0.370204,0.404643,0.273017,0.078774,-0.147432,-0.329462,-0.308039,-0.474419,-0.234631,-0.078461
2,ID_L4YR3NDY,0.437904,0.439064,0.442527,0.450437,0.455363,0.465817,0.471249,0.479145,0.482595,0.497043,0.508849,0.520005,0.526073,0.529009,0.530775,0.530869,0.529993,0.529816,0.525386,0.52227,0.518925,0.516824,0.514363,0.510227,0.50654,0.503605,0.501884,0.499315,0.498547,0.497386,0.496028,0.495754,0.495847,0.495887,0.497499,0.499683,0.501803,0.504862,0.508623,...,-1.231802,-1.233313,-1.2707,-1.302708,-1.297397,-1.362261,-1.369247,-1.395393,-1.388909,-1.445083,-1.454765,-1.461159,-1.456635,-1.49657,-1.517256,-1.540893,-1.531749,-1.572441,-1.588435,-1.58274,-1.616694,-1.563345,-1.520616,-1.590241,-1.61032,-1.635457,-1.644092,-1.62708,-1.551553,-1.656906,-1.597835,-1.556996,-1.484154,-1.395691,-1.27993,-1.19703,-1.154629,-1.094291,-0.942452,-0.922984
3,ID_U88E3SQ6,0.495038,0.506246,0.50873,0.518995,0.529961,0.537583,0.539696,0.5404,0.547279,0.561166,0.572493,0.583802,0.588819,0.59178,0.596486,0.595962,0.595182,0.588548,0.584253,0.579974,0.576841,0.573102,0.569567,0.565961,0.563061,0.560563,0.556971,0.55563,0.554065,0.554014,0.552711,0.552815,0.552691,0.555071,0.557024,0.558817,0.563014,0.566382,0.571307,...,-1.299169,-1.328658,-1.334808,-1.314343,-1.289187,-1.316928,-1.314219,-1.289618,-1.297391,-1.293348,-1.288336,-1.273782,-1.279009,-1.313775,-1.281225,-1.294503,-1.268176,-1.288147,-1.283527,-1.276344,-1.272091,-1.305849,-1.295253,-1.352451,-1.282066,-1.299092,-1.306644,-1.312904,-1.310104,-1.309388,-1.315472,-1.400438,-1.440712,-1.566567,-1.565088,-1.776853,-1.688696,-1.785187,-1.710679,-1.688173
4,ID_NW7Z3XU7,0.531306,0.525309,0.535306,0.541387,0.551364,0.559821,0.564851,0.570824,0.577426,0.589114,0.601409,0.616401,0.621386,0.626131,0.626661,0.627811,0.626961,0.624922,0.621003,0.619719,0.615285,0.612897,0.609494,0.607091,0.603417,0.600907,0.599359,0.597534,0.595879,0.593052,0.590476,0.590287,0.591087,0.591824,0.592791,0.59354,0.597088,0.60095,0.603265,...,0.131514,0.108921,0.236976,0.164226,0.109996,0.077427,0.179683,0.074284,0.201988,0.086129,0.206552,0.101297,0.241019,0.020164,0.021391,0.12666,0.207654,0.184678,0.030096,0.046733,0.157799,0.126716,0.059818,0.187054,0.212648,0.060491,0.201111,-0.024547,-0.062209,0.147198,0.212886,0.269482,-0.09002,-0.027121,0.254864,0.128129,0.049122,0.127416,-0.272043,0.268737


In [None]:
def double_spectral_collator(batch):
    
    x  = np.array([el['x'] for el in batch])
    x_env = np.array([el['x_env'] for el in batch])
    x_q = np.array([el['x_q'] for el in batch])

    x_aux = np.hstack((x_env, x_q))

    y  = np.array([el['y'] for el in batch])
    
    x = torch.tensor(x, dtype = torch.float)
    x_aux = torch.tensor(x_aux, dtype = torch.float)
    
    if not use_raw_features:
        x = filter_signal(x)
    
    y  = torch.tensor(y, dtype = torch.float)
     
    return x, x_aux, y

def test_double_spectral_collator(batch):
    
    x  = np.array([el['x'] for el in batch])
    x_env = np.array([el['x_env'] for el in batch])
    x_q = np.array([el['x_q'] for el in batch])

    x_aux = np.hstack((x_env, x_q))
    
    x = torch.tensor(x, dtype = torch.float)
    x_aux = torch.tensor(x_aux, dtype = torch.float)
    
    if not use_raw_features:
        x = filter_signal(x)
            
    return x, x_aux


def single_spectral_collator(batch):
    
    x  = np.array([el['x'] for el in batch])

    x = torch.tensor(x, dtype = torch.float)

    y  = torch.tensor(y, dtype = torch.float)

    if not use_raw_features:
        x = filter_signal(x)
        
    y  = torch.tensor(y, dtype = torch.float)
     
    return x, y

def test_single_spectral_collator(batch):
    
    x  = np.array([el['x'] for el in batch])
    
    x = torch.tensor(x, dtype = torch.float)
    
    if not use_raw_features:
        x = filter_signal(x)
            
    return x


def filter_signal(signal):
            
    sig = torch.fft.fft2(signal)
        
    bs = sig.shape[0]
    sig_dim = sig.shape[1]
    
    if use_real_only:
      return sig.real

    elif use_imag_only:
      return sig.imag

    else:

        if use_threshold:
            arr = torch.zeros((bs, threshold, threshold, 1))

            for i in range(bs):
                arr[i, 1] = sig.real[i, :].unsqueeze(1)[:threshold]
                arr[i, 2] = sig.imag[i, :].unsqueeze(1)[:threshold]

            arr = arr.view(bs, -1)

            return arr
        
        else:
            arr = torch.zeros((bs, sig_dim, sig_dim, 1))
            
            for i in range(bs):
                arr[i, 1] = sig.real[i, :].unsqueeze(1)
                arr[i, 2] = sig.imag[i, :].unsqueeze(1)

            arr = arr.view(bs, -1)

            return arr


class BloodDataset(Dataset):
    
    def __init__(self, features, q_features = None, env_features = None,
                 targets = None, train_mode = True):
        
        self.train_mode = train_mode
        self.features = features
        self.env_features = env_features
        self.q_features = q_features
        if train_mode:
            self.targets = targets
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, item):
                
        x = self.features[item,:]
                
        if self.train_mode:

            y = self.targets[item,:]

            if not single_inp:
                x_env = self.env_features[item, :]
                x_q =  self.q_features[item, :]

                return {

                    'x' : x,
                    'x_env' : x_env,
                    'x_q' : x_q,
                    'y' : y
                }
            else:
                return {

                    'x' : x,
                    'y' : y
                }
        else:
            if not single_inp:
                
                x_env = self.env_features[item, :]
                x_q = self.q_features[item, :]

                return {

                    'x' : x,
                    'x_env' : x_env,
                    'x_q' : x_q
                }
            else:
                return {
                    'x' : x,
                }

In [None]:
folds = train.copy()

In [None]:
# Create new labels - Flatten 3 to 9 multilabel dataset

new_cols = ['hdl_cholesterol_human_ok','hdl_cholesterol_human_high', 'hdl_cholesterol_human_low', 
            'cholesterol_ldl_human_ok', 'cholesterol_ldl_human_high', 'cholesterol_ldl_human_low',
           'hemoglobin(hgb)_human_ok', 'hemoglobin(hgb)_human_high', 'hemoglobin(hgb)_human_low'
           ]

In [None]:
for col in new_cols:
    name, status = col.split('_')[:-1], col.split('_')[-1]
    name = '_'.join(name)
    
    if status == 'ok':
        folds.loc[:,col] = np.where(folds.loc[:, name] == 'ok' , 1, 0)
    elif status == 'high':
        folds.loc[:,col] = np.where(folds.loc[:, name] == 'high' , 1, 0)
    elif status == 'low':
        folds.loc[:,col] = np.where(folds.loc[:, name] == 'low' , 1, 0)


In [None]:
train

Unnamed: 0,Reading_ID,absorbance0,absorbance1,absorbance2,absorbance3,absorbance4,absorbance5,absorbance6,absorbance7,absorbance8,absorbance9,absorbance10,absorbance11,absorbance12,absorbance13,absorbance14,absorbance15,absorbance16,absorbance17,absorbance18,absorbance19,absorbance20,absorbance21,absorbance22,absorbance23,absorbance24,absorbance25,absorbance26,absorbance27,absorbance28,absorbance29,absorbance30,absorbance31,absorbance32,absorbance33,absorbance34,absorbance35,absorbance36,absorbance37,absorbance38,...,absorbance130_q,absorbance131_q,absorbance132_q,absorbance133_q,absorbance134_q,absorbance135_q,absorbance136_q,absorbance137_q,absorbance138_q,absorbance139_q,absorbance140_q,absorbance141_q,absorbance142_q,absorbance143_q,absorbance144_q,absorbance145_q,absorbance146_q,absorbance147_q,absorbance148_q,absorbance149_q,absorbance150_q,absorbance151_q,absorbance152_q,absorbance153_q,absorbance154_q,absorbance155_q,absorbance156_q,absorbance157_q,absorbance158_q,absorbance159_q,absorbance160_q,absorbance161_q,absorbance162_q,absorbance163_q,absorbance164_q,absorbance165_q,absorbance166_q,absorbance167_q,absorbance168_q,absorbance169_q
0,ID_3SSHI56C,0.479669,0.477423,0.487956,0.491831,0.500516,0.502590,0.511561,0.514639,0.524245,0.536170,0.546407,0.561557,0.568417,0.571877,0.570884,0.569032,0.567476,0.565662,0.561901,0.559722,0.557474,0.554371,0.552386,0.548702,0.544238,0.542579,0.540514,0.538980,0.536650,0.536483,0.535447,0.537577,0.535715,0.536895,0.539589,0.541081,0.544893,0.547765,0.551773,...,0.005347,-0.200495,-0.169134,-0.201174,-0.268700,-0.360849,-0.299175,-0.374937,-0.368232,-0.446304,-0.481743,-0.536963,-0.546214,-0.592069,-0.548906,-0.652757,-0.559784,-0.701215,-0.655335,-0.632611,-0.791866,-0.784521,-0.713666,-0.753014,-0.735886,-0.767530,-0.892156,-0.781300,-0.744020,-0.801688,-0.854736,-0.823175,-0.759883,-0.867688,-0.590366,-0.637286,-0.584662,-0.683980,-0.768665,-0.815892
1,ID_599OOLZA,0.471537,0.474113,0.479981,0.485528,0.491049,0.497942,0.504760,0.510543,0.522328,0.534423,0.548646,0.558420,0.565449,0.569717,0.570999,0.569969,0.568405,0.566628,0.564101,0.559951,0.556193,0.552271,0.550086,0.546207,0.542366,0.539789,0.537221,0.534336,0.533868,0.533018,0.532227,0.530818,0.532171,0.533658,0.535266,0.538939,0.542399,0.546479,0.550606,...,0.948714,0.729946,0.977180,0.902091,0.946421,0.751450,0.781482,0.828006,0.796453,0.723199,0.732754,0.704139,0.750715,0.728255,0.644411,0.635164,0.648150,0.558396,0.721934,0.552997,0.594627,0.631096,0.471492,0.507546,0.452656,0.485731,0.444692,0.558860,0.691252,0.551639,0.468825,0.640890,0.528006,0.351806,0.376244,0.369352,0.595002,0.423373,-0.126827,-0.011632
2,ID_MVJGPQ75,0.444998,0.458034,0.447386,0.456921,0.463225,0.475983,0.476817,0.481565,0.490010,0.505892,0.518125,0.530362,0.538530,0.543128,0.546287,0.547001,0.547120,0.546351,0.544254,0.542802,0.542207,0.539779,0.536417,0.533380,0.531117,0.529093,0.526101,0.524599,0.522952,0.521551,0.521149,0.520478,0.521432,0.521473,0.523567,0.525816,0.527889,0.530697,0.533416,...,0.188565,0.159944,0.069883,0.288553,0.230258,0.293950,0.271335,0.251015,0.330716,0.186716,0.238123,0.378758,0.367895,0.285911,0.325094,0.349507,0.227160,0.498355,0.458895,0.435385,0.264527,0.344995,0.300342,0.215132,0.311200,0.423823,0.357111,0.567359,0.295868,0.482974,0.504709,0.440702,0.337791,0.147785,0.098000,0.592990,0.029350,0.722413,0.123484,0.759315
3,ID_CK6RF8YV,0.513434,0.513303,0.522609,0.521068,0.523146,0.530132,0.539517,0.546364,0.552414,0.565502,0.581143,0.594354,0.599457,0.604529,0.605267,0.606276,0.604895,0.603716,0.600683,0.598087,0.594303,0.589403,0.585883,0.581369,0.578962,0.575181,0.573274,0.570471,0.568241,0.565671,0.564579,0.563724,0.561978,0.562744,0.563455,0.565163,0.566505,0.569239,0.572075,...,-0.634534,-0.646776,-0.662182,-0.703819,-0.636916,-0.678044,-0.758524,-0.747302,-0.749403,-0.733126,-0.822535,-0.770965,-0.832533,-0.837970,-0.875521,-0.844299,-0.791368,-0.892418,-0.964479,-0.926125,-0.895796,-0.877917,-0.930868,-0.847778,-0.834806,-0.922771,-0.841652,-0.882441,-0.927972,-0.826974,-0.821212,-0.751495,-0.440116,-0.351561,0.043776,0.303395,0.359714,0.472551,0.256198,0.794309
4,ID_82N6QE6I,0.510485,0.519359,0.524225,0.528419,0.535273,0.545342,0.550314,0.557129,0.567030,0.577731,0.589192,0.604401,0.611372,0.614571,0.619713,0.619805,0.622708,0.620036,0.618070,0.616470,0.614592,0.611658,0.609762,0.608088,0.604118,0.602248,0.598901,0.598259,0.597334,0.594730,0.593618,0.593828,0.595201,0.596143,0.597089,0.599811,0.602078,0.607372,0.610382,...,-0.854737,-0.772028,-0.863299,-0.796284,-0.755787,-0.771807,-0.718124,-0.631960,-0.675858,-0.593209,-0.498462,-0.537186,-0.526856,-0.510368,-0.491973,-0.493553,-0.373520,-0.363134,-0.434087,-0.416978,-0.438140,-0.322276,-0.241307,-0.415354,-0.270319,-0.400733,-0.122631,-0.158394,-0.116011,-0.091576,-0.099951,-0.052423,0.034232,-0.239900,-0.116420,-0.027904,0.251437,0.743818,0.437047,0.787630
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13135,ID_NGPC0DA3,0.483513,0.482732,0.487531,0.497946,0.500697,0.505740,0.511904,0.519018,0.529455,0.540356,0.554324,0.563337,0.568932,0.573057,0.575291,0.574014,0.573149,0.570979,0.567632,0.564020,0.560878,0.557749,0.554752,0.552457,0.548313,0.547369,0.544267,0.542225,0.542338,0.539435,0.539836,0.539883,0.540450,0.541485,0.542836,0.544809,0.548269,0.552457,0.556129,...,-1.004504,-1.158631,-1.063012,-1.094996,-1.161065,-1.092118,-1.178779,-1.144267,-1.150592,-1.169634,-1.194748,-1.185083,-1.176600,-1.189944,-1.210122,-1.204190,-1.144455,-1.184090,-1.160018,-1.176165,-1.205137,-1.178304,-1.199179,-1.167092,-1.211026,-1.203026,-1.086229,-1.155308,-1.208478,-1.169114,-1.180868,-1.177219,-1.163229,-1.150258,-0.978887,-0.940301,-0.893090,-0.591298,-0.837573,-0.749748
13136,ID_XRBUD5U8,0.525435,0.527563,0.528863,0.531776,0.541156,0.547318,0.552549,0.559060,0.566892,0.581571,0.595248,0.603612,0.611313,0.614295,0.617476,0.617077,0.618908,0.616631,0.610841,0.604681,0.601588,0.598501,0.594149,0.589455,0.587673,0.584746,0.581093,0.580218,0.578071,0.577839,0.576454,0.576026,0.576220,0.579007,0.581564,0.584409,0.587786,0.591168,0.595874,...,0.862627,1.072665,1.082042,1.081815,0.994342,1.103853,1.085500,1.040155,1.156730,0.997372,1.138976,1.153781,1.177291,1.163146,1.176169,1.072870,1.025334,1.198652,1.050350,1.075465,1.161095,1.160210,1.230940,1.207270,1.158279,1.151956,1.169783,1.175273,1.209241,1.179995,1.174539,0.925073,1.175723,1.157584,1.016327,1.190866,1.100848,1.044536,1.080872,1.133224
13137,ID_2M9L5NV2,0.512718,0.517815,0.524857,0.525466,0.536542,0.542930,0.550628,0.558939,0.567593,0.579192,0.595752,0.609749,0.617717,0.621270,0.625333,0.626705,0.623455,0.622948,0.618986,0.616963,0.614392,0.610574,0.606732,0.604206,0.601527,0.597425,0.595800,0.592601,0.590437,0.589152,0.588726,0.587134,0.588214,0.589699,0.590376,0.592544,0.594104,0.599264,0.603341,...,0.633195,0.607147,0.635448,0.736553,0.635200,0.711113,0.544655,0.734594,0.695521,0.704520,0.727205,0.802097,0.695919,0.793078,0.735224,0.722313,0.798280,0.820054,0.850148,0.988690,0.816417,0.973748,0.978803,0.958757,0.946359,0.834825,0.884191,0.996301,1.108747,1.047904,1.003578,1.045255,0.969722,1.196814,1.191217,1.228748,1.129584,1.302465,1.260595,1.096716
13138,ID_C5V5SD2D,0.456747,0.472575,0.466935,0.466698,0.478546,0.486451,0.494838,0.496540,0.508274,0.521304,0.532652,0.542257,0.550854,0.556305,0.558084,0.559439,0.558418,0.554719,0.554852,0.552821,0.551641,0.549722,0.545233,0.542932,0.539894,0.536520,0.534783,0.530619,0.531076,0.530187,0.528713,0.527754,0.529230,0.529761,0.531586,0.533869,0.535897,0.539685,0.543911,...,-1.345986,-1.432086,-1.384386,-1.355260,-1.336108,-1.340254,-1.322521,-1.341003,-1.252699,-1.300171,-1.274233,-1.241261,-1.257020,-1.253388,-1.246544,-1.227465,-1.213584,-1.248587,-1.274239,-1.218337,-1.249783,-1.187209,-1.239202,-1.263199,-1.251528,-1.224649,-1.249678,-1.184693,-1.174361,-1.224818,-1.210134,-1.193675,-1.142709,-1.148653,-1.172761,-1.037771,-1.005407,-0.904384,-1.166684,-0.506224


In [None]:
targets = ['hdl_cholesterol_human', 'cholesterol_ldl_human', 'hemoglobin(hgb)_human']

# drop  old columns
folds.drop(columns = targets, axis = 1)

Unnamed: 0,Reading_ID,absorbance0,absorbance1,absorbance2,absorbance3,absorbance4,absorbance5,absorbance6,absorbance7,absorbance8,absorbance9,absorbance10,absorbance11,absorbance12,absorbance13,absorbance14,absorbance15,absorbance16,absorbance17,absorbance18,absorbance19,absorbance20,absorbance21,absorbance22,absorbance23,absorbance24,absorbance25,absorbance26,absorbance27,absorbance28,absorbance29,absorbance30,absorbance31,absorbance32,absorbance33,absorbance34,absorbance35,absorbance36,absorbance37,absorbance38,...,absorbance139_q,absorbance140_q,absorbance141_q,absorbance142_q,absorbance143_q,absorbance144_q,absorbance145_q,absorbance146_q,absorbance147_q,absorbance148_q,absorbance149_q,absorbance150_q,absorbance151_q,absorbance152_q,absorbance153_q,absorbance154_q,absorbance155_q,absorbance156_q,absorbance157_q,absorbance158_q,absorbance159_q,absorbance160_q,absorbance161_q,absorbance162_q,absorbance163_q,absorbance164_q,absorbance165_q,absorbance166_q,absorbance167_q,absorbance168_q,absorbance169_q,hdl_cholesterol_human_ok,hdl_cholesterol_human_high,hdl_cholesterol_human_low,cholesterol_ldl_human_ok,cholesterol_ldl_human_high,cholesterol_ldl_human_low,hemoglobin(hgb)_human_ok,hemoglobin(hgb)_human_high,hemoglobin(hgb)_human_low
0,ID_3SSHI56C,0.479669,0.477423,0.487956,0.491831,0.500516,0.502590,0.511561,0.514639,0.524245,0.536170,0.546407,0.561557,0.568417,0.571877,0.570884,0.569032,0.567476,0.565662,0.561901,0.559722,0.557474,0.554371,0.552386,0.548702,0.544238,0.542579,0.540514,0.538980,0.536650,0.536483,0.535447,0.537577,0.535715,0.536895,0.539589,0.541081,0.544893,0.547765,0.551773,...,-0.446304,-0.481743,-0.536963,-0.546214,-0.592069,-0.548906,-0.652757,-0.559784,-0.701215,-0.655335,-0.632611,-0.791866,-0.784521,-0.713666,-0.753014,-0.735886,-0.767530,-0.892156,-0.781300,-0.744020,-0.801688,-0.854736,-0.823175,-0.759883,-0.867688,-0.590366,-0.637286,-0.584662,-0.683980,-0.768665,-0.815892,1,0,0,1,0,0,1,0,0
1,ID_599OOLZA,0.471537,0.474113,0.479981,0.485528,0.491049,0.497942,0.504760,0.510543,0.522328,0.534423,0.548646,0.558420,0.565449,0.569717,0.570999,0.569969,0.568405,0.566628,0.564101,0.559951,0.556193,0.552271,0.550086,0.546207,0.542366,0.539789,0.537221,0.534336,0.533868,0.533018,0.532227,0.530818,0.532171,0.533658,0.535266,0.538939,0.542399,0.546479,0.550606,...,0.723199,0.732754,0.704139,0.750715,0.728255,0.644411,0.635164,0.648150,0.558396,0.721934,0.552997,0.594627,0.631096,0.471492,0.507546,0.452656,0.485731,0.444692,0.558860,0.691252,0.551639,0.468825,0.640890,0.528006,0.351806,0.376244,0.369352,0.595002,0.423373,-0.126827,-0.011632,1,0,0,0,1,0,0,1,0
2,ID_MVJGPQ75,0.444998,0.458034,0.447386,0.456921,0.463225,0.475983,0.476817,0.481565,0.490010,0.505892,0.518125,0.530362,0.538530,0.543128,0.546287,0.547001,0.547120,0.546351,0.544254,0.542802,0.542207,0.539779,0.536417,0.533380,0.531117,0.529093,0.526101,0.524599,0.522952,0.521551,0.521149,0.520478,0.521432,0.521473,0.523567,0.525816,0.527889,0.530697,0.533416,...,0.186716,0.238123,0.378758,0.367895,0.285911,0.325094,0.349507,0.227160,0.498355,0.458895,0.435385,0.264527,0.344995,0.300342,0.215132,0.311200,0.423823,0.357111,0.567359,0.295868,0.482974,0.504709,0.440702,0.337791,0.147785,0.098000,0.592990,0.029350,0.722413,0.123484,0.759315,1,0,0,0,1,0,1,0,0
3,ID_CK6RF8YV,0.513434,0.513303,0.522609,0.521068,0.523146,0.530132,0.539517,0.546364,0.552414,0.565502,0.581143,0.594354,0.599457,0.604529,0.605267,0.606276,0.604895,0.603716,0.600683,0.598087,0.594303,0.589403,0.585883,0.581369,0.578962,0.575181,0.573274,0.570471,0.568241,0.565671,0.564579,0.563724,0.561978,0.562744,0.563455,0.565163,0.566505,0.569239,0.572075,...,-0.733126,-0.822535,-0.770965,-0.832533,-0.837970,-0.875521,-0.844299,-0.791368,-0.892418,-0.964479,-0.926125,-0.895796,-0.877917,-0.930868,-0.847778,-0.834806,-0.922771,-0.841652,-0.882441,-0.927972,-0.826974,-0.821212,-0.751495,-0.440116,-0.351561,0.043776,0.303395,0.359714,0.472551,0.256198,0.794309,0,0,1,0,1,0,1,0,0
4,ID_82N6QE6I,0.510485,0.519359,0.524225,0.528419,0.535273,0.545342,0.550314,0.557129,0.567030,0.577731,0.589192,0.604401,0.611372,0.614571,0.619713,0.619805,0.622708,0.620036,0.618070,0.616470,0.614592,0.611658,0.609762,0.608088,0.604118,0.602248,0.598901,0.598259,0.597334,0.594730,0.593618,0.593828,0.595201,0.596143,0.597089,0.599811,0.602078,0.607372,0.610382,...,-0.593209,-0.498462,-0.537186,-0.526856,-0.510368,-0.491973,-0.493553,-0.373520,-0.363134,-0.434087,-0.416978,-0.438140,-0.322276,-0.241307,-0.415354,-0.270319,-0.400733,-0.122631,-0.158394,-0.116011,-0.091576,-0.099951,-0.052423,0.034232,-0.239900,-0.116420,-0.027904,0.251437,0.743818,0.437047,0.787630,1,0,0,0,1,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13135,ID_NGPC0DA3,0.483513,0.482732,0.487531,0.497946,0.500697,0.505740,0.511904,0.519018,0.529455,0.540356,0.554324,0.563337,0.568932,0.573057,0.575291,0.574014,0.573149,0.570979,0.567632,0.564020,0.560878,0.557749,0.554752,0.552457,0.548313,0.547369,0.544267,0.542225,0.542338,0.539435,0.539836,0.539883,0.540450,0.541485,0.542836,0.544809,0.548269,0.552457,0.556129,...,-1.169634,-1.194748,-1.185083,-1.176600,-1.189944,-1.210122,-1.204190,-1.144455,-1.184090,-1.160018,-1.176165,-1.205137,-1.178304,-1.199179,-1.167092,-1.211026,-1.203026,-1.086229,-1.155308,-1.208478,-1.169114,-1.180868,-1.177219,-1.163229,-1.150258,-0.978887,-0.940301,-0.893090,-0.591298,-0.837573,-0.749748,0,1,0,1,0,0,1,0,0
13136,ID_XRBUD5U8,0.525435,0.527563,0.528863,0.531776,0.541156,0.547318,0.552549,0.559060,0.566892,0.581571,0.595248,0.603612,0.611313,0.614295,0.617476,0.617077,0.618908,0.616631,0.610841,0.604681,0.601588,0.598501,0.594149,0.589455,0.587673,0.584746,0.581093,0.580218,0.578071,0.577839,0.576454,0.576026,0.576220,0.579007,0.581564,0.584409,0.587786,0.591168,0.595874,...,0.997372,1.138976,1.153781,1.177291,1.163146,1.176169,1.072870,1.025334,1.198652,1.050350,1.075465,1.161095,1.160210,1.230940,1.207270,1.158279,1.151956,1.169783,1.175273,1.209241,1.179995,1.174539,0.925073,1.175723,1.157584,1.016327,1.190866,1.100848,1.044536,1.080872,1.133224,1,0,0,1,0,0,1,0,0
13137,ID_2M9L5NV2,0.512718,0.517815,0.524857,0.525466,0.536542,0.542930,0.550628,0.558939,0.567593,0.579192,0.595752,0.609749,0.617717,0.621270,0.625333,0.626705,0.623455,0.622948,0.618986,0.616963,0.614392,0.610574,0.606732,0.604206,0.601527,0.597425,0.595800,0.592601,0.590437,0.589152,0.588726,0.587134,0.588214,0.589699,0.590376,0.592544,0.594104,0.599264,0.603341,...,0.704520,0.727205,0.802097,0.695919,0.793078,0.735224,0.722313,0.798280,0.820054,0.850148,0.988690,0.816417,0.973748,0.978803,0.958757,0.946359,0.834825,0.884191,0.996301,1.108747,1.047904,1.003578,1.045255,0.969722,1.196814,1.191217,1.228748,1.129584,1.302465,1.260595,1.096716,1,0,0,1,0,0,1,0,0
13138,ID_C5V5SD2D,0.456747,0.472575,0.466935,0.466698,0.478546,0.486451,0.494838,0.496540,0.508274,0.521304,0.532652,0.542257,0.550854,0.556305,0.558084,0.559439,0.558418,0.554719,0.554852,0.552821,0.551641,0.549722,0.545233,0.542932,0.539894,0.536520,0.534783,0.530619,0.531076,0.530187,0.528713,0.527754,0.529230,0.529761,0.531586,0.533869,0.535897,0.539685,0.543911,...,-1.300171,-1.274233,-1.241261,-1.257020,-1.253388,-1.246544,-1.227465,-1.213584,-1.248587,-1.274239,-1.218337,-1.249783,-1.187209,-1.239202,-1.263199,-1.251528,-1.224649,-1.249678,-1.184693,-1.174361,-1.224818,-1.210134,-1.193675,-1.142709,-1.148653,-1.172761,-1.037771,-1.005407,-0.904384,-1.166684,-0.506224,1,0,0,1,0,0,1,0,0


In [None]:
# Hyperparameters

global use_real_only
global use_threshold
global num_features
global threshold
global use_smoothing_loss
global single_inp
global use_cv
global use_crelu
global use_screlu
global use_imag_only

global use_raw_features

use_raw_features = False

use_crelu = True
use_screlu = False

use_cv = False

n_clusters = 9

single_inp = False
use_smoothing_loss = True
use_threshold = False
use_real_only = True
use_imag_only = False
threshold = 10

feature_cols = [col for col in folds.columns if ("absorbance" in col)]
feature_cols_env = ['temperature' , 'humidity']

if use_raw_features:
    num_features = len(feature_cols)
else:
    if use_real_only:
        num_features = len(feature_cols)
    else:
        if use_threshold:
            num_features = threshold**2
        else:
            num_features = len(feature_cols)**2
        
num_env_features = 2 + len(q_cols)
hidden_size_env = 128
h_size_env_2 = 256
h_size_env_3 = 128

num_targets = 9
hidden_size = 512

DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 1000
LEARNING_RATE = 5e-3
WEIGHT_DECAY = 1e-5
NFOLDS = 10 
EARLY_STOPPING_STEPS = 50
EARLY_STOP = True
BATCH_SIZE = 64

model_output_folder = maindir

In [None]:
len(q_cols)

170

In [None]:
from sklearn.cluster import KMeans

def fe_cluster(train, test, features, n_clusters= 2, SEED=42):

    def create_cluster(train, test, features, n_clusters=n_clusters):

        train = train.fillna(0)
        test =  test.fillna(0)
        
        train_ = train[features].copy()
        test_ = test[features].copy()
        
        kmeans = KMeans(random_state = SEED, n_clusters = n_clusters)
        
        kmeans.fit(pd.concat((train_, test_), axis=0).reset_index(drop=True))

        train['kfold'] = kmeans.predict(train_.values)
        test['kfold']  = kmeans.predict(test_.values)

        return train, test

   # train, test = create_cluster(train, test, features_g, kind = 'g', n_clusters = n_clusters_g)
    train, test = create_cluster(
        train, test, features, n_clusters=n_clusters)
    
    return train, test

In [None]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

if use_cv:
    mskf = MultilabelStratifiedKFold(n_splits = NFOLDS)

    for fold, (tr_idx, vl_idx) in enumerate(mskf.split(X = folds, y= folds[new_cols])):

        folds.loc[vl_idx, 'kfold'] = int(fold)

    folds['kfold'] = folds.kfold.astype(int)
    
else:
    folds, test = fe_cluster(folds, test, n_clusters = n_clusters , features = feature_cols + q_cols + feature_cols_env)

In [None]:
folds.head()

Unnamed: 0,Reading_ID,absorbance0,absorbance1,absorbance2,absorbance3,absorbance4,absorbance5,absorbance6,absorbance7,absorbance8,absorbance9,absorbance10,absorbance11,absorbance12,absorbance13,absorbance14,absorbance15,absorbance16,absorbance17,absorbance18,absorbance19,absorbance20,absorbance21,absorbance22,absorbance23,absorbance24,absorbance25,absorbance26,absorbance27,absorbance28,absorbance29,absorbance30,absorbance31,absorbance32,absorbance33,absorbance34,absorbance35,absorbance36,absorbance37,absorbance38,...,absorbance140_q,absorbance141_q,absorbance142_q,absorbance143_q,absorbance144_q,absorbance145_q,absorbance146_q,absorbance147_q,absorbance148_q,absorbance149_q,absorbance150_q,absorbance151_q,absorbance152_q,absorbance153_q,absorbance154_q,absorbance155_q,absorbance156_q,absorbance157_q,absorbance158_q,absorbance159_q,absorbance160_q,absorbance161_q,absorbance162_q,absorbance163_q,absorbance164_q,absorbance165_q,absorbance166_q,absorbance167_q,absorbance168_q,absorbance169_q,hdl_cholesterol_human_ok,hdl_cholesterol_human_high,hdl_cholesterol_human_low,cholesterol_ldl_human_ok,cholesterol_ldl_human_high,cholesterol_ldl_human_low,hemoglobin(hgb)_human_ok,hemoglobin(hgb)_human_high,hemoglobin(hgb)_human_low,kfold
0,ID_3SSHI56C,0.479669,0.477423,0.487956,0.491831,0.500516,0.50259,0.511561,0.514639,0.524245,0.53617,0.546407,0.561557,0.568417,0.571877,0.570884,0.569032,0.567476,0.565662,0.561901,0.559722,0.557474,0.554371,0.552386,0.548702,0.544238,0.542579,0.540514,0.53898,0.53665,0.536483,0.535447,0.537577,0.535715,0.536895,0.539589,0.541081,0.544893,0.547765,0.551773,...,-0.481743,-0.536963,-0.546214,-0.592069,-0.548906,-0.652757,-0.559784,-0.701215,-0.655335,-0.632611,-0.791866,-0.784521,-0.713666,-0.753014,-0.735886,-0.76753,-0.892156,-0.7813,-0.74402,-0.801688,-0.854736,-0.823175,-0.759883,-0.867688,-0.590366,-0.637286,-0.584662,-0.68398,-0.768665,-0.815892,1,0,0,1,0,0,1,0,0,5
1,ID_599OOLZA,0.471537,0.474113,0.479981,0.485528,0.491049,0.497942,0.50476,0.510543,0.522328,0.534423,0.548646,0.55842,0.565449,0.569717,0.570999,0.569969,0.568405,0.566628,0.564101,0.559951,0.556193,0.552271,0.550086,0.546207,0.542366,0.539789,0.537221,0.534336,0.533868,0.533018,0.532227,0.530818,0.532171,0.533658,0.535266,0.538939,0.542399,0.546479,0.550606,...,0.732754,0.704139,0.750715,0.728255,0.644411,0.635164,0.64815,0.558396,0.721934,0.552997,0.594627,0.631096,0.471492,0.507546,0.452656,0.485731,0.444692,0.55886,0.691252,0.551639,0.468825,0.64089,0.528006,0.351806,0.376244,0.369352,0.595002,0.423373,-0.126827,-0.011632,1,0,0,0,1,0,0,1,0,5
2,ID_MVJGPQ75,0.444998,0.458034,0.447386,0.456921,0.463225,0.475983,0.476817,0.481565,0.49001,0.505892,0.518125,0.530362,0.53853,0.543128,0.546287,0.547001,0.54712,0.546351,0.544254,0.542802,0.542207,0.539779,0.536417,0.53338,0.531117,0.529093,0.526101,0.524599,0.522952,0.521551,0.521149,0.520478,0.521432,0.521473,0.523567,0.525816,0.527889,0.530697,0.533416,...,0.238123,0.378758,0.367895,0.285911,0.325094,0.349507,0.22716,0.498355,0.458895,0.435385,0.264527,0.344995,0.300342,0.215132,0.3112,0.423823,0.357111,0.567359,0.295868,0.482974,0.504709,0.440702,0.337791,0.147785,0.098,0.59299,0.02935,0.722413,0.123484,0.759315,1,0,0,0,1,0,1,0,0,5
3,ID_CK6RF8YV,0.513434,0.513303,0.522609,0.521068,0.523146,0.530132,0.539517,0.546364,0.552414,0.565502,0.581143,0.594354,0.599457,0.604529,0.605267,0.606276,0.604895,0.603716,0.600683,0.598087,0.594303,0.589403,0.585883,0.581369,0.578962,0.575181,0.573274,0.570471,0.568241,0.565671,0.564579,0.563724,0.561978,0.562744,0.563455,0.565163,0.566505,0.569239,0.572075,...,-0.822535,-0.770965,-0.832533,-0.83797,-0.875521,-0.844299,-0.791368,-0.892418,-0.964479,-0.926125,-0.895796,-0.877917,-0.930868,-0.847778,-0.834806,-0.922771,-0.841652,-0.882441,-0.927972,-0.826974,-0.821212,-0.751495,-0.440116,-0.351561,0.043776,0.303395,0.359714,0.472551,0.256198,0.794309,0,0,1,0,1,0,1,0,0,7
4,ID_82N6QE6I,0.510485,0.519359,0.524225,0.528419,0.535273,0.545342,0.550314,0.557129,0.56703,0.577731,0.589192,0.604401,0.611372,0.614571,0.619713,0.619805,0.622708,0.620036,0.61807,0.61647,0.614592,0.611658,0.609762,0.608088,0.604118,0.602248,0.598901,0.598259,0.597334,0.59473,0.593618,0.593828,0.595201,0.596143,0.597089,0.599811,0.602078,0.607372,0.610382,...,-0.498462,-0.537186,-0.526856,-0.510368,-0.491973,-0.493553,-0.37352,-0.363134,-0.434087,-0.416978,-0.43814,-0.322276,-0.241307,-0.415354,-0.270319,-0.400733,-0.122631,-0.158394,-0.116011,-0.091576,-0.099951,-0.052423,0.034232,-0.2399,-0.11642,-0.027904,0.251437,0.743818,0.437047,0.78763,1,0,0,0,1,0,1,0,0,6


In [None]:
# model train and validation utils

def train_fn(model, train_dataloader, criterion, optimizer, scheduler , device):
    
    logging.info("TRAIN")
    
    model.train()
    
    start_iter = 0
    final_loss = 0
    
    pbar = tqdm(iter(train_dataloader), leave = True, total = len(train_dataloader))
    
    for i, (data) in enumerate(pbar, start = start_iter):
        
        if not single_inp:

            x, x_env, y = data
            inputs , inputs_env, targets = x.to(device), x_env.to(device), y.to(device)
            output = model(inputs, inputs_env)
            
        else:
            x, y = data
            inputs, targets = x.to(device), y.to(device)
            output = model(inputs)

            
        optimizer.zero_grad()
            
        
        loss = criterion(output, targets)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        final_loss += loss.item()
        
    final_loss /= len(train_dataloader)    
        
    return final_loss

def val_fn(model, valid_dataloader, criterion, device):
    
    logging.info("VALID")
    
    model.eval()
    
    final_loss = 0
    start_iter = 0
    valid_preds = []
    
    pbar= tqdm(iter(valid_dataloader), leave = True, total = len(valid_dataloader))
        
    
    for i, (data) in enumerate(pbar, start = start_iter):
        
        if not single_inp:

            x, x_env, y = data
            inputs , inputs_env, targets = x.to(device), x_env.to(device), y.to(device)
            output = model(inputs, inputs_env)
            
        else:
            x, y = data
            inputs, targets = x.to(device), y.to(device)
            output = model(inputs)
            
                
        loss = criterion(output, targets)
        
        final_loss += loss.item()
        
        valid_preds.append(output.sigmoid().detach().cpu().numpy())
        
    final_loss /= len(valid_dataloader)
    valid_preds = np.concatenate(valid_preds)
    
    return final_loss, valid_preds


def inference_fn(model, test_dataloader, device):
    
    model.eval()
    
    preds = []

    pbar= tqdm(iter(test_dataloader), leave = True, total = len(test_dataloader))
        
    start_iter = 0
    
    for i, (data) in enumerate(pbar, start = start_iter):
        
        
        if not single_inp:

            x, x_env = data
            inputs , inputs_env= x.to(device), x_env.to(device)
            
            with torch.no_grad():
                outputs = model(inputs, inputs_env)            
        else:
            x  = data
            inputs = x.to(device)
            
            with torch.no_grad():
                outputs = model(inputs)   
                
        preds.append(outputs.sigmoid().detach().cpu().numpy())
    
    preds = np.concatenate(preds)
    
    return preds

In [None]:
import torch
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F


class SmoothBCEwLogits(_WeightedLoss):

    def __init__(self, weight=None, reduction='mean', smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    @staticmethod
    def _smooth(targets: torch.Tensor, n_labels: int, smoothing=0.0):
        assert 0 <= smoothing < 1
        with torch.no_grad():
            targets = targets * (1.0 - smoothing) + 0.5 * smoothing
        return targets

    def forward(self, inputs, targets):
        targets = SmoothBCEwLogits._smooth(targets, inputs.size(-1),
                                           self.smoothing)
        loss = F.binary_cross_entropy_with_logits(inputs, targets, self.weight)

        if self.reduction == 'sum':
            loss = loss.sum()
        elif self.reduction == 'mean':
            loss = loss.mean()

        return loss
    

In [None]:
class SCReLU(nn.Module):
    
    """CReLU Activation
     This is a modification of the classical CReLU activation proposed in this paper (https://arxiv.org/pdf/1603.05201.pdf)
    returns : CONCAT(relu(x), relu(-x))
    """
    def __init__(self):        
        super(SCReLU, self).__init__()
    def forward(self,x):
        return torch.cat((F.relu(x), -F.relu(-x)), dim =1)
    
class CReLU(nn.Module):
    
    """CReLU Activation
     This is the classical CReLU activation proposed in this paper (https://arxiv.org/pdf/1603.05201.pdf)
    returns : CONCAT(relu(x), relu(-x))
    """
    def __init__(self):        
        super(CReLU, self).__init__()
    def forward(self,x):
        return torch.cat((F.relu(x), F.relu(-x)), dim =1)


class DModel(nn.Module):

    def __init__(self, num_features, num_env_features, num_targets, hidden_size, hidden_size_env):
        super(DModel, self).__init__()
        self.batch_norm1 = nn.BatchNorm1d(num_features)
        self.batch_norm_env = nn.BatchNorm1d(num_env_features)
        self.dense1_env      = nn.Linear(num_env_features, hidden_size_env)
        self.dense2_env      =  nn.Linear(hidden_size_env, h_size_env_2)
        self.dense3_env      = nn.Linear(h_size_env_2, h_size_env_3)

        self.dense1 = nn.utils.weight_norm(
            nn.Linear(num_features, hidden_size))

        self.batch_norm2 = nn.BatchNorm1d(hidden_size*2)
        self.dropout2 = nn.Dropout(0.25)
        self.dense2 = nn.Linear(hidden_size*2, hidden_size)

        self.batch_norm3 = nn.BatchNorm1d((hidden_size)*2 + hidden_size_env)
        self.dropout3 = nn.Dropout(0.25)
        self.dense3 = nn.utils.weight_norm(nn.Linear((hidden_size)*2 + hidden_size_env, num_targets))
        
        if use_screlu:
            self.crelu =  SCReLU()
        elif use_crelu:
            self.crelu = CReLU()
        else:
            self.crelu = nn.ReLU()

    def forward(self, x, x_env):
        x = self.batch_norm1(x)
        x = self.crelu(self.dense1(x))

        x = self.batch_norm2(x)
        x = self.dropout2(x)
        x = self.crelu(self.dense2(x))
        
        x_env = self.batch_norm_env(x_env)
        x_env = self.dense1_env(x_env)
        x_env = self.dense2_env(x_env)
        x_env = self.dense3_env(x_env)
        
        x = torch.cat((x, x_env), dim = 1)

        x = self.batch_norm3(x)
        x = self.dropout3(x)
        x = self.dense3(x)

        return x

In [None]:
class SModel(nn.Module):

    def __init__(self, num_features, num_targets, hidden_size):
        super(SModel, self).__init__()
        self.batch_norm1 = nn.BatchNorm1d(num_features)
        self.dense1 = nn.utils.weight_norm(
            nn.Linear(num_features, hidden_size))

        self.batch_norm2 = nn.BatchNorm1d(hidden_size*2)
        self.dropout2 = nn.Dropout(0.25)
        self.dense2 = nn.Linear(hidden_size*2, hidden_size)

        self.batch_norm3 = nn.BatchNorm1d((hidden_size)*2)
        self.dropout3 = nn.Dropout(0.25)
        self.dense3 = nn.utils.weight_norm(nn.Linear((hidden_size)*2, num_targets))
        
        if use_screlu:
            self.crelu =  SCReLU()
        elif use_crelu:
            self.crelu = CReLU()
        else:
            self.crelu = nn.ReLU()
        
    def forward(self, x):
        x = self.batch_norm1(x)
        x = self.crelu(self.dense1(x))

        x = self.batch_norm2(x)
        x = self.dropout2(x)
        x = self.crelu(self.dense2(x))

        x = self.batch_norm3(x)
        x = self.dropout3(x)
        x = self.dense3(x)

        return x

In [None]:
test_ = test.copy()

In [None]:
folds[q_cols]

Unnamed: 0,absorbance0_q,absorbance1_q,absorbance2_q,absorbance3_q,absorbance4_q,absorbance5_q,absorbance6_q,absorbance7_q,absorbance8_q,absorbance9_q,absorbance10_q,absorbance11_q,absorbance12_q,absorbance13_q,absorbance14_q,absorbance15_q,absorbance16_q,absorbance17_q,absorbance18_q,absorbance19_q,absorbance20_q,absorbance21_q,absorbance22_q,absorbance23_q,absorbance24_q,absorbance25_q,absorbance26_q,absorbance27_q,absorbance28_q,absorbance29_q,absorbance30_q,absorbance31_q,absorbance32_q,absorbance33_q,absorbance34_q,absorbance35_q,absorbance36_q,absorbance37_q,absorbance38_q,absorbance39_q,...,absorbance130_q,absorbance131_q,absorbance132_q,absorbance133_q,absorbance134_q,absorbance135_q,absorbance136_q,absorbance137_q,absorbance138_q,absorbance139_q,absorbance140_q,absorbance141_q,absorbance142_q,absorbance143_q,absorbance144_q,absorbance145_q,absorbance146_q,absorbance147_q,absorbance148_q,absorbance149_q,absorbance150_q,absorbance151_q,absorbance152_q,absorbance153_q,absorbance154_q,absorbance155_q,absorbance156_q,absorbance157_q,absorbance158_q,absorbance159_q,absorbance160_q,absorbance161_q,absorbance162_q,absorbance163_q,absorbance164_q,absorbance165_q,absorbance166_q,absorbance167_q,absorbance168_q,absorbance169_q
0,-0.786971,-0.864180,-0.749764,-0.761295,-0.714189,-0.806760,-0.748991,-0.785651,-0.757505,-0.754075,-0.802392,-0.736700,-0.735698,-0.736765,-0.792381,-0.836177,-0.849975,-0.848249,-0.872517,-0.859094,-0.852051,-0.858251,-0.838466,-0.845570,-0.876431,-0.849254,-0.843322,-0.834989,-0.849357,-0.827548,-0.834901,-0.783344,-0.829768,-0.828900,-0.806071,-0.821135,-0.801307,-0.814189,-0.809571,-0.837569,...,0.005347,-0.200495,-0.169134,-0.201174,-0.268700,-0.360849,-0.299175,-0.374937,-0.368232,-0.446304,-0.481743,-0.536963,-0.546214,-0.592069,-0.548906,-0.652757,-0.559784,-0.701215,-0.655335,-0.632611,-0.791866,-0.784521,-0.713666,-0.753014,-0.735886,-0.767530,-0.892156,-0.781300,-0.744020,-0.801688,-0.854736,-0.823175,-0.759883,-0.867688,-0.590366,-0.637286,-0.584662,-0.683980,-0.768665,-0.815892
1,-0.925544,-0.924372,-0.880664,-0.870745,-0.883992,-0.904450,-0.878467,-0.859128,-0.790611,-0.783989,-0.763864,-0.790076,-0.786938,-0.774062,-0.790301,-0.818307,-0.832116,-0.829165,-0.828712,-0.854376,-0.879413,-0.903556,-0.888116,-0.901357,-0.919470,-0.912999,-0.918667,-0.941108,-0.913335,-0.905990,-0.908532,-0.934196,-0.909578,-0.901187,-0.900906,-0.866786,-0.852955,-0.840043,-0.832367,-0.841316,...,0.948714,0.729946,0.977180,0.902091,0.946421,0.751450,0.781482,0.828006,0.796453,0.723199,0.732754,0.704139,0.750715,0.728255,0.644411,0.635164,0.648150,0.558396,0.721934,0.552997,0.594627,0.631096,0.471492,0.507546,0.452656,0.485731,0.444692,0.558860,0.691252,0.551639,0.468825,0.640890,0.528006,0.351806,0.376244,0.369352,0.595002,0.423373,-0.126827,-0.011632
2,-1.193419,-1.165575,-1.210808,-1.200806,-1.206076,-1.194434,-1.213999,-1.213615,-1.215294,-1.196485,-1.200223,-1.198884,-1.194661,-1.189874,-1.187299,-1.187633,-1.183004,-1.179368,-1.177508,-1.172517,-1.164444,-1.162505,-1.165817,-1.165826,-1.162332,-1.158644,-1.161987,-1.159582,-1.160201,-1.161284,-1.160110,-1.161289,-1.158017,-1.163222,-1.160461,-1.158801,-1.161669,-1.163182,-1.166843,-1.153560,...,0.188565,0.159944,0.069883,0.288553,0.230258,0.293950,0.271335,0.251015,0.330716,0.186716,0.238123,0.378758,0.367895,0.285911,0.325094,0.349507,0.227160,0.498355,0.458895,0.435385,0.264527,0.344995,0.300342,0.215132,0.311200,0.423823,0.357111,0.567359,0.295868,0.482974,0.504709,0.440702,0.337791,0.147785,0.098000,0.592990,0.029350,0.722413,0.123484,0.759315
3,0.088019,0.005410,0.204558,-0.064415,-0.193357,-0.190716,-0.086935,-0.051213,-0.112525,-0.110604,-0.041412,-0.006799,-0.048340,-0.011503,-0.058449,-0.050447,-0.039812,0.025250,0.046266,0.075767,0.071656,0.020605,0.014171,-0.018131,-0.001476,-0.017892,0.004152,-0.023272,-0.038521,-0.075227,-0.085654,-0.102376,-0.164947,-0.170319,-0.195401,-0.207624,-0.247556,-0.282357,-0.303156,-0.327543,...,-0.634534,-0.646776,-0.662182,-0.703819,-0.636916,-0.678044,-0.758524,-0.747302,-0.749403,-0.733126,-0.822535,-0.770965,-0.832533,-0.837970,-0.875521,-0.844299,-0.791368,-0.892418,-0.964479,-0.926125,-0.895796,-0.877917,-0.930868,-0.847778,-0.834806,-0.922771,-0.841652,-0.882441,-0.927972,-0.826974,-0.821212,-0.751495,-0.440116,-0.351561,0.043776,0.303395,0.359714,0.472551,0.256198,0.794309
4,-0.020570,0.235269,0.269182,0.188956,0.215264,0.303266,0.269301,0.304612,0.369432,0.319413,0.250242,0.353985,0.361119,0.351334,0.435880,0.416650,0.554392,0.567539,0.607426,0.661393,0.697537,0.695498,0.720821,0.753877,0.738045,0.761260,0.747472,0.773677,0.789984,0.766801,0.761877,0.773118,0.792387,0.790627,0.776728,0.787270,0.774169,0.797309,0.780521,0.787507,...,-0.854737,-0.772028,-0.863299,-0.796284,-0.755787,-0.771807,-0.718124,-0.631960,-0.675858,-0.593209,-0.498462,-0.537186,-0.526856,-0.510368,-0.491973,-0.493553,-0.373520,-0.363134,-0.434087,-0.416978,-0.438140,-0.322276,-0.241307,-0.415354,-0.270319,-0.400733,-0.122631,-0.158394,-0.116011,-0.091576,-0.099951,-0.052423,0.034232,-0.239900,-0.116420,-0.027904,0.251437,0.743818,0.437047,0.787630
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13135,-0.726522,-0.773684,-0.756408,-0.656745,-0.711156,-0.744710,-0.742791,-0.711536,-0.669946,-0.684975,-0.669075,-0.707322,-0.727004,-0.716830,-0.714560,-0.743937,-0.744682,-0.746687,-0.761531,-0.773453,-0.782270,-0.788862,-0.789471,-0.766248,-0.787955,-0.747193,-0.762891,-0.766141,-0.728208,-0.764537,-0.741253,-0.735849,-0.730714,-0.733229,-0.739357,-0.745469,-0.734612,-0.724107,-0.727950,-0.736794,...,-1.004504,-1.158631,-1.063012,-1.094996,-1.161065,-1.092118,-1.178779,-1.144267,-1.150592,-1.169634,-1.194748,-1.185083,-1.176600,-1.189944,-1.210122,-1.204190,-1.144455,-1.184090,-1.160018,-1.176165,-1.205137,-1.178304,-1.199179,-1.167092,-1.211026,-1.203026,-1.086229,-1.155308,-1.208478,-1.169114,-1.180868,-1.177219,-1.163229,-1.150258,-0.978887,-0.940301,-0.893090,-0.591298,-0.837573,-0.749748
13136,0.512873,0.508444,0.418884,0.311396,0.416590,0.366923,0.343281,0.368698,0.365065,0.440982,0.454031,0.331190,0.359372,0.343365,0.369023,0.333937,0.436914,0.459530,0.379138,0.294328,0.303841,0.304380,0.269728,0.235077,0.277211,0.282938,0.252053,0.284806,0.271599,0.300217,0.283546,0.281369,0.274515,0.331812,0.360035,0.379817,0.399321,0.385403,0.410399,0.472669,...,0.862627,1.072665,1.082042,1.081815,0.994342,1.103853,1.085500,1.040155,1.156730,0.997372,1.138976,1.153781,1.177291,1.163146,1.176169,1.072870,1.025334,1.198652,1.050350,1.075465,1.161095,1.160210,1.230940,1.207270,1.158279,1.151956,1.169783,1.175273,1.209241,1.179995,1.174539,0.925073,1.175723,1.157584,1.016327,1.190866,1.100848,1.044536,1.080872,1.133224
13137,0.060573,0.175970,0.294765,0.083622,0.264475,0.217749,0.280023,0.364702,0.387203,0.365010,0.470240,0.514183,0.556040,0.552785,0.613848,0.641209,0.578354,0.665447,0.638455,0.676526,0.693761,0.675299,0.654567,0.677992,0.687690,0.661516,0.687146,0.654851,0.642701,0.648342,0.662572,0.622096,0.644693,0.662303,0.630022,0.631157,0.588534,0.625284,0.637791,0.631189,...,0.633195,0.607147,0.635448,0.736553,0.635200,0.711113,0.544655,0.734594,0.695521,0.704520,0.727205,0.802097,0.695919,0.793078,0.735224,0.722313,0.798280,0.820054,0.850148,0.988690,0.816417,0.973748,0.978803,0.958757,0.946359,0.834825,0.884191,0.996301,1.108747,1.047904,1.003578,1.045255,0.969722,1.196814,1.191217,1.228748,1.129584,1.302465,1.260595,1.096716
13138,-1.160425,-0.953491,-1.135947,-1.167847,-1.151433,-1.155873,-1.099670,-1.151373,-1.067672,-1.037141,-1.072798,-1.113509,-1.078343,-1.036348,-1.052961,-1.037828,-1.041294,-1.091715,-1.025909,-1.011761,-0.982496,-0.961183,-1.000817,-0.979158,-0.979029,-0.992744,-0.978004,-1.034359,-0.981565,-0.974461,-0.994983,-1.010028,-0.980508,-0.995008,-0.988541,-0.982913,-0.999844,-0.987320,-0.972659,-0.976463,...,-1.345986,-1.432086,-1.384386,-1.355260,-1.336108,-1.340254,-1.322521,-1.341003,-1.252699,-1.300171,-1.274233,-1.241261,-1.257020,-1.253388,-1.246544,-1.227465,-1.213584,-1.248587,-1.274239,-1.218337,-1.249783,-1.187209,-1.239202,-1.263199,-1.251528,-1.224649,-1.249678,-1.184693,-1.174361,-1.224818,-1.210134,-1.193675,-1.142709,-1.148653,-1.172761,-1.037771,-1.005407,-0.904384,-1.166684,-0.506224


In [None]:
def run_training(fold, seed):
    
    seed_everything(seed)
    
    train_idx = folds[folds['kfold'] != fold].index
    valid_idx = folds[folds['kfold'] == fold].index
    
    train_df = folds.iloc[train_idx].reset_index(drop =True)
    valid_df = folds.iloc[valid_idx].reset_index(drop =True)
    
    x_train, y_train = train_df[feature_cols].values, train_df[new_cols].values
    x_valid, y_valid = valid_df[feature_cols].values, valid_df[new_cols].values

    x_train_q , x_valid_q = train_df[q_cols].values, valid_df[q_cols].values
    
    x_train_env, x_valid_env = train_df[feature_cols_env].values, valid_df[feature_cols_env].values
    
    x_test, x_test_env, x_test_q = test_[feature_cols].values, test_[feature_cols_env].values, test_[q_cols].values
    
    scaler = StandardScaler()
    
    scaler.fit(folds[feature_cols].values)
    
    x_train = scaler.transform(x_train)
    x_valid = scaler.transform(x_valid)
    x_test  = scaler.transform(x_test)
    
    train_dataset = BloodDataset(features =x_train, q_features = x_train_q,  env_features=x_train_env, targets=y_train, train_mode = True)
    valid_dataset = BloodDataset(features =x_valid, q_features = x_valid_q,  env_features=x_valid_env, targets=y_valid, train_mode = True)
    testdataset   = BloodDataset(features = x_test, q_features = x_test_q,   env_features=x_test_env,  targets = None,  train_mode = False)

    trainloader = DataLoader(
        train_dataset, collate_fn =  single_spectral_collator if single_inp else double_spectral_collator, batch_size=BATCH_SIZE, shuffle=True)
    validloader = DataLoader(
        valid_dataset,collate_fn = single_spectral_collator if single_inp else double_spectral_collator, batch_size=BATCH_SIZE, shuffle=False)
    testloader = DataLoader(
        testdataset, collate_fn = test_single_spectral_collator if single_inp else test_double_spectral_collator, batch_size=BATCH_SIZE, shuffle=False)
    
    if single_inp:
        
        model = SModel(
        num_features=num_features,
        num_targets=num_targets,
        hidden_size=hidden_size)
    else:

        model = DModel(
            num_features=num_features,
            num_env_features = num_env_features,
            num_targets=num_targets,
            hidden_size=hidden_size,
            hidden_size_env= hidden_size_env
        )

    model.to(DEVICE)
    
    optimizer = optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e3,
                                              max_lr=1e-2, epochs=EPOCHS, steps_per_epoch=len(trainloader))
    if not use_smoothing_loss:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = SmoothBCEwLogits(smoothing=0.001)

    oof = np.zeros((len(train), len(new_cols)))
   
    early_stopping_steps = EARLY_STOPPING_STEPS
    early_step = 0
    
    min_loss = np.inf
    best_loss_epoch = -1
    
    for epoch in range(EPOCHS):
        
        logging.info(f"Epoch {epoch + 1}")
        
        #--------------------- TRAIN---------------------

        train_loss = train_fn(model, trainloader, criterion, optimizer, scheduler , DEVICE)
        
        #--------------------- VALID---------------------

        valid_loss, valid_preds = val_fn(model, validloader, criterion, DEVICE)
        
        if valid_loss < min_loss:
            min_loss = valid_loss
            best_loss_epoch = epoch
            oof[valid_idx] = valid_preds
            
            torch.save(model.state_dict(), f"{model_output_folder}/SEED{seed}_FOLD{fold}_.pth")
            
        elif(EARLY_STOP == True):
            early_step += 1
            
            if(early_step >= early_stopping_steps):
                break
            
            
        if (epoch % 10 == 0)  or (epoch == EPOCHS - 1):
            print(f"Fold {fold}--Seed {seed}--Epoch {epoch}--Train Loss {train_loss:.6f}--Valid Loss {valid_loss:.6f}--Best Loss {min_loss:.6f}")

    
    #--------------------- PREDICTION---------------------

    
    if single_inp:
        model = SModel(
        num_features=num_features,
        num_targets=num_targets,
        hidden_size=hidden_size)
    else:
        model = DModel(
            num_features=num_features,
            num_env_features = num_env_features,
            num_targets=num_targets,
            hidden_size=hidden_size,
            hidden_size_env= hidden_size_env
        )
        
    # Load the best model
    model.load_state_dict(torch.load(f"{model_output_folder}/SEED{seed}_FOLD{fold}_.pth"))
    model.to(DEVICE)
    
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    num_params = sum([np.prod(p.size()) for p in model_parameters])
    print(f"Model Size: {num_params:,} trainable parameters")

    predictions = inference_fn(model, testloader, DEVICE)

    return oof, predictions

In [None]:
def run_k_fold(NFOLDS, seed):
    
    oof = np.zeros((len(train), len(new_cols)))
    predictions = np.zeros((len(test), len(new_cols)))

    for fold in range(NFOLDS):
        
        print(f"Fold ==== {fold}/{NFOLDS}")
        
        oof_, pred_ = run_training(fold, seed)

        predictions += pred_ / NFOLDS
        oof += oof_
        

    return oof, predictions

In [None]:
folds['kfold'].unique()

array([5, 7, 6, 2, 3, 1, 8, 4, 0], dtype=int32)

In [None]:
# Averaging on multiple SEEDS

if torch.cuda.is_available():
    SEED = [940 , 1513, 1269, 1321, 2491]

else:
    SEED = [940 , 1513, 1269]

oof = np.zeros((len(train), len(new_cols)))
predictions = np.zeros((len(test), len(new_cols)))

if not use_cv:
    NFOLDS = n_clusters

for seed in SEED:

    oof_, predictions_ = run_k_fold(NFOLDS, seed)
    oof += oof_ / len(SEED)
    predictions += predictions_ / len(SEED)

train[new_cols] = oof
test[new_cols] = predictions

In [None]:
# test.to_csv(maindir+"/test_submission_NNet_2inputs_threshold_5_9_clusters.csv", index=False)

In [None]:
# test = pd.read_csv(maindir+'/test_submission_NNet_2inputs_threshold_5_9_clusters.csv')
# test.head()

In [None]:
predictions_ = test[new_cols].values

In [None]:
preds = (predictions_ > 0.5).astype(int)
preds

array([[1, 0, 0, ..., 1, 0, 0],
       [1, 0, 0, ..., 1, 0, 0],
       [1, 0, 0, ..., 1, 0, 0],
       ...,
       [1, 0, 0, ..., 1, 0, 0],
       [1, 0, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 1, 0, 0]])

In [None]:
test[new_cols] = preds

In [None]:
test[new_cols]

Unnamed: 0,hdl_cholesterol_human_ok,hdl_cholesterol_human_high,hdl_cholesterol_human_low,cholesterol_ldl_human_ok,cholesterol_ldl_human_high,cholesterol_ldl_human_low,hemoglobin(hgb)_human_ok,hemoglobin(hgb)_human_high,hemoglobin(hgb)_human_low
0,1,0,0,0,1,0,1,0,0
1,1,0,0,1,0,0,1,0,0
2,1,0,0,0,1,0,1,0,0
3,1,0,0,0,0,0,1,0,0
4,1,0,0,1,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...
3655,0,0,0,1,0,0,1,0,0
3656,0,0,0,1,0,0,1,0,0
3657,1,0,0,0,0,0,1,0,0
3658,1,0,0,0,1,0,1,0,0


In [None]:
def inverse_transform(data):
    
    
    def extract(vals, cols):
        
        index= np.argmax(vals)
        
        return cols[index]
        
        
    df = data.copy()
            
    step_size = 3
    start = 0
    
    for i, cols_j in enumerate(range(start, len(new_cols), step_size)):
        
        start = cols_j
        
        cols_i = new_cols[start: (start + step_size)]
        
        print(f'Columns idexed from {start} to {start + step_size} --> {cols_i}')
        
        df.loc[:, 'temp_col_'+str(i)] = df[cols_i].apply(lambda s : extract(s.values, cols_i), axis = 1)
                
        col_name = '_'.join(cols_i[0].split('_')[:-1])
                
        df.loc[:, col_name] = df['temp_col_'+str(i)].apply(lambda k : k.split('_')[-1])
        
    return df

In [None]:
test_ = inverse_transform(test)
test_

Columns idexed from 0 to 3 --> ['hdl_cholesterol_human_ok', 'hdl_cholesterol_human_high', 'hdl_cholesterol_human_low']
Columns idexed from 3 to 6 --> ['cholesterol_ldl_human_ok', 'cholesterol_ldl_human_high', 'cholesterol_ldl_human_low']
Columns idexed from 6 to 9 --> ['hemoglobin(hgb)_human_ok', 'hemoglobin(hgb)_human_high', 'hemoglobin(hgb)_human_low']


Unnamed: 0,Reading_ID,absorbance0,absorbance1,absorbance2,absorbance3,absorbance4,absorbance5,absorbance6,absorbance7,absorbance8,absorbance9,absorbance10,absorbance11,absorbance12,absorbance13,absorbance14,absorbance15,absorbance16,absorbance17,absorbance18,absorbance19,absorbance20,absorbance21,absorbance22,absorbance23,absorbance24,absorbance25,absorbance26,absorbance27,absorbance28,absorbance29,absorbance30,absorbance31,absorbance32,absorbance33,absorbance34,absorbance35,absorbance36,absorbance37,absorbance38,...,absorbance146_q,absorbance147_q,absorbance148_q,absorbance149_q,absorbance150_q,absorbance151_q,absorbance152_q,absorbance153_q,absorbance154_q,absorbance155_q,absorbance156_q,absorbance157_q,absorbance158_q,absorbance159_q,absorbance160_q,absorbance161_q,absorbance162_q,absorbance163_q,absorbance164_q,absorbance165_q,absorbance166_q,absorbance167_q,absorbance168_q,absorbance169_q,kfold,hdl_cholesterol_human_ok,hdl_cholesterol_human_high,hdl_cholesterol_human_low,cholesterol_ldl_human_ok,cholesterol_ldl_human_high,cholesterol_ldl_human_low,hemoglobin(hgb)_human_ok,hemoglobin(hgb)_human_high,hemoglobin(hgb)_human_low,temp_col_0,hdl_cholesterol_human,temp_col_1,cholesterol_ldl_human,temp_col_2,hemoglobin(hgb)_human
0,ID_37BEI22R,0.449736,0.449798,0.447488,0.464694,0.466377,0.485350,0.488915,0.495073,0.504129,0.512690,0.528313,0.540020,0.550252,0.555062,0.555983,0.562491,0.559443,0.562695,0.558805,0.559067,0.557602,0.554924,0.553300,0.549671,0.548033,0.544539,0.542640,0.541228,0.540335,0.539378,0.539134,0.538375,0.538068,0.540340,0.541754,0.543920,0.547749,0.550490,0.553550,...,-0.877405,-0.848976,-0.759211,-0.787022,-0.865335,-0.682490,-0.707377,-0.574728,-0.493538,-0.731705,-0.589431,-0.551942,-0.623715,-0.759709,-0.798007,-0.762391,-1.036127,-1.212637,-0.940384,-1.179244,-1.183394,-0.988800,-1.157198,-0.634158,4,1,0,0,0,1,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_high,high,hemoglobin(hgb)_human_ok,ok
1,ID_4W85V5DV,0.495429,0.505488,0.510239,0.518880,0.533147,0.543142,0.551670,0.558261,0.564027,0.575223,0.588780,0.603260,0.609797,0.613326,0.616530,0.617400,0.617284,0.615343,0.611668,0.608864,0.606411,0.602919,0.599854,0.597024,0.592800,0.590059,0.586417,0.585922,0.583848,0.583204,0.582259,0.581994,0.582528,0.584993,0.587332,0.590686,0.591674,0.595796,0.599694,...,0.487869,0.550592,0.462785,0.491374,0.421365,0.530339,0.585208,0.588573,0.542323,0.583090,0.413273,0.530809,0.582355,0.280880,0.370204,0.404643,0.273017,0.078774,-0.147432,-0.329462,-0.308039,-0.474419,-0.234631,-0.078461,3,1,0,0,1,0,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_ok,ok,hemoglobin(hgb)_human_ok,ok
2,ID_L4YR3NDY,0.437904,0.439064,0.442527,0.450437,0.455363,0.465817,0.471249,0.479145,0.482595,0.497043,0.508849,0.520005,0.526073,0.529009,0.530775,0.530869,0.529993,0.529816,0.525386,0.522270,0.518925,0.516824,0.514363,0.510227,0.506540,0.503605,0.501884,0.499315,0.498547,0.497386,0.496028,0.495754,0.495847,0.495887,0.497499,0.499683,0.501803,0.504862,0.508623,...,-1.531749,-1.572441,-1.588435,-1.582740,-1.616694,-1.563345,-1.520616,-1.590241,-1.610320,-1.635457,-1.644092,-1.627080,-1.551553,-1.656906,-1.597835,-1.556996,-1.484154,-1.395691,-1.279930,-1.197030,-1.154629,-1.094291,-0.942452,-0.922984,4,1,0,0,0,1,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_high,high,hemoglobin(hgb)_human_ok,ok
3,ID_U88E3SQ6,0.495038,0.506246,0.508730,0.518995,0.529961,0.537583,0.539696,0.540400,0.547279,0.561166,0.572493,0.583802,0.588819,0.591780,0.596486,0.595962,0.595182,0.588548,0.584253,0.579974,0.576841,0.573102,0.569567,0.565961,0.563061,0.560563,0.556971,0.555630,0.554065,0.554014,0.552711,0.552815,0.552691,0.555071,0.557024,0.558817,0.563014,0.566382,0.571307,...,-1.268176,-1.288147,-1.283527,-1.276344,-1.272091,-1.305849,-1.295253,-1.352451,-1.282066,-1.299092,-1.306644,-1.312904,-1.310104,-1.309388,-1.315472,-1.400438,-1.440712,-1.566567,-1.565088,-1.776853,-1.688696,-1.785187,-1.710679,-1.688173,7,1,0,0,0,0,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_ok,ok,hemoglobin(hgb)_human_ok,ok
4,ID_NW7Z3XU7,0.531306,0.525309,0.535306,0.541387,0.551364,0.559821,0.564851,0.570824,0.577426,0.589114,0.601409,0.616401,0.621386,0.626131,0.626661,0.627811,0.626961,0.624922,0.621003,0.619719,0.615285,0.612897,0.609494,0.607091,0.603417,0.600907,0.599359,0.597534,0.595879,0.593052,0.590476,0.590287,0.591087,0.591824,0.592791,0.593540,0.597088,0.600950,0.603265,...,0.207654,0.184678,0.030096,0.046733,0.157799,0.126716,0.059818,0.187054,0.212648,0.060491,0.201111,-0.024547,-0.062209,0.147198,0.212886,0.269482,-0.090020,-0.027121,0.254864,0.128129,0.049122,0.127416,-0.272043,0.268737,6,1,0,0,1,0,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_ok,ok,hemoglobin(hgb)_human_ok,ok
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3655,ID_ADCBL266,0.506681,0.506536,0.510056,0.511791,0.518384,0.524765,0.530773,0.538229,0.549651,0.559759,0.574382,0.590426,0.599819,0.601076,0.602852,0.600078,0.599296,0.598479,0.598490,0.594793,0.593218,0.590962,0.587704,0.584707,0.583129,0.578749,0.576627,0.576120,0.574660,0.573868,0.572408,0.573434,0.572881,0.574630,0.576806,0.579274,0.582911,0.585880,0.589952,...,1.051884,0.892047,0.851556,0.989683,0.822971,0.903805,0.918390,0.851617,0.877269,0.783485,0.844495,0.920731,0.685098,0.795862,1.037097,0.771232,1.104058,0.876497,0.997663,0.965964,1.147378,1.046975,0.714406,0.426715,3,0,0,0,1,0,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_ok,ok,hemoglobin(hgb)_human_ok,ok
3656,ID_SW51B61O,0.488276,0.501509,0.498858,0.500627,0.511329,0.522876,0.530738,0.538328,0.542644,0.556195,0.566892,0.582109,0.586264,0.589912,0.592691,0.594288,0.593082,0.590602,0.588099,0.585223,0.582153,0.579395,0.576336,0.572624,0.570377,0.566799,0.564190,0.561980,0.560520,0.559923,0.557758,0.557790,0.557553,0.558330,0.560658,0.560559,0.564633,0.568568,0.570692,...,0.079219,-0.134436,-0.017194,-0.088585,-0.201595,-0.116964,-0.216065,-0.169608,-0.235793,-0.147448,-0.125162,-0.227040,-0.170962,-0.002497,-0.303555,-0.274957,-0.314911,-0.205031,0.068969,-0.003583,0.051854,-0.095100,0.248121,0.228775,0,0,0,0,1,0,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_ok,ok,hemoglobin(hgb)_human_ok,ok
3657,ID_CO8IHJRA,0.494581,0.501446,0.499981,0.509865,0.512139,0.519129,0.521906,0.527789,0.538997,0.549444,0.562798,0.572487,0.579927,0.580615,0.582973,0.582676,0.580580,0.579033,0.575704,0.574223,0.570310,0.566706,0.563887,0.560281,0.557116,0.553743,0.552781,0.549288,0.548563,0.547542,0.546325,0.547862,0.547632,0.549180,0.550944,0.552574,0.557007,0.559878,0.563191,...,-0.775995,-0.851013,-0.807029,-0.866671,-0.804923,-0.808710,-0.839263,-0.831940,-0.770288,-0.870117,-0.946281,-0.787531,-0.859056,-0.693800,-0.868922,-0.773375,-0.557429,-0.639349,-0.603257,-0.040051,-0.276653,-0.115542,0.126915,0.492801,1,1,0,0,0,0,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_ok,ok,hemoglobin(hgb)_human_ok,ok
3658,ID_VN5CP3ZZ,0.431551,0.434236,0.433433,0.437899,0.451583,0.461391,0.471832,0.468035,0.471895,0.487380,0.495404,0.510000,0.514567,0.520869,0.520769,0.519822,0.521647,0.517721,0.516848,0.515657,0.513473,0.509555,0.507770,0.505066,0.501918,0.499459,0.496803,0.493903,0.493249,0.492733,0.493896,0.492419,0.491682,0.493235,0.494788,0.497434,0.500312,0.503646,0.506452,...,-1.124943,-1.020317,-1.099463,-1.179474,-1.215190,-1.176344,-1.155791,-1.200006,-1.119665,-0.949163,-1.167303,-1.160329,-1.192286,-1.203652,-1.215622,-1.237119,-1.292134,-1.324468,-1.376866,-1.377347,-1.461733,-1.210102,-1.297050,-0.907905,7,1,0,0,0,1,0,1,0,0,hdl_cholesterol_human_ok,ok,cholesterol_ldl_human_high,high,hemoglobin(hgb)_human_ok,ok


#### -  Convert our submission as per the sample submission 

In [None]:
def transform_c_hdl(row):
    return str(row["Reading_ID"]) + "_hdl_cholesterol_human" + "-" +  row["hdl_cholesterol_human"]

In [None]:
hdl_rows = pd.DataFrame(test_[['Reading_ID'] + targets].apply(transform_c_hdl, axis=1))

In [None]:
def transform_hemo(row):
    return str(row["Reading_ID"]) + "_hemoglobin(hgb)_human" +  "-" + row["hemoglobin(hgb)_human"]

In [None]:
hemo_rows = pd.DataFrame(test_[['Reading_ID'] + targets].apply(transform_hemo, axis=1))

In [None]:
def transform_c_ldl(row):
    return str(row["Reading_ID"]) + "_cholesterol_ldl_human" +  "-" + row["cholesterol_ldl_human"]

In [None]:
ldl_rows = pd.DataFrame(test_[['Reading_ID'] + targets].apply(transform_c_ldl, axis=1))

In [None]:
ss = pd.concat([hdl_rows, hemo_rows, ldl_rows]).reset_index(drop=True)

In [None]:
ss["target"] = ss[0].apply(lambda x: x.split("-")[1])
ss[0] = ss[0].apply(lambda x: x.split("-")[0])

In [None]:
ss = ss.rename(columns={0:"Reading_ID"})

In [None]:
ss.target.value_counts()

ok      10408
high      572
Name: target, dtype: int64

In [None]:
ss.to_csv("submission_NNet_2inputs_threshold_5_9_clusters.csv", index=False)

In [None]:
ss

Unnamed: 0,Reading_ID,target
0,ID_37BEI22R_hdl_cholesterol_human,ok
1,ID_4W85V5DV_hdl_cholesterol_human,ok
2,ID_L4YR3NDY_hdl_cholesterol_human,ok
3,ID_U88E3SQ6_hdl_cholesterol_human,ok
4,ID_NW7Z3XU7_hdl_cholesterol_human,ok
...,...,...
10975,ID_ADCBL266_cholesterol_ldl_human,ok
10976,ID_SW51B61O_cholesterol_ldl_human,ok
10977,ID_CO8IHJRA_cholesterol_ldl_human,ok
10978,ID_VN5CP3ZZ_cholesterol_ldl_human,high
