In [1]:
# essentially the jupyter notebook version of train.py; created to make model tuning faster.
# data loading is decoupled from training

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np

from preprocessing.main import main, get_edf_files
from crnn_tf_v3 import CRNN
from dataloader import ANNEDataset

import json
import math
import time
import random

random.seed(42)
torch.manual_seed(42)
timestr = time.strftime("%Y%m%d-%H%M%S")

N_CLASSES = 3


class CosineWithWarmupLR(LambdaLR):
    def __init__(self, optimizer, warmup_epochs, max_epochs, max_lr=0.001, min_lr=0.0001):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        # self.base_lr = base_lr
        warmup_scheduler = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else \
            (0.5 * (1.0 + math.cos(math.pi * (epoch - warmup_epochs) / (max_epochs - warmup_epochs))) if epoch < max_epochs else min_lr/max_lr)

        super(CosineWithWarmupLR, self).__init__(optimizer, lr_lambda=warmup_scheduler)


def save_strings_to_json(strings_list, filename):
    data = {"strings": strings_list}

    with open(filename, "w") as json_file:
        json.dump(data, json_file, indent=4)

def train_model(model, optimizer, train_loaders, test_loaders, lr_scheduler, epochs=100, print_every=10):

    # Using GPUs in PyTorch is pretty straightforward
    if torch.cuda.is_available():
        print("Using cuda")
        use_cuda = True
        device = torch.device("cuda")
    else:
        device = "cpu"

    if N_CLASSES == 3:
        xentropy_weight = torch.tensor([1 / 27 ** 1.5, 1 / 62 ** 1.5, 1 / 11 ** 1.5]).to(device)
    else:
        xentropy_weight = torch.tensor([1 / 27 ** 1.75 , 1 / 73 ** 1.75]).to(device)

    criterion = nn.CrossEntropyLoss(weight=xentropy_weight)
    train_accs = []
    test_accs = []
    train_losses = []
    test_losses = []
    learning_rates = []
    # Early return parameters
    patience = 1000       # we early return if there is no improvement after patience number of epochs
    counter = 0
    # min_delta = 0.01    # at least 1% accuracy increase is needed to count it as an improvement
    min_delta = 0
    best_test_loss = float("inf")

    # Move the model to GPU, if available
    model.to(device)
    model.train()

    for epoch in range(epochs):
        xentropy_loss_total = 0.
        correct = 0.
        total = 0.
        total_trainer_len = 0
        for train_loader in train_loaders:
            for i, (inputs, inputs_freq, inputs_scl, labels, lengths) in enumerate(train_loader):
                model.zero_grad()
                inputs = inputs.to(device)
                inputs_freq = inputs_freq.to(device)
                inputs_scl = inputs_scl.to(device)
                labels = labels.to(device)
                # inputs = inputs.view(inputs.size(0), -1)  # Flatten input from [batch_size, 1, 28, 28] to [batch_size, 784]
                pred = model(inputs, inputs_freq, inputs_scl)
                xentropy_loss = criterion(pred, labels)
                xentropy_loss.backward()

                xentropy_loss_total += xentropy_loss.item()

                # Calculate running average of accuracy
                pred = torch.max(pred.data, 1)[1]
                total += labels.size(0)
                correct += (pred == labels.data).sum().item()

            optimizer.step()
            total_trainer_len += len(train_loader)

        # lr_scheduler.step(epoch + i / len(train_loader))    # Important: use this for CosineAnnealingWarmRestarts
        lr_scheduler.step()  # use this for CyclicLR
        current_lr = lr_scheduler.get_last_lr()[0]
        learning_rates.append(current_lr)
        print(f"Learning Rate {current_lr}")

        accuracy = correct / total
        avg_xentropy_loss = xentropy_loss_total / total_trainer_len

        test_acc_sum = 0
        test_loss_sum = 0

        for test_loader in test_loaders:
            test_acc, test_loss = evaluate(model, test_loader, criterion, device)
            test_acc_sum += test_acc
            test_loss_sum += test_loss

        test_acc = test_acc_sum / len(test_loaders)
        test_loss = test_loss_sum / len(test_loaders)

        if epoch % print_every == 0:
            print("Epoch {}, Train acc: {:.2f}%, Test acc: {:.2f}%".format(epoch, accuracy * 100, test_acc * 100))
            print("Epoch {}, Train loss: {:.2f}, Test loss: {:.2f}".format(epoch, avg_xentropy_loss, test_loss))

        train_accs.append(accuracy)
        test_accs.append(test_acc)
        train_losses.append(avg_xentropy_loss)
        test_losses.append(test_loss)

        # Check for early stopping
        if test_loss < best_test_loss - min_delta:
            best_test_loss = test_loss
            model_scripted = torch.jit.script(model)
            model_scripted.save(f"checkpoints/es_{timestr}.pt")
            print("saved new model")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered.")
                break

    return train_accs, test_accs, train_losses, test_losses, learning_rates


def evaluate(model, loader, criterion, device):
    model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
    correct = 0.
    total = 0.
    val_loss = 0.
    for i, (inputs, inputs_freq, inputs_scl, labels, lengths) in enumerate(loader):
        with torch.no_grad():
            inputs = inputs.to(device)
            inputs_freq = inputs_freq.to(device)
            labels = labels.to(device)
            # inputs = inputs.view(inputs.size(0), -1)
            pred = model(inputs, inputs_freq, inputs_scl)
            xentropy_loss = criterion(pred, labels)
            val_loss += xentropy_loss.item()

        pred = torch.max(pred.data, 1)[1]
        total += labels.size(0)
        correct += (pred == labels).sum().item()

    val_acc = correct / total
    val_loss = val_loss / len(loader)
    model.train()
    return val_acc, val_loss


def plot_grad_histograms(grad_list, epoch, init=False):
    fig, ax = plt.subplots(nrows=1, ncols=len(grad_list), figsize=(5 * len(grad_list), 5))
    for i, grad in enumerate(grad_list):
        plt.subplot(1, len(grad_list), i + 1)
        plt.hist(grad)
        if init:
            plt.title("Grads for Weights (Layer {}-{}) Init".format(i, i + 1))
        else:
            plt.title("Grads for Weights (Layer {}-{}) Epoch {}".format(i, i + 1, epoch))
    plt.show()


def plot_act_histograms(act_list, epoch, init=False):
    fig, ax = plt.subplots(nrows=1, ncols=len(act_list), figsize=(5 * len(act_list), 5))
    for i, act in enumerate(act_list):
        plt.subplot(1, len(act_list), i + 1)
        plt.hist(act)
        if init:
            plt.title("Activations for Layer {} Init".format(i + 1))
        else:
            plt.title("Activations for Layer {} Epoch {}".format(i + 1, epoch))
    plt.show()

In [2]:
# Check gpu availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.cuda.empty_cache()
# Load data:
train_list = get_edf_files("/mnt/Common/data")

validation_list = random.sample(train_list, 20)
print(validation_list)
save_strings_to_json(validation_list, "./validation.json")
#
# train_list = train_list_[:2]
# print(train_list)
# validation_list = [train_list_[1]]

train_dataloaders = []
valid_dataloaders = []
for path in train_list:
    # try:
        X, X_freq, X_scl, t = main(path)
        # for binary classification
        if N_CLASSES == 2:
            t = np.where(t == 2, 1, t)
        # print(t)
        dataset = ANNEDataset(X, X_freq, np.zeros(shape = (len(X), 1)), t, device)
        size = len(X)
        if path not in validation_list:
            train_dataloaders.append(DataLoader(dataset=dataset, batch_size=size))

        else:
            valid_dataloaders.append(DataLoader(dataset=dataset, batch_size=size))
    # except:
    #     print(f"Something went wrong for file {path}")
# random.shuffle(train_list)

['/mnt/Common/data/20-10-02-21_48_39.C1390.L1217.205-annotated.edf', '/mnt/Common/data/20-02-10-20_07_16.C823.L931.7-annotated.edf', '/mnt/Common/data/21-03-22-22_36_53.C1442.L1215.261-annotated.edf', '/mnt/Common/data/21-02-18-21_07_23.C1442.L1215.250-annotated.edf', '/mnt/Common/data/21-01-12-21_30_33.C1442.L1215.238-annotated.edf', '/mnt/Common/data/20-10-23-21_08_59.C1442.L1215.213-annotated.edf', '/mnt/Common/data/20-09-10-23_48_07.C1425.L1205.194-annotated.edf', '/mnt/Common/data/20-08-24-21_44_04.C1419.L1362.186-annotated.edf', '/mnt/Common/data/21-09-15-22_59_43.C1390.L3562.316-annotated.edf', '/mnt/Common/data/20-02-19-19_54_00.C823.L931.10-annotated.edf', '/mnt/Common/data/20-02-18-20_05_51.C823.L931.9-annotated.edf', '/mnt/Common/data/20-08-26-22_34_39.C1390.L1217.188-annotated.edf', '/mnt/Common/data/21-01-06-22_45_53.C1425.L1205.236-annotated.edf', '/mnt/Common/data/21-02-01-21_36_56.C1425.L1205.244-annotated.edf', '/mnt/Common/data/23-05-03-21_29_35.C4359.L3786.589-annota

PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-01-14-20_23_27.C823.L775.4-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-01-15-20_13_00.C823.L775.5-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-02-04-20_34_29.C823.L931.6-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-02-10-20_07_16.C823.L931.7-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-02-18-20_05_51.C823.L931.9-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-02-19-19_54_00.C823.L931.10-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-03-09-21_22_39.C823.L931.11-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-03-11-21_23_00.C823.L931.13-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-07-21-21_07_11.C1442.L1198.170-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-07-28-21_50_47.C1459.L1198.174-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-07-30-21_10_14.C1390.L1217.176-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-08-08-21_59_12.C1459.L1198.178-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-08-10-21_18_52.C1459.L1198.179-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-08-13-22_15_43.C1425.L1205.181-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-08-17-21_22_44.C1442.L1215.183-annotate

PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-08-24-21_44_04.C1419.L1362.186-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-08-26-22_34_39.C1390.L1217.188-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-09-02-21_00_21.C1390.L1217.189-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-09-09-20_52_35.C1390.L1217.193-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-09-10-23_48_07.C1425.L1205.194-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-09-30-21_56_35.C1419.L1362.203-annotate

PLM.events, SpO2.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-10-05-21_31_56.C1459.L1198.206-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


resp.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-10-07-22_00_34.C1425.L1205.207-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-10-08-22_12_23.C1419.L1362.208-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-10-13-21_31_56.C1442.L1215.209-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-10-14-21_31_33.C1459.L1198.210-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-10-15-21_57_07.C1425.L1205.211-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-10-23-21_08_59.C1442.L1215.213-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-10-28-21_32_43.C1459.L1198.214-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-11-03-22_20_46.C1442.L1215.216-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-11-04-22_03_00.C1459.L1198.217-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-11-09-21_34_42.C1425.L1205.218-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-11-11-22_10_54.C1442.L1215.220-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events, SpO2.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-11-17-22_12_58.C1459.L1198.221-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-11-19-22_30_28.C1425.L1205.222-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-11-24-21_39_36.C1459.L1198.223-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-11-26-22_10_11.C1442.L1215.224-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-12-01-22_20_13.C1425.L1205.225-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-12-02-22_32_08.C1459.L1198.226-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-12-08-21_39_09.C1425.L1205.228-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-12-09-22_34_06.C1442.L1215.229-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-12-10-22_33_54.C1459.L1198.230-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events, SpO2.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-12-15-22_02_08.C1425.L1205.231-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/20-12-16-21_57_22.C1442.L1215.232-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events, SpO2.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/20-12-17-22_24_00.C1459.L1198.233-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-01-04-21_42_22.C1459.L1198.234-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/21-01-06-22_17_14.C1442.L1215.235-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-01-06-22_45_53.C1425.L1205.236-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-01-11-21_39_04.C1425.L1205.237-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-01-12-21_30_33.C1442.L1215.238-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-01-25-21_33_14.C1425.L1205.242-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-02-01-21_36_56.C1425.L1205.244-annotate

PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/21-02-18-21_48_01.C1425.L1205.251-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-02-23-21_30_23.C1425.L1205.252-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-03-01-21_32_46.C1425.L1205.253-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-03-03-21_54_30.C1425.L1205.254-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-03-09-21_20_56.C1442.L1215.256-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-03-15-21_16_58.C1442.L1215.258-annotate

PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/21-06-01-21_57_58.C1390.L1215.283-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-06-07-21_00_39.C1390.L1205.284-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-06-08-22_17_20.C1390.L1215.285-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-06-09-21_31_23.C1390.L1215.286-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-06-15-21_40_44.C1390.L1215.287-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-06-16-21_34_33.C1390.L1215.288-annotate

PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/21-08-19-22_25_46.C1425.L1205.304-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/21-08-23-21_08_24.C1390.L1215.305-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-08-24-21_13_48.C1390.L1215.306-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-08-25-21_26_38.C1425.L1205.307-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-08-26-23_01_05.C1390.L1215.308-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-08-31-21_32_49.C1390.L3562.309-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-09-01-21_31_56.C1390.L3562.310-annotate

SpO2.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/21-09-14-21_30_45.C1390.L3562.315-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-09-15-22_59_43.C1390.L3562.316-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-09-16-21_46_18.C1390.L3562.317-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-09-28-22_11_46.C1425.L3562.319-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-09-29-21_52_36.C1425.L3562.320-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/21-10-07-22_11_28.C3884.L3562.322-annotate

PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/21-12-14-22_27_58.C3882.L3562.342-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/22-05-12-22_05_29.C4179.L3806.394-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/22-07-21-21_50_04.C4179.L3806.433-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-02-27-21_30_50.C4359.L3786.556-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events, SpO2.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/23-03-06-20_52_45.C4359.L3786.559-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-03-14-20_49_13.C4359.L3786.565-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-03-21-21_05_54.C4359.L3786.569-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-03-22-21_41_32.C4359.L3786.570-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-03-27-21_10_08.C4359.L3786.572-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-04-05-21_10_08.C4359.L3786.575-annotate

SpO2.events
  data = mne.io.read_raw_edf(path)


Extracting EDF parameters from /mnt/Common/data/23-04-20-21_40_19.C4359.L3786.580-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  edf_info["physical_min"] - edf_info["digital_min"] * edf_info["cal"]
  ch_data = ch_data * cal[orig_idx]


Extracting EDF parameters from /mnt/Common/data/23-04-24-20_54_29.C4181.L3766.582-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-04-26-20_42_32.C4181.L3766.584-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /mnt/Common/data/23-05-03-21_29_35.C4359.L3786.589-annotated.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


PLM.events
  data = mne.io.read_raw_edf(path)


In [3]:
model = CRNN(num_classes=N_CLASSES, in_channels=X.shape[1], in_channels_f=X_freq.shape[1], in_channels_s=0, model='gru')
#
# MODEL_PATH = ""
# model = torch.load(MODEL_PATH)
# Initialize dataloaders
# train_dataset = ANNEDataset(X1, X1f, X1s, t1, device)
# train_dataloader = DataLoader(dataset=train_dataset, batch_size=4096)
# val_dataset = ANNEDataset(X2, X2f, X2s, t2, device)
# val_dataloader = DataLoader(dataset=val_dataset)

# Visualize model
# dummy_input = torch.randn(1024, 6, 25 * 30)
# torch.onnx.export(model, dummy_input, "./model.onnx")

# Train model:
learning_rate = 0.0000675
epochs = 400
# dummy_input = torch.randn(4096, X.shape[1], 25*30)
# dummy_input_freq = torch.randn(4096, X_freq.shape[1], X_freq.shape[2])
# dummy_input_scl = torch.randn(4096, X_scl.shape[1], X_scl.shape[2])
# torch.onnx.export(model, dummy_input, "./model.onnx")

# Train model:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1)
# Create the learning rate scheduler
scheduler = CosineWithWarmupLR(optimizer, warmup_epochs=10, max_epochs=300, max_lr=learning_rate, min_lr=0.000001)
# scheduler = CyclicLR(optimizer, max_lr = 0.01, base_lr =0.0000001, step_size_up=15, step_size_down=20,
# gamma=0.85, cycle_momentum=False, mode="triangular2") Run the training loop
train_accs, test_accs, train_losses, test_losses, learning_rates = train_model(model, optimizer, train_dataloaders,
                                                                               valid_dataloaders,
                                                                               scheduler,
                                                                               epochs=epochs,
                                                                               print_every=1)
model_scripted = torch.jit.script(model)
model_scripted.save(f"checkpoints/model_{timestr}.pt")
print("Model Saved")

plt.plot(train_losses)
plt.plot(test_losses)
plt.title("loss")
plt.show()

plt.plot(train_accs)
plt.plot(test_accs)
plt.title("accuracy")
plt.show()

plt.plot(learning_rates)
plt.plot(learning_rates)
plt.title("learning_rates")
plt.show()

Using cuda
Learning Rate 6.750000000000001e-06
Epoch 0, Train acc: 30.42%, Test acc: 29.11%
Epoch 0, Train loss: 1.25, Test loss: 1.19
saved new model
Learning Rate 1.3500000000000001e-05
Epoch 1, Train acc: 31.70%, Test acc: 34.66%
Epoch 1, Train loss: 1.21, Test loss: 1.15
saved new model
Learning Rate 2.025e-05
Epoch 2, Train acc: 31.33%, Test acc: 38.68%
Epoch 2, Train loss: 1.17, Test loss: 1.12
saved new model
Learning Rate 2.7000000000000002e-05
Epoch 3, Train acc: 29.29%, Test acc: 39.52%
Epoch 3, Train loss: 1.15, Test loss: 1.11
saved new model
Learning Rate 3.375e-05
Epoch 4, Train acc: 29.25%, Test acc: 40.76%
Epoch 4, Train loss: 1.13, Test loss: 1.10
saved new model
Learning Rate 4.05e-05
Epoch 5, Train acc: 29.36%, Test acc: 41.51%
Epoch 5, Train loss: 1.12, Test loss: 1.09
saved new model
Learning Rate 4.7249999999999997e-05
Epoch 6, Train acc: 29.43%, Test acc: 42.15%
Epoch 6, Train loss: 1.11, Test loss: 1.09
saved new model
Learning Rate 5.4000000000000005e-05
Epoch 