In [1]:

#!python3 -m pip install pytorch_lightning
#!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from model import LSTM_Model
from pytorch_lightning import seed_everything, LightningModule, Trainer
from torch import save
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
import neptune.new as neptune
from pytorch_lightning.loggers import NeptuneLogger
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from BCM_dataset_v2 import bcmDataset, concat_train_test_datasets
import glob
import os
import sys
from collect_dataset import*

Dataset version: 5


In [2]:

ft_datasets = create_ft_datasets_da()
                                                   

#loso_datasets = create_loso_datasets()


Fint-tuning datasets


train

0
bcm_behaviour_data_multi_subject/subject2/2022-11-09_14-34-16
bcm_behaviour_data_multi_subject/subject2/2022-11-09_14-52-7
bcm_behaviour_data_multi_subject/subject2/2022-11-09_14-15-6
bcm_behaviour_data_multi_subject/subject2/2022-11-09_13-56-0
bcm_behaviour_data_multi_subject/subject3/2022-11-09_15-47-1
bcm_behaviour_data_multi_subject/subject3/2022-11-09_16-26-13
bcm_behaviour_data_multi_subject/subject3/2022-11-09_15-27-14
bcm_behaviour_data_multi_subject/subject3/2022-11-09_16-7-9
bcm_behaviour_data_multi_subject/subject2/2022-11-09_14-15-6speed
bcm_behaviour_data_multi_subject/subject2/2022-11-09_13-56-0speed
bcm_behaviour_data_multi_subject/subject2/2022-11-09_14-34-16speed
bcm_behaviour_data_multi_subject/subject2/2022-11-09_14-52-7speed
bcm_behaviour_data_multi_subject/subject3/2022-11-09_15-47-1speed
bcm_behaviour_data_multi_subject/subject3/2022-11-09_16-26-13speed
bcm_behaviour_data_multi_subject/subject3/2022-11-09_15-27-14speed
bcm_behaviou

In [3]:
seed = 555
seed_everything(seed)
    
neptune_logger = NeptuneLogger(
    project="NTLAB/BCM-activity-classification",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIxYTA4NzcxMy1lYmQ2LTQ3NTctYjRhNC02Mzk1NjdjMWM0NmYifQ==",
    source_files=["train_model.ipynb", "model.py", "BCM_dataset_v2.py"]
)

checkpoint_callback = ModelCheckpoint(dirpath = 'checkpoints/', filename = 'latest_checkpoint')

model = LSTM_Model(ft_datasets[0], window_size = 3, stride = 1 ,lstm_hidden_size = 128, seed = seed, temporal_cutout = False)

trainer = Trainer(max_epochs=40, 
                min_epochs=1, 
                auto_lr_find=False, 
                auto_scale_batch_size=False,
                enable_checkpointing=False, 
                accelerator="gpu", 
                devices = 1, 
                logger=neptune_logger,
                callbacks = [checkpoint_callback])
trainer.tune(model)
trainer.fit(model)
        
neptune_logger.experiment["metadata/augmentation"].log("best")

# Find the latest checkpoint
list_of_checkpoints = glob.glob('checkpoints/*') 
latest_checkpoint = max(list_of_checkpoints, key = os.path.getctime) 
print(latest_checkpoint) # Sanity check
    
    
    
trainer_finetune = Trainer(max_epochs=80, 
                    min_epochs=100, 
                    auto_lr_find=False, 
                    auto_scale_batch_size=False,
                    enable_checkpointing=False, 
                    accelerator="gpu", 
                    devices = 1, 
                    logger=neptune_logger,
                    callbacks = [checkpoint_callback])
trainer_finetune.tune(model)
    
# Create a dataloader from the new subject specific data

dataset = ft_datasets[0]

dataset_finetune, _ = concat_train_test_datasets([dataset[2],dataset[2]], window_size = 3, stride = 1)
    
neptune_logger.experiment["metadata/train_set_length"].log(len(dataset_finetune))
    
finetuning_dataloader = DataLoader(dataset_finetune, batch_size=32, shuffle=True, num_workers = 8)
    
trainer_finetune.fit(model, finetuning_dataloader, ckpt_path=latest_checkpoint)
        
        
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save(f"trained_models/ft_final_ft_file_2022-09-20_15-38-11") # Save
        
        
# Stop logging
neptune_logger.experiment.stop()
        

Global seed set to 555


_____________________
<class 'list'>
_____________________
Validation set
bcm_behaviour_data_multi_subject/subject1/2022-09-20_14-58-39/0.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_14-58-39/1.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_14-58-39/2.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_14-58-39/3.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_14-58-39/4.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-18-27/0.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-18-27/1.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-18-27/2.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-18-27/3.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-18-27/4.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-57-37/0.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-57-37/1.npy
bcm_behaviour_data_multi_subject/subject1/2022-09-20_15-57-37/2.npy
bcm_behaviour_data_multi_subject/subject1/

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


https://app.neptune.ai/NTLAB/BCM-activity-classification/e/BCMAC-1365
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.



  | Name         | Type             | Params
--------------------------------------------------
0 | lstm         | LSTM             | 149 K 
1 | flatten      | Flatten          | 0     
2 | fc           | Linear           | 1.3 K 
3 | output       | Sigmoid          | 0     
4 | sm           | Softmax          | 0     
5 | loss         | CrossEntropyLoss | 0     
6 | accuracy     | Accuracy         | 0     
7 | val_accuracy | Accuracy         | 0     
--------------------------------------------------
150 K     Trainable params
0         Non-trainable params
150 K     Total params
0.603     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.5980808734893799}, 'log': {'val_loss': 1.5980808734893799}, 'val_loss': 1.5980808734893799}
Accuracy: 0.1875


  value = torch.tensor(value, device=self.device)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4719443321228027}, 'log': {'val_loss': 1.4719443321228027}, 'val_loss': 1.4719443321228027}
Accuracy: 0.7218198776245117
Accuracy: 0.7470157742500305


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4542781114578247}, 'log': {'val_loss': 1.4542781114578247}, 'val_loss': 1.4542781114578247}
Accuracy: 0.7758588790893555
Accuracy: 0.7685090899467468


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.457258701324463}, 'log': {'val_loss': 1.457258701324463}, 'val_loss': 1.457258701324463}
Accuracy: 0.7645310759544373
Accuracy: 0.7778555750846863


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.463701844215393}, 'log': {'val_loss': 1.463701844215393}, 'val_loss': 1.463701844215393}
Accuracy: 0.7478179931640625
Accuracy: 0.7818812131881714


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4581410884857178}, 'log': {'val_loss': 1.4581410884857178}, 'val_loss': 1.4581410884857178}
Accuracy: 0.7792015075683594
Accuracy: 0.7812861204147339


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4537214040756226}, 'log': {'val_loss': 1.4537214040756226}, 'val_loss': 1.4537214040756226}
Accuracy: 0.7740018367767334
Accuracy: 0.7832463979721069


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4596443176269531}, 'log': {'val_loss': 1.4596443176269531}, 'val_loss': 1.4596443176269531}
Accuracy: 0.7483751177787781
Accuracy: 0.7841565608978271


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4646663665771484}, 'log': {'val_loss': 1.4646663665771484}, 'val_loss': 1.4646663665771484}
Accuracy: 0.7199628353118896
Accuracy: 0.7864319086074829


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4614661931991577}, 'log': {'val_loss': 1.4614661931991577}, 'val_loss': 1.4614661931991577}
Accuracy: 0.7277622818946838
Accuracy: 0.7876920700073242


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4757494926452637}, 'log': {'val_loss': 1.4757494926452637}, 'val_loss': 1.4757494926452637}
Accuracy: 0.6939647197723389
Accuracy: 0.7964784502983093


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.462219476699829}, 'log': {'val_loss': 1.462219476699829}, 'val_loss': 1.462219476699829}
Accuracy: 0.7376044392585754
Accuracy: 0.7933979630470276


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4702645540237427}, 'log': {'val_loss': 1.4702645540237427}, 'val_loss': 1.4702645540237427}
Accuracy: 0.7244197130203247
Accuracy: 0.8053698539733887


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4572168588638306}, 'log': {'val_loss': 1.4572168588638306}, 'val_loss': 1.4572168588638306}
Accuracy: 0.763602614402771
Accuracy: 0.8195470571517944


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.463768720626831}, 'log': {'val_loss': 1.463768720626831}, 'val_loss': 1.463768720626831}
Accuracy: 0.7259052991867065
Accuracy: 0.8083803057670593


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4675379991531372}, 'log': {'val_loss': 1.4675379991531372}, 'val_loss': 1.4675379991531372}
Accuracy: 0.7062209844589233
Accuracy: 0.8147863149642944


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4517829418182373}, 'log': {'val_loss': 1.4517829418182373}, 'val_loss': 1.4517829418182373}
Accuracy: 0.8376973271369934
Accuracy: 0.9038401246070862


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4541027545928955}, 'log': {'val_loss': 1.4541027545928955}, 'val_loss': 1.4541027545928955}
Accuracy: 0.8246982097625732
Accuracy: 0.9491370916366577


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4774640798568726}, 'log': {'val_loss': 1.4774640798568726}, 'val_loss': 1.4774640798568726}
Accuracy: 0.735561728477478
Accuracy: 0.9519375562667847


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.479353427886963}, 'log': {'val_loss': 1.479353427886963}, 'val_loss': 1.479353427886963}
Accuracy: 0.7225626707077026
Accuracy: 0.9534428119659424


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4567867517471313}, 'log': {'val_loss': 1.4567867517471313}, 'val_loss': 1.4567867517471313}
Accuracy: 0.8098421692848206
Accuracy: 0.9538278579711914


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4595519304275513}, 'log': {'val_loss': 1.4595519304275513}, 'val_loss': 1.4595519304275513}
Accuracy: 0.7981429696083069
Accuracy: 0.9549829959869385


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4758808612823486}, 'log': {'val_loss': 1.4758808612823486}, 'val_loss': 1.4758808612823486}
Accuracy: 0.7335190176963806
Accuracy: 0.9564182162284851


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4568053483963013}, 'log': {'val_loss': 1.4568053483963013}, 'val_loss': 1.4568053483963013}
Accuracy: 0.8148560523986816
Accuracy: 0.9572233557701111


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4533865451812744}, 'log': {'val_loss': 1.4533865451812744}, 'val_loss': 1.4533865451812744}
Accuracy: 0.8280408382415771
Accuracy: 0.9564182162284851


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4616121053695679}, 'log': {'val_loss': 1.4616121053695679}, 'val_loss': 1.4616121053695679}
Accuracy: 0.7842153906822205
Accuracy: 0.9577484726905823


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4599261283874512}, 'log': {'val_loss': 1.4599261283874512}, 'val_loss': 1.4599261283874512}
Accuracy: 0.7968431115150452
Accuracy: 0.9589036107063293


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4620989561080933}, 'log': {'val_loss': 1.4620989561080933}, 'val_loss': 1.4620989561080933}
Accuracy: 0.7897864580154419
Accuracy: 0.9585185647010803


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4738988876342773}, 'log': {'val_loss': 1.4738988876342773}, 'val_loss': 1.4738988876342773}
Accuracy: 0.7392757534980774
Accuracy: 0.9606888890266418


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4536787271499634}, 'log': {'val_loss': 1.4536787271499634}, 'val_loss': 1.4536787271499634}
Accuracy: 0.821169912815094
Accuracy: 0.9593936800956726


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4524742364883423}, 'log': {'val_loss': 1.4524742364883423}, 'val_loss': 1.4524742364883423}
Accuracy: 0.8367688059806824
Accuracy: 0.9613890051841736


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4685289859771729}, 'log': {'val_loss': 1.4685289859771729}, 'val_loss': 1.4685289859771729}
Accuracy: 0.7628598213195801
Accuracy: 0.9606539011001587


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4612867832183838}, 'log': {'val_loss': 1.4612867832183838}, 'val_loss': 1.4612867832183838}
Accuracy: 0.7948004007339478
Accuracy: 0.9621241092681885


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4621622562408447}, 'log': {'val_loss': 1.4621622562408447}, 'val_loss': 1.4621622562408447}
Accuracy: 0.7961002588272095
Accuracy: 0.9622991681098938


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4681544303894043}, 'log': {'val_loss': 1.4681544303894043}, 'val_loss': 1.4681544303894043}
Accuracy: 0.7682451009750366
Accuracy: 0.9620891213417053


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4602898359298706}, 'log': {'val_loss': 1.4602898359298706}, 'val_loss': 1.4602898359298706}
Accuracy: 0.8048282265663147
Accuracy: 0.9654146432876587


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4758131504058838}, 'log': {'val_loss': 1.4758131504058838}, 'val_loss': 1.4758131504058838}
Accuracy: 0.7455896139144897
Accuracy: 0.9646444916725159


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4678961038589478}, 'log': {'val_loss': 1.4678961038589478}, 'val_loss': 1.4678961038589478}
Accuracy: 0.7717734575271606
Accuracy: 0.9667448401451111


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.4594018459320068}, 'log': {'val_loss': 1.4594018459320068}, 'val_loss': 1.4594018459320068}
Accuracy: 0.8042711019515991
Accuracy: 0.9671298861503601


Validation: 0it [00:00, ?it/s]

{'progress_bar': {'val_loss': 1.462050199508667}, 'log': {'val_loss': 1.462050199508667}, 'val_loss': 1.462050199508667}
Accuracy: 0.7877437472343445
Accuracy: 0.9657996892929077


Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=40` reached.


{'progress_bar': {'val_loss': 1.4588369131088257}, 'log': {'val_loss': 1.4588369131088257}, 'val_loss': 1.4588369131088257}
Accuracy: 0.8064995408058167
Accuracy: 0.9652045965194702


ValueError: max() arg is an empty sequence

In [None]:


#model_scripted = torch.jit.script(model) # Export to TorchScript
#model_scripted.save("trained_models/oct_31") # Save