In [1]:
import time

from torch.optim import lr_scheduler

from model import ECGDataset, ECG_Classifier_LSTM
from torch.utils.data import DataLoader, random_split
from MyEDFImports import load_all_data, load_all_labels, remove_ecg_artifacts, three_stages_transform
import torch
import torch.nn as nn
import os
from tempfile import TemporaryDirectory

In [2]:
all_unprepared_data = load_all_data()
all_unprepared_labels = load_all_labels()

print(len(all_unprepared_data))
filtered_data, filter_labels = remove_ecg_artifacts(all_unprepared_data, all_unprepared_labels)
print(len(filtered_data))
# going from 6 labels to three Wake, Nrem, REM
filter_labels = three_stages_transform(filter_labels)

# this data right now is filtered but not normlized
# TODO: normalize

Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CN223100.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CP229110.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/CX230050.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/DG220020.edf...
EDF file detected
Setting chann

  raw = mne.io.read_raw_edf(path + "//" + name)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/LA216100.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/LM230010.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_states/Jean-Pol_repaired_headers/TK221110.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Extracting EDF parameters from /home/tadeusz/Desktop/Tadeusz/mgr_sleep_

  raw = mne.io.read_raw_edf(path + "//" + name)


<RawEDF | CN223100.edf, 1 x 15611000 (31222.0 s), ~6 kB, data not loaded> with 1561 windows
<RawEDF | CP229110.edf, 1 x 20078000 (40156.0 s), ~6 kB, data not loaded> with 2007 windows
<RawEDF | CX230050.edf, 1 x 17981000 (35962.0 s), ~6 kB, data not loaded> with 1798 windows
<RawEDF | DG220020.edf, 1 x 17756000 (35512.0 s), ~6 kB, data not loaded> with 1775 windows
<RawEDF | DO223050.edf, 1 x 18066500 (36133.0 s), ~6 kB, data not loaded> with 1806 windows
<RawEDF | LA216100.edf, 1 x 16333500 (32667.0 s), ~6 kB, data not loaded> with 1633 windows
<RawEDF | LM230010.edf, 1 x 17246500 (34493.0 s), ~6 kB, data not loaded> with 1724 windows
<RawEDF | TK221110.edf, 1 x 15991000 (31982.0 s), ~6 kB, data not loaded> with 1599 windows
<RawEDF | VC209100.edf, 1 x 18434500 (36869.0 s), ~6 kB, data not loaded> with 1843 windows
<RawEDF | VP214110.edf, 1 x 17252500 (34505.0 s), ~6 kB, data not loaded> with 1725 windows
<RawEDF | WD224010.edf, 1 x 17774000 (35548.0 s), ~6 kB, data not loaded> with 1

In [3]:
# Hyperparameters
input_size = 64  # Adjust based on the output from conv layers
hidden_size = 128
num_layers = 2
num_classes = 3  # [Wake, NonREM, REM]
learning_rate = 0.1
batch_size=4

In [4]:
stages = ['train', 'val']
dataset_all = ECGDataset(filtered_data, filter_labels)
train_data, test_data = random_split(dataset_all, [0.8, 0.2])
datasets = {'train': train_data, 'val': test_data}
dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in
               stages}
dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ECG_Classifier_LSTM(input_size, hidden_size, num_layers, num_classes)
model = model.to(device)  # Move the model to GPU

In [5]:
# inbalanced dataset 4 to 1 so adding weights to criterion
crit_weitghts = torch.tensor([4., 1., 4.]).to(device)
criterion = nn.CrossEntropyLoss(weight=crit_weitghts)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
model = ECG_Classifier_LSTM(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes, num_layers=num_layers)
model = model.to(device)

In [18]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_name = f'best_model_params_{type(criterion).__name__}_{type(optimizer).__name__}_{num_epochs}.pt'
        best_model_params_path = os.path.join(tempdir, best_model_name)

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    print('in training')
                    model.train()  # Set model to training mode
                else:
                    print('in validation')
                    model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                # No idea why now for inputs why I need to transfer it to a float from a double
                inputs = inputs.unsqueeze(1).to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):

                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        print('updating shit')
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                print('updating sched')
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), best_model_params_path)
        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')
        model.load_state_dict(torch.load(best_model_params_path))
    return model

In [19]:
dataloaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x7fb85f016340>,
 'val': <torch.utils.data.dataloader.DataLoader at 0x7fb85f016160>}

In [22]:
model = train_model(model, criterion, optimizer, exp_lr_scheduler)

Epoch 1/25
----------
in training
in validation
val Loss: 1.0960 Acc: 0.6510
Epoch 2/25
----------
in training
in validation
val Loss: 1.0959 Acc: 0.6510
Epoch 3/25
----------
in training
in validation
val Loss: 1.0961 Acc: 0.6510
Epoch 4/25
----------
in training
in validation
val Loss: 1.0961 Acc: 0.6510
Epoch 5/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 6/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 7/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 8/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 9/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 10/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 11/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 12/25
----------
in training
in validation
val Loss: 1.0962 Acc: 0.6510
Epoch 13/25
----------
in training
in validation
val Loss: 1.0960 Acc: 0.

In [9]:
inputs, labels = next(iter(dataloaders['train']))
inputs.unsqueeze(1).shape

torch.Size([4, 1, 10000])

In [14]:
len(dataloaders['train'])

3004