In [1]:

#!python3 -m pip install pytorch_lightning
#!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from model import LSTM_Model
from pytorch_lightning import seed_everything, LightningModule, Trainer
from torch import save
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
import neptune.new as neptune
from pytorch_lightning.loggers import NeptuneLogger



torch.manual_seed(1)
#!python3 -m pip install neptune-client

# Install pip packages in the current Jupyter kernel
import sys
#!{sys.executable} -m pip install torchinfo
#!{sys.executable} -m pip install neptune-notebooks
#!{sys.executable} -m pip install neptune-client

#!pip install neptune-notebooks
#!jupyter nbextension enable --py neptune-notebooks


Dataset version: 2.0


In [None]:
seed_everything(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu' #Check for cuda 
print(f'Using {device} device')

neptune_logger = NeptuneLogger(
    project="NTLAB/BCM-activity-classification",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIxYTA4NzcxMy1lYmQ2LTQ3NTctYjRhNC02Mzk1NjdjMWM0NmYifQ==",
    source_files=["train_model.ipynb", "model.py", "BCM_dataset_v2.py"]
) 

model = LSTM_Model('data/bcm/', window_size = 3, stride = 0.032,lstm_hidden_size = 128,occlusion = [1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0])

trainer = Trainer(max_epochs=200, min_epochs=1, auto_lr_find=False, auto_scale_batch_size=False,enable_checkpointing=False, accelerator="gpu", devices = 1, logger=neptune_logger)
trainer.tune(model)
trainer.fit(model)


Global seed set to 42


Using cuda device
Validation set
data/bcm//validation/0.npy
data/bcm//validation/1.npy
data/bcm//validation/2.npy
data/bcm//validation/3.npy
data/bcm//validation/4.npy
Training set
data/bcm//train/0.npy
data/bcm//train/1.npy
data/bcm//train/2.npy
data/bcm//train/3.npy
data/bcm//train/4.npy


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


https://app.neptune.ai/NTLAB/BCM-activity-classification/e/BCMAC-26
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.



  | Name         | Type             | Params
--------------------------------------------------
0 | lstm         | LSTM             | 149 K 
1 | flatten      | Flatten          | 0     
2 | fc           | Linear           | 1.3 K 
3 | output       | Sigmoid          | 0     
4 | sm           | Softmax          | 0     
5 | loss         | CrossEntropyLoss | 0     
6 | accuracy     | Accuracy         | 0     
7 | val_accuracy | Accuracy         | 0     
--------------------------------------------------
150 K     Trainable params
0         Non-trainable params
150 K     Total params
0.603     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.6028403043746948}, 'log': {'val_loss': 1.6028403043746948}, 'val_loss': 1.6028403043746948}
Accuracy: 0.0


  value = torch.tensor(value, device=self.device)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4373046159744263}, 'log': {'val_loss': 1.4373046159744263}, 'val_loss': 1.4373046159744263}
Accuracy: 0.7815362811088562
Accuracy: 0.7594972848892212


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4344414472579956}, 'log': {'val_loss': 1.4344414472579956}, 'val_loss': 1.4344414472579956}
Accuracy: 0.7944608926773071
Accuracy: 0.8172386288642883


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4234328269958496}, 'log': {'val_loss': 1.4234328269958496}, 'val_loss': 1.4234328269958496}
Accuracy: 0.9457919001579285
Accuracy: 0.8915945887565613


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.423644781112671}, 'log': {'val_loss': 1.423644781112671}, 'val_loss': 1.423644781112671}
Accuracy: 0.9463117122650146
Accuracy: 0.9637995362281799


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.423775315284729}, 'log': {'val_loss': 1.423775315284729}, 'val_loss': 1.423775315284729}
Accuracy: 0.9459173679351807
Accuracy: 0.9808419942855835


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4230237007141113}, 'log': {'val_loss': 1.4230237007141113}, 'val_loss': 1.4230237007141113}
Accuracy: 0.9486241936683655
Accuracy: 0.9873488545417786


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4221605062484741}, 'log': {'val_loss': 1.4221605062484741}, 'val_loss': 1.4221605062484741}
Accuracy: 0.952836811542511
Accuracy: 0.9913123846054077


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4226906299591064}, 'log': {'val_loss': 1.4226906299591064}, 'val_loss': 1.4226906299591064}
Accuracy: 0.9514564871788025
Accuracy: 0.9940041899681091


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4256447553634644}, 'log': {'val_loss': 1.4256447553634644}, 'val_loss': 1.4256447553634644}
Accuracy: 0.9394102096557617
Accuracy: 0.995168924331665


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4229965209960938}, 'log': {'val_loss': 1.4229965209960938}, 'val_loss': 1.4229965209960938}
Accuracy: 0.9503809213638306
Accuracy: 0.9966545104980469


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4223582744598389}, 'log': {'val_loss': 1.4223582744598389}, 'val_loss': 1.4223582744598389}
Accuracy: 0.952836811542511
Accuracy: 0.9966307282447815


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.422764539718628}, 'log': {'val_loss': 1.422764539718628}, 'val_loss': 1.422764539718628}
Accuracy: 0.9510621428489685
Accuracy: 0.9977062940597534


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4213836193084717}, 'log': {'val_loss': 1.4213836193084717}, 'val_loss': 1.4213836193084717}
Accuracy: 0.9556869864463806
Accuracy: 0.9974032044410706


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4267674684524536}, 'log': {'val_loss': 1.4267674684524536}, 'val_loss': 1.4267674684524536}
Accuracy: 0.9340324401855469
Accuracy: 0.9979677200317383


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.421324610710144}, 'log': {'val_loss': 1.421324610710144}, 'val_loss': 1.421324610710144}
Accuracy: 0.956063449382782
Accuracy: 0.9973199963569641


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4272834062576294}, 'log': {'val_loss': 1.4272834062576294}, 'val_loss': 1.4272834062576294}
Accuracy: 0.9323474168777466
Accuracy: 0.9984490871429443


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4218504428863525}, 'log': {'val_loss': 1.4218504428863525}, 'val_loss': 1.4218504428863525}
Accuracy: 0.9549878835678101
Accuracy: 0.9980806112289429


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4247570037841797}, 'log': {'val_loss': 1.4247570037841797}, 'val_loss': 1.4247570037841797}
Accuracy: 0.9424038529396057
Accuracy: 0.9982708096504211


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4233620166778564}, 'log': {'val_loss': 1.4233620166778564}, 'val_loss': 1.4233620166778564}
Accuracy: 0.948928952217102
Accuracy: 0.998425304889679


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.422513484954834}, 'log': {'val_loss': 1.422513484954834}, 'val_loss': 1.422513484954834}
Accuracy: 0.9513668417930603
Accuracy: 0.998092532157898


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4247928857803345}, 'log': {'val_loss': 1.4247928857803345}, 'val_loss': 1.4247928857803345}
Accuracy: 0.9423680305480957
Accuracy: 0.9986332654953003


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4241403341293335}, 'log': {'val_loss': 1.4241403341293335}, 'val_loss': 1.4241403341293335}
Accuracy: 0.9445370435714722
Accuracy: 0.9987699389457703


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4243874549865723}, 'log': {'val_loss': 1.4243874549865723}, 'val_loss': 1.4243874549865723}
Accuracy: 0.9448597431182861
Accuracy: 0.9989422559738159


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4233713150024414}, 'log': {'val_loss': 1.4233713150024414}, 'val_loss': 1.4233713150024414}
Accuracy: 0.9485524892807007
Accuracy: 0.9987818002700806


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.423801064491272}, 'log': {'val_loss': 1.423801064491272}, 'val_loss': 1.423801064491272}
Accuracy: 0.9460070133209229
Accuracy: 0.9988769292831421


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4264295101165771}, 'log': {'val_loss': 1.4264295101165771}, 'val_loss': 1.4264295101165771}
Accuracy: 0.9365779161453247
Accuracy: 0.9986094832420349


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4237219095230103}, 'log': {'val_loss': 1.4237219095230103}, 'val_loss': 1.4237219095230103}
Accuracy: 0.947405219078064
Accuracy: 0.9987639784812927


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4233957529067993}, 'log': {'val_loss': 1.4233957529067993}, 'val_loss': 1.4233957529067993}
Accuracy: 0.9491081833839417
Accuracy: 0.9987818002700806


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4222631454467773}, 'log': {'val_loss': 1.4222631454467773}, 'val_loss': 1.4222631454467773}
Accuracy: 0.953033983707428
Accuracy: 0.9987818002700806


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4243690967559814}, 'log': {'val_loss': 1.4243690967559814}, 'val_loss': 1.4243690967559814}
Accuracy: 0.9437482953071594
Accuracy: 0.9987521171569824


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4222477674484253}, 'log': {'val_loss': 1.4222477674484253}, 'val_loss': 1.4222477674484253}
Accuracy: 0.9531415104866028
Accuracy: 0.9985678791999817


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4214268922805786}, 'log': {'val_loss': 1.4214268922805786}, 'val_loss': 1.4214268922805786}
Accuracy: 0.9560813903808594
Accuracy: 0.9985500574111938


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4260491132736206}, 'log': {'val_loss': 1.4260491132736206}, 'val_loss': 1.4260491132736206}
Accuracy: 0.9381195902824402
Accuracy: 0.9992631673812866


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4239600896835327}, 'log': {'val_loss': 1.4239600896835327}, 'val_loss': 1.4239600896835327}
Accuracy: 0.9450031518936157
Accuracy: 0.9982054233551025


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4216580390930176}, 'log': {'val_loss': 1.4216580390930176}, 'val_loss': 1.4216580390930176}
Accuracy: 0.955704927444458
Accuracy: 0.9972368478775024


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4234206676483154}, 'log': {'val_loss': 1.4234206676483154}, 'val_loss': 1.4234206676483154}
Accuracy: 0.9493770599365234
Accuracy: 0.999061107635498


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4222362041473389}, 'log': {'val_loss': 1.4222362041473389}, 'val_loss': 1.4222362041473389}
Accuracy: 0.9527471661567688
Accuracy: 0.9988887906074524


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4225276708602905}, 'log': {'val_loss': 1.4225276708602905}, 'val_loss': 1.4225276708602905}
Accuracy: 0.9516895413398743
Accuracy: 0.9984549880027771


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4243286848068237}, 'log': {'val_loss': 1.4243286848068237}, 'val_loss': 1.4243286848068237}
Accuracy: 0.9444832801818848
Accuracy: 0.9990373253822327


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.424156904220581}, 'log': {'val_loss': 1.424156904220581}, 'val_loss': 1.424156904220581}
Accuracy: 0.9460607767105103
Accuracy: 0.9992750287055969


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4236855506896973}, 'log': {'val_loss': 1.4236855506896973}, 'val_loss': 1.4236855506896973}
Accuracy: 0.9479967951774597
Accuracy: 0.9986035823822021


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4259670972824097}, 'log': {'val_loss': 1.4259670972824097}, 'val_loss': 1.4259670972824097}
Accuracy: 0.9376355409622192
Accuracy: 0.9987224340438843


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4228447675704956}, 'log': {'val_loss': 1.4228447675704956}, 'val_loss': 1.4228447675704956}
Accuracy: 0.9509366154670715
Accuracy: 0.9992572069168091


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4240798950195312}, 'log': {'val_loss': 1.4240798950195312}, 'val_loss': 1.4240798950195312}
Accuracy: 0.9457560181617737
Accuracy: 0.9990730285644531


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4237595796585083}, 'log': {'val_loss': 1.4237595796585083}, 'val_loss': 1.4237595796585083}
Accuracy: 0.9457380771636963
Accuracy: 0.998817503452301


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.424829125404358}, 'log': {'val_loss': 1.424829125404358}, 'val_loss': 1.424829125404358}
Accuracy: 0.9438379406929016
Accuracy: 0.9989244341850281


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4235737323760986}, 'log': {'val_loss': 1.4235737323760986}, 'val_loss': 1.4235737323760986}
Accuracy: 0.9478175044059753
Accuracy: 0.9991859197616577


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.423141360282898}, 'log': {'val_loss': 1.423141360282898}, 'val_loss': 1.423141360282898}
Accuracy: 0.9480146765708923
Accuracy: 0.9986273050308228


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4218133687973022}, 'log': {'val_loss': 1.4218133687973022}, 'val_loss': 1.4218133687973022}
Accuracy: 0.9541453719139099
Accuracy: 0.9992631673812866


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4214719533920288}, 'log': {'val_loss': 1.4214719533920288}, 'val_loss': 1.4214719533920288}
Accuracy: 0.9546831846237183
Accuracy: 0.9982826709747314


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4235621690750122}, 'log': {'val_loss': 1.4235621690750122}, 'val_loss': 1.4235621690750122}
Accuracy: 0.9473335146903992
Accuracy: 0.999001681804657


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4238145351409912}, 'log': {'val_loss': 1.4238145351409912}, 'val_loss': 1.4238145351409912}
Accuracy: 0.9465089440345764
Accuracy: 0.9990076422691345


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4230380058288574}, 'log': {'val_loss': 1.4230380058288574}, 'val_loss': 1.4230380058288574}
Accuracy: 0.9510621428489685
Accuracy: 0.9992750287055969


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4228252172470093}, 'log': {'val_loss': 1.4228252172470093}, 'val_loss': 1.4228252172470093}
Accuracy: 0.949878990650177
Accuracy: 0.9986273050308228


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4264299869537354}, 'log': {'val_loss': 1.4264299869537354}, 'val_loss': 1.4264299869537354}
Accuracy: 0.9361476898193359
Accuracy: 0.9990967512130737


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.423315405845642}, 'log': {'val_loss': 1.423315405845642}, 'val_loss': 1.423315405845642}
Accuracy: 0.9476920366287231
Accuracy: 0.9988947510719299


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4236186742782593}, 'log': {'val_loss': 1.4236186742782593}, 'val_loss': 1.4236186742782593}
Accuracy: 0.9473693370819092
Accuracy: 0.9986748695373535


In [None]:
seed_everything(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu' #Check for cuda 
print(f'Using {device} device')

neptune_logger2 = NeptuneLogger(
    project="NTLAB/BCM-activity-classification",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIxYTA4NzcxMy1lYmQ2LTQ3NTctYjRhNC02Mzk1NjdjMWM0NmYifQ==",
    source_files=["train_model.ipynb", "model.py", "BCM_dataset_v2.py"]
) 

model = LSTM_Model('data/bcm/', window_size = 3, stride = 0.032,lstm_hidden_size = 128,occlusion = [0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1])

trainer = Trainer(max_epochs=200, min_epochs=1, auto_lr_find=False, auto_scale_batch_size=False,enable_checkpointing=False, accelerator="gpu", devices = 1, logger=neptune_logger2)
trainer.tune(model)
trainer.fit(model)

In [None]:
#torch.save(model.state_dict(), "trained_models/oct_6_3_sec_window")

model_scripted = torch.jit.script(model) # Export to TorchScript
#model_scripted.save("trained_models/oct_14") # Save