In [None]:
import torch 
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from torch import device, cuda
from torch.optim import Adam
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from sklearn.metrics import f1_score, confusion_matrix, precision_recall_curve, ConfusionMatrixDisplay
import numpy as np 
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from time import time
import os

import sys
sys.path.append('..')  
from src.utils.RawDataReader import RawDataReader


In [None]:
reader = RawDataReader()
labels = reader.get_labels()
datasets, means, stds = reader.get_normalised_dataset()

In [None]:
from torch.nn import Module, Sequential, Conv1d, LeakyReLU, MaxPool1d, Flatten, Linear, Sigmoid

class SingleHeadCNNModel(Module):
    def __init__(self, feature_count:int, hidden_count:int=5):
        super(SingleHeadCNNModel, self).__init__()

        self.slope = 0.001

        self.channel_counts = 2

        self.cnn_model = Sequential(
            *[self.cnn_block((self.channel_counts**_h_index),(self.channel_counts**(_h_index+1))) for _h_index in range(hidden_count)],
            Flatten()
        )
        self.fc_model = Sequential(
            self.linear_block(in_features=feature_count, out_features=100),
            self.linear_block(in_features=100, out_features=10),
            self.linear_block(in_features=10, out_features=1),
            Sigmoid()
        )

    def cnn_block(self, in_channels:int, out_channels:int, kernel_size:int=3, padding=1):
        return Sequential(
            Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding),
            LeakyReLU(negative_slope=self.slope),
            MaxPool1d(kernel_size=2)
        )
    
    def linear_block(self, in_features, out_features):
        return Sequential(
            Linear(in_features=in_features, out_features=out_features),
            LeakyReLU(negative_slope=self.slope)
        )
    
    def forward(self, x):
        return self.fc_model(self.cnn_model(x))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on {device}")

Training on cuda


In [13]:
r = 2 
batch_size = 2048
learning_rate = 0.01
test_split = 0.2
epochs = 9999

random_state = 1
hidden_count = 0

num_lables = labels.shape[1]

In [14]:
curr_dataset = datasets[r]
feature_count = curr_dataset.shape[1]
for label_index in range(num_lables):
    curr_label = labels[:, [label_index]]
    x_train, x_test, y_train, y_test = train_test_split(curr_dataset, curr_label, test_size=test_split, random_state=random_state, stratify=curr_label)
    x_train = x_train.unsqueeze(1).to(device)
    y_train.to(device)
    x_test = x_test.unsqueeze(1).to(device)

    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    head_model = SingleHeadCNNModel(feature_count=feature_count, hidden_count=hidden_count)
    head_model.to(device)
    
    criterion = torch.nn.BCELoss().to(device)
    optimizer = Adam(head_model.parameters(), lr=learning_rate)

    for epoch in tqdm(range(epochs)):
        head_model.train()
        total_loss = 0
        for batch_features, batch_labels in train_loader:
            optimizer.zero_grad()
            predictions = head_model(batch_features.to(device))
            loss = criterion(predictions, batch_labels.to(device))
            loss.backward()
            optimizer.step()
            total_loss += loss
        total_loss = (total_loss / len(train_loader)).item()

        y_test_pred = head_model(x_test)
        f1 = f1_score(y_test,(y_test_pred.cpu() >0.5).float())
        print(total_loss, f1)
    break
    

  0%|          | 2/9999 [00:00<16:59,  9.80it/s]

0.6787839531898499 0.5960784313725491
0.6609547734260559 0.6406685236768802


  0%|          | 4/9999 [00:00<20:20,  8.19it/s]

0.6508848667144775 0.6505576208178439
0.6331400275230408 0.6580406654343808


  0%|          | 6/9999 [00:00<21:35,  7.71it/s]

0.5890602469444275 0.6115384615384615
0.48737096786499023 0.6090373280943027


  0%|          | 8/9999 [00:01<23:15,  7.16it/s]

0.3474743068218231 0.594059405940594
0.25762778520584106 0.5873015873015872


  0%|          | 10/9999 [00:01<20:08,  8.27it/s]

0.22328992187976837 0.5791583166332664
0.2085830420255661 0.596039603960396


  0%|          | 12/9999 [00:01<21:04,  7.90it/s]

0.19823309779167175 0.6176185866408518
0.1899748146533966 0.6137667304015297


  0%|          | 15/9999 [00:01<20:05,  8.28it/s]

0.1836894303560257 0.644506001846722
0.17960704863071442 0.6501831501831502
0.17514614760875702 0.6481994459833794


  0%|          | 18/9999 [00:02<19:56,  8.34it/s]

0.17069898545742035 0.6550137994480221
0.16902302205562592 0.6630727762803234
0.17108440399169922 0.6648451730418943


  0%|          | 20/9999 [00:02<20:46,  8.00it/s]

0.16967836022377014 0.6660583941605839
0.16564500331878662 0.6543438077634011


  0%|          | 23/9999 [00:02<20:08,  8.25it/s]

0.16358967125415802 0.6598702502316961
0.1639353483915329 0.6541353383458648
0.16351425647735596 0.6739327883742052


  0%|          | 26/9999 [00:03<19:51,  8.37it/s]

0.16291742026805878 0.6567717996289424
0.1606413722038269 0.66
0.15605133771896362 0.6690712353471597


  0%|          | 27/9999 [00:03<22:03,  7.53it/s]

0.15530402958393097 0.6779964221824688
0.1527824103832245 0.6750675067506751


  0%|          | 30/9999 [00:03<21:03,  7.89it/s]

0.15192946791648865 0.669683257918552
0.15166060626506805 0.6786034019695614


  0%|          | 32/9999 [00:04<22:19,  7.44it/s]

0.15535248816013336 0.6144814090019569
0.16748765110969543 0.6079207920792079


  0%|          | 33/9999 [00:04<20:52,  7.96it/s]

0.16229753196239471 0.6435921421889615
0.15692545473575592 0.6672777268560953


  0%|          | 36/9999 [00:04<20:34,  8.07it/s]

0.1528528779745102 0.6690582959641256
0.15013417601585388 0.674439461883408


  0%|          | 38/9999 [00:04<22:11,  7.48it/s]

0.14783543348312378 0.6802120141342756
0.14598610997200012 0.6654445462878095


  0%|          | 40/9999 [00:05<23:08,  7.17it/s]

0.14416269958019257 0.6858168761220825
0.1426238864660263 0.6810035842293907


  0%|          | 42/9999 [00:05<20:12,  8.21it/s]

0.14199891686439514 0.6823104693140793
0.14065535366535187 0.6843971631205674


  0%|          | 45/9999 [00:05<19:44,  8.40it/s]

0.13936373591423035 0.6939843068875327
0.1377379298210144 0.6844919786096256
0.1374596208333969 0.6843971631205674


  0%|          | 46/9999 [00:05<21:56,  7.56it/s]

0.13599969446659088 0.6906474820143884
0.1353372037410736 0.6889279437609842


  0%|          | 49/9999 [00:06<21:00,  7.90it/s]

0.13433261215686798 0.6919014084507041
0.13401971757411957 0.6905187835420394
0.13626925647258759 0.6963490650044523


  1%|          | 53/9999 [00:06<19:38,  8.44it/s]

0.13676504790782928 0.6769509981851178
0.1358499377965927 0.6822682268226823
0.1355120688676834 0.692927484333035


  1%|          | 55/9999 [00:07<20:19,  8.15it/s]

0.13527147471904755 0.6751592356687898
0.13344496488571167 0.6868327402135231


  1%|          | 57/9999 [00:07<21:06,  7.85it/s]

0.13418732583522797 0.6678899082568808
0.13377457857131958 0.6763372620126926


  1%|          | 59/9999 [00:07<21:39,  7.65it/s]

0.13284847140312195 0.6672777268560953
0.1423691362142563 0.6425855513307985


  1%|          | 61/9999 [00:07<19:42,  8.40it/s]

0.14241118729114532 0.6585820895522388
0.13833259046077728 0.667279411764706


  1%|          | 63/9999 [00:08<20:36,  8.03it/s]

0.13319651782512665 0.677536231884058
0.1320609450340271 0.6892857142857144


  1%|          | 65/9999 [00:08<21:39,  7.65it/s]

0.1363610178232193 0.676416819012797
0.13665564358234406 0.6920353982300885


  1%|          | 67/9999 [00:08<22:55,  7.22it/s]

0.13406561315059662 0.685251798561151
0.13303832709789276 0.6721915285451197


  1%|          | 69/9999 [00:08<20:16,  8.16it/s]

0.13533936440944672 0.6624319419237749
0.13892991840839386 0.6690712353471597


  1%|          | 71/9999 [00:09<21:07,  7.83it/s]

0.13505445420742035 0.6721014492753623
0.13334932923316956 0.6504672897196262
0.13275547325611115 0.6876640419947506


  1%|          | 74/9999 [00:09<20:32,  8.05it/s]

0.13367272913455963 0.6636280765724704
0.12933646142482758 0.6825396825396826


  1%|          | 76/9999 [00:09<21:07,  7.83it/s]

0.1323658674955368 0.6696428571428572
0.12985579669475555 0.6822262118491922
0.12454604357481003 0.6883802816901409


  1%|          | 79/9999 [00:10<20:28,  8.07it/s]

0.1259363442659378 0.6779661016949153
0.1290884166955948 0.6866197183098592
0.12625136971473694 0.6888694127957931


  1%|          | 82/9999 [00:10<20:20,  8.13it/s]

0.1247270256280899 0.6889279437609842
0.12545444071292877 0.6808888888888889


  1%|          | 84/9999 [00:10<20:54,  7.90it/s]

0.1267888993024826 0.6856127886323269
0.12708412110805511 0.6884955752212389
0.1289074867963791 0.65237651444548


  1%|          | 87/9999 [00:11<20:22,  8.10it/s]

0.13475151360034943 0.6636113657195233
0.1330844908952713 0.6493506493506493
0.13303585350513458 0.6624203821656051


  1%|          | 91/9999 [00:11<19:20,  8.53it/s]

0.129185751080513 0.689594356261023
0.12673644721508026 0.6672727272727274
0.12469365447759628 0.6904969485614648


  1%|          | 93/9999 [00:11<20:13,  8.17it/s]

0.12213049083948135 0.6908768821966342
0.11969844251871109 0.6894736842105262


  1%|          | 95/9999 [00:12<21:06,  7.82it/s]

0.1356116384267807 0.6894736842105262
0.125189408659935 0.6996527777777778
0.1248338371515274 0.7019982623805386


  1%|          | 98/9999 [00:12<20:35,  8.01it/s]

0.12245786190032959 0.697754749568221
0.11621644347906113 0.7010309278350516
0.12070490419864655 0.6869565217391305


  1%|          | 101/9999 [00:12<20:24,  8.09it/s]

0.12396664917469025 0.70076726342711
0.12428028881549835 0.7084745762711864


  1%|          | 103/9999 [00:13<21:03,  7.83it/s]

0.11847292631864548 0.6979982593559618
0.11701317131519318 0.6993127147766324
0.11080911010503769 0.6999140154772142


  1%|          | 106/9999 [00:13<20:20,  8.10it/s]

0.11665143817663193 0.7006802721088434
0.11171642690896988 0.7004291845493561
0.11111583560705185 0.6857142857142857


  1%|          | 109/9999 [00:13<20:05,  8.20it/s]

0.11269576847553253 0.6979982593559618
0.11203588545322418 0.6974716652136007


  1%|          | 111/9999 [00:14<20:50,  7.90it/s]

0.11148102581501007 0.694229112833764
0.11199349164962769 0.689165186500888


  1%|          | 113/9999 [00:14<21:27,  7.68it/s]

0.10976042598485947 0.6937014667817084
0.1090252548456192 0.6922406277244987


  1%|          | 115/9999 [00:14<19:28,  8.46it/s]

0.11614435911178589 0.6863084922010398
0.12734876573085785 0.6864628820960698


  1%|          | 117/9999 [00:14<20:31,  8.02it/s]

0.1309175342321396 0.6938421509106678
0.12287242710590363 0.679646017699115
0.121223583817482 0.6903114186851211


  1%|          | 120/9999 [00:15<20:16,  8.12it/s]

0.11886036396026611 0.6984402079722704
0.11634500324726105 0.7003484320557491


  1%|          | 122/9999 [00:15<21:11,  7.77it/s]

0.11938386410474777 0.6945169712793734
0.11975988000631332 0.6878868258178603
0.11626086384057999 0.7048759623609924


  1%|▏         | 126/9999 [00:16<19:31,  8.43it/s]

0.11756954342126846 0.6883230904302019
0.12226007133722305 0.6970227670753065
0.11689936369657516 0.7040552200172562


  1%|▏         | 128/9999 [00:16<20:18,  8.10it/s]

0.1117832288146019 0.7030303030303029
0.10967067629098892 0.6999140154772142


  1%|▏         | 130/9999 [00:16<20:52,  7.88it/s]

0.11027944087982178 0.7019982623805386
0.10778875648975372 0.7112253641816624
0.1067865714430809 0.7080103359173127


  1%|▏         | 133/9999 [00:17<20:21,  8.08it/s]

0.10838482528924942 0.7161125319693095
0.10864384472370148 0.697391304347826
0.11120420694351196 0.7065026362038663


  1%|▏         | 136/9999 [00:17<20:08,  8.16it/s]

0.10918853431940079 0.7086206896551724
0.10790250450372696 0.6927175843694494


  1%|▏         | 138/9999 [00:17<20:43,  7.93it/s]

0.10441895574331284 0.7218045112781954
0.10679581016302109 0.7058823529411765
0.10550935566425323 0.6993970714900948


  1%|▏         | 141/9999 [00:18<20:19,  8.09it/s]

0.1100509762763977 0.6897163120567377
0.10835295170545578 0.6972318339100346


  1%|▏         | 143/9999 [00:18<21:41,  7.57it/s]

0.10790026187896729 0.6952714535901926
0.10484911501407623 0.7214225232853514


  1%|▏         | 145/9999 [00:18<19:23,  8.47it/s]

0.10367148369550705 0.7059843885516045
0.11121570318937302 0.7024722932651322


  1%|▏         | 147/9999 [00:18<20:45,  7.91it/s]

0.10835975408554077 0.6952714535901926
0.10768377035856247 0.7131782945736435


  1%|▏         | 149/9999 [00:19<21:17,  7.71it/s]

0.10404335707426071 0.7034482758620689
0.10178182274103165 0.7095115681233932
0.09923422336578369 0.7085124677558039


  2%|▏         | 153/9999 [00:19<19:28,  8.43it/s]

0.100093774497509 0.7037671232876712
0.10482411831617355 0.6774774774774776
0.11080346256494522 0.6780883678990081


  2%|▏         | 155/9999 [00:19<20:11,  8.13it/s]

0.10820484161376953 0.6848112379280071
0.1052449569106102 0.6897147796024201


  2%|▏         | 157/9999 [00:20<20:57,  7.83it/s]

0.10574741661548615 0.6845397676496873
0.11360158026218414 0.6619217081850534
0.12062215059995651 0.6807017543859649


  2%|▏         | 160/9999 [00:20<20:28,  8.01it/s]

0.11809835582971573 0.6725978647686833
0.11148841679096222 0.6799307958477508
0.11114799976348877 0.6836555360281195


  2%|▏         | 163/9999 [00:20<20:04,  8.16it/s]

0.11209080368280411 0.6732673267326733
0.11579727381467819 0.6921075455333912


  2%|▏         | 165/9999 [00:21<20:40,  7.93it/s]

0.11832445859909058 0.6933797909407666
0.10860023647546768 0.6878761822871883


  2%|▏         | 167/9999 [00:21<22:03,  7.43it/s]

0.10996861755847931 0.6900790166812995
0.1159716472029686 0.6847161572052402


  2%|▏         | 168/9999 [00:21<20:52,  7.85it/s]

0.11886485666036606 0.6732495511669658
0.11034536361694336 0.6895338610378189


  2%|▏         | 171/9999 [00:21<20:21,  8.05it/s]

0.10721414536237717 0.6913155631986242
0.1083812341094017 0.6944444444444444
0.105503611266613 0.7084048027444254


  2%|▏         | 173/9999 [00:22<21:07,  7.75it/s]

0.09889774769544601 0.7086882453151617
0.1124722957611084 0.7027491408934707


  2%|▏         | 176/9999 [00:22<20:48,  7.87it/s]

0.10007240623235703 0.7044673539518901
0.11372645944356918 0.6967071057192374


  2%|▏         | 178/9999 [00:22<22:07,  7.40it/s]

0.11222430318593979 0.7077977720651243
0.10692206770181656 0.6992153443766347


  2%|▏         | 180/9999 [00:23<19:38,  8.33it/s]

0.10256991535425186 0.699228791773779
0.10116666555404663 0.6920473773265651


  2%|▏         | 181/9999 [00:23<22:23,  7.31it/s]

0.10855691879987717 0.6946826758147513
0.10506856441497803 0.7005937234944869


  2%|▏         | 184/9999 [00:23<21:06,  7.75it/s]

0.10368136316537857 0.7086481947942905
0.10229597240686417 0.6986301369863015
0.10063182562589645 0.7083685545224006


  2%|▏         | 188/9999 [00:24<19:21,  8.45it/s]

0.09825941175222397 0.713440405748098
0.09881089627742767 0.7014297729184188
0.09637479484081268 0.7057837384744342


  2%|▏         | 190/9999 [00:24<20:08,  8.12it/s]

0.09467777609825134 0.6940486169321041
0.10002362728118896 0.7064760302775441


  2%|▏         | 192/9999 [00:24<20:41,  7.90it/s]

0.10241510719060898 0.6988783433994822
0.1001332625746727 0.6892361111111112


  2%|▏         | 194/9999 [00:24<21:08,  7.73it/s]

0.11136267334222794 0.688695652173913
0.12165742367506027 0.6824978012313104


  2%|▏         | 195/9999 [00:24<20:04,  8.14it/s]

0.11949415504932404 0.6790780141843972
0.10910645872354507 0.6912280701754386


  2%|▏         | 199/9999 [00:25<18:54,  8.64it/s]

0.10396259278059006 0.7112253641816624
0.09981366246938705 0.7115222876366695
0.09589390456676483 0.7100591715976332


  2%|▏         | 200/9999 [00:25<20:46,  7.86it/s]

0.09347302466630936 0.7134502923976608
0.09787683933973312 0.7072758037225042


  2%|▏         | 204/9999 [00:26<19:17,  8.47it/s]

0.1167374849319458 0.6955767562879445
0.10153350979089737 0.7105038428693424
0.09704195708036423 0.7107296137339055


  2%|▏         | 207/9999 [00:26<19:06,  8.54it/s]

0.0933806523680687 0.7096774193548387
0.09237884730100632 0.7076923076923077
0.09045523405075073 0.7049742710120067


  2%|▏         | 208/9999 [00:26<20:51,  7.83it/s]

0.08878578990697861 0.7094017094017093
0.08704667538404465 0.7191383595691798


  2%|▏         | 211/9999 [00:26<20:31,  7.95it/s]

0.08736690133810043 0.7091836734693877
0.08620241284370422 0.7151310228233306


  2%|▏         | 213/9999 [00:27<21:04,  7.74it/s]

0.08529097586870193 0.7125850340136054
0.08473581075668335 0.7190635451505016


  2%|▏         | 214/9999 [00:27<20:17,  8.04it/s]

0.08536504954099655 0.7101200686106347
0.09665045142173767 0.7142857142857143


  2%|▏         | 217/9999 [00:27<20:25,  7.98it/s]

0.0875076949596405 0.7006920415224914
0.08798716217279434 0.7083685545224006


  2%|▏         | 219/9999 [00:28<21:22,  7.63it/s]

0.08566557615995407 0.7021097046413503
0.08494880050420761 0.715008431703204


  2%|▏         | 221/9999 [00:28<22:50,  7.13it/s]

0.08732005208730698 0.7030716723549487
0.09587482362985611 0.6864628820960698


  2%|▏         | 223/9999 [00:28<20:11,  8.07it/s]

0.09154081344604492 0.7046979865771812
0.0880134105682373 0.7101694915254237


  2%|▏         | 225/9999 [00:28<20:54,  7.79it/s]

0.08570417016744614 0.7158424140821459
0.08328160643577576 0.71875


  2%|▏         | 227/9999 [00:29<22:33,  7.22it/s]

0.09282400459051132 0.7013888888888888
0.08675747364759445 0.7180327868852459


  2%|▏         | 229/9999 [00:29<23:31,  6.92it/s]

0.09158062189817429 0.6857654431512982
0.09422501176595688 0.7021459227467811


  2%|▏         | 231/9999 [00:29<20:34,  7.91it/s]

0.09071766585111618 0.7108843537414966
0.08676969259977341 0.7182506307821698


  2%|▏         | 233/9999 [00:29<20:58,  7.76it/s]

0.10621409863233566 0.7089859851607584
0.14958064258098602 0.7023295944779983


  2%|▏         | 235/9999 [00:30<22:13,  7.32it/s]

0.10800069570541382 0.7112253641816624
0.10524942725896835 0.6985230234578628


  2%|▏         | 237/9999 [00:30<23:20,  6.97it/s]

0.09593210369348526 0.7184300341296929
0.09550618380308151 0.719195305951383


  2%|▏         | 238/9999 [00:30<21:25,  7.60it/s]

0.11422812938690186 0.6947368421052631
0.10198328644037247 0.7067796610169492


  2%|▏         | 242/9999 [00:30<19:27,  8.36it/s]

0.09363739937543869 0.7240802675585284
0.09504054486751556 0.7046025104602511
0.09346195310354233 0.6921029281277729


  2%|▏         | 244/9999 [00:31<20:24,  7.97it/s]

0.09466614574193954 0.711340206185567
0.08951947093009949 0.7049742710120067


  2%|▏         | 246/9999 [00:31<21:29,  7.56it/s]

0.08867479860782623 0.6999154691462385
0.09815668314695358 0.679177837354781


  2%|▏         | 248/9999 [00:31<22:36,  7.19it/s]

0.09600876271724701 0.6996644295302012
0.08763429522514343 0.697713801862828


  3%|▎         | 250/9999 [00:32<20:06,  8.08it/s]

0.08495064824819565 0.7066666666666666
0.10000776499509811 0.7033333333333333


  3%|▎         | 252/9999 [00:32<21:03,  7.71it/s]

0.10005432367324829 0.70076726342711
0.11208011955022812 0.6962576153176675


  3%|▎         | 254/9999 [00:32<22:36,  7.18it/s]

0.11338581889867783 0.6979166666666666
0.10357315838336945 0.6983857264231096


  3%|▎         | 256/9999 [00:32<23:31,  6.90it/s]

0.09608353674411774 0.6955046649703138
0.09499474614858627 0.7045075125208681


  3%|▎         | 258/9999 [00:33<20:10,  8.05it/s]

0.08874605596065521 0.7119796091758709
0.0841807872056961 0.7132043734230445


  3%|▎         | 260/9999 [00:33<20:51,  7.78it/s]

0.08172670751810074 0.715
0.09128759056329727 0.6969178082191781
0.08331949263811111 0.7222222222222222


  3%|▎         | 262/9999 [00:33<20:52,  7.78it/s]

0.08311136066913605 0.7140468227424749





KeyboardInterrupt: 

In [None]:
head_model(curr_dataset.unsqueeze(1).to(device))

tensor([[1.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [1.]], device='cuda:0', grad_fn=<SigmoidBackward0>)