In [1]:
import os
import json

h5_directory = './data/h5'  # adapt if you used a different DOWNLOAD_PATH when running `make download_example`
import torch
import tempfile
import json
import random
import sys
sys.path.append("..")

from dosed.utils import Compose
from dosed.datasets import BalancedEventDataset as dataset
from dosed.models import DOSED3 as model
from dosed.models import TC
from dosed.datasets import get_train_validation_test
from dosed.trainers import trainers
from dosed.preprocessing import GaussianNoise, RescaleNormal, Invert
from dosed.functions import augmentations

from config_files import EDF_Configs

seed = 2019
configs = EDF_Configs.Config()
train, validation, test = get_train_validation_test(h5_directory,
                                                    percent_test=25,
                                                    percent_validation=33,
                                                    seed=seed)

print("Number of records train:", len(train))
print("Number of records validation:", len(validation))
print("Number of records test:", len(test))
window = 10  # window duration in seconds
ratio_positive = 0.5  # When creating the batch, sample containing at least one spindle will be drawn with that probability
# 一个训练批次中，每个样本包含至少一个脑电活动（spindle）的概率
fs = 32

signals = [
    {
        'h5_path': '/eeg_0',
        'fs': 64,
        'processing': {
            "type": "clip_and_normalize",
            "args": {
                    "min_value": -150,
                "max_value": 150,
            }
        }
    },
    {
        'h5_path': '/eeg_1',
        'fs': 64,
        'processing': {
            "type": "clip_and_normalize",
            "args": {
                    "min_value": -150,
                "max_value": 150,
            }
        }
    }
]

events = [
    {
        "name": "spindle",
        "h5_path": "spindle",
    },
]
dataset_parameters = {
    "h5_directory": h5_directory,#h5目录
    "signals": signals,#信号，包括两个通道
    "events": events,#事件，只有一个事件，即脑电活动（spindle）
    "window": window,#窗口，10s
    "fs": fs,#采样率32
    "ratio_positive": ratio_positive,#一个训练批次中，每个样本包含至少一个脑电活动（spindle）的概率
    "n_jobs": -1,  # Make use of parallel computing to extract and normalize signals from h5
    "cache_data": True,  # by default will store normalized signals extracted from h5 in h5_directory + "/.cache" directory
}

dataset_validation = dataset(records=validation, **dataset_parameters)
dataset_test = dataset(records=test, **dataset_parameters)

# for training add data augmentation
dataset_parameters_train = {
    "transformations": Compose([
        GaussianNoise(),
        RescaleNormal(),
        Invert(),
    ])
}


dataset_parameters_train.update(dataset_parameters)#更新dataset_parameters的键值对到dataset_parameters_train
dataset_train = dataset(records=train, **dataset_parameters_train)# inputsize=2,320
default_event_sizes = [0.7, 1, 1.3]#如何选取默认事件的大小，要检测的事件是脑电活动（spindle），大小约为1s
k_max = 5
kernel_size = 5
probability_dropout = 0.1
device = torch.device("cuda")
sampling_frequency = dataset_train.fs

net_parameters = {
    "detection_parameters": {
        "overlap_non_maximum_suppression": 0.5,
        "classification_threshold": 0.7
    },
    "default_event_sizes": [
        default_event_size * sampling_frequency
        for default_event_size in default_event_sizes
    ],
    "k_max": k_max,
    "kernel_size": kernel_size,
    "pdrop": probability_dropout,
    "fs": sampling_frequency,   # just used to print architecture info with right time
    "input_shape": dataset_train.input_shape,# 2，10s，采样率32hz
    "number_of_classes": dataset_train.number_of_classes,# 1
}
self_supervised_net = model(**net_parameters)
supervised_net = model(**net_parameters)
net = model(**net_parameters)
net = net.to(device)

temporal_contr_model = TC(configs, device).to(device)

optimizer_parameters = {
    "lr": 5e-3,
    "weight_decay": 1e-8,
}
loss_specs = {
    "type": "focal",
    "parameters": {
        "number_of_classes": dataset_train.number_of_classes,
        "device": device,
    }
}


epochs = 50


  from .autonotebook import tqdm as notebook_tqdm


Number of records train: 11
Number of records validation: 5
Number of records test: 5


In [2]:
self_supervised = torch.load('D:\Desktop\checkpoint_self_supervised_train[0%3A2].pth')
supervised = torch.load('D:\Desktop\checkpoint_train[0%3A2].pth')

In [3]:
self_supervised_net.load_state_dict(self_supervised['model_state_dict'])
supervised_net.load_state_dict(supervised['model_state_dict'])

<All keys matched successfully>

In [7]:
self_supervised

{'epoch': 50,
 'model_state_dict': OrderedDict([('blocks.0.conv_0.weight',
               tensor([[[-0.1502, -0.5412,  0.3422, -0.0143, -0.0295],
                        [ 0.1853,  0.6783, -0.0387,  0.3703,  0.0646]],
               
                       [[ 0.1973,  0.1151,  0.3434,  0.2516,  0.3903],
                        [ 0.1430,  0.4077,  0.2206,  0.1814,  0.3940]],
               
                       [[ 0.1839, -0.2446,  0.1052, -0.0257,  0.0969],
                        [ 0.1478, -0.2293,  0.7194, -0.3789,  0.1863]],
               
                       [[ 0.3270, -0.4924,  0.6544,  0.0314, -0.1171],
                        [-0.4021, -0.0881, -0.4421,  0.0205, -0.1321]],
               
                       [[ 0.4761,  0.3598,  0.3645,  0.3924,  0.3815],
                        [-0.0393, -0.0371,  0.1946, -0.0772,  0.1449]],
               
                       [[-0.3591, -0.4261, -0.5309, -0.4892, -0.3858],
                        [ 0.0457, -0.1704, -0.2404, -0.1362

In [8]:
from dosed.functions import compute_metrics_dataset
metrics_self_supervised = compute_metrics_dataset(
    self_supervised_net,
    dataset_test,
    self_supervised['best_threshold'],
)
metrics_supervised = compute_metrics_dataset(
    supervised_net,
    dataset_test,
    supervised['best_threshold_train'],
)

In [9]:
metrics_self_supervised

[{'precision': 0.18756524461325624,
  'recall': 0.25529592496560927,
  'f1': 0.20279879685421057}]

In [10]:
metrics_supervised

[{'precision': 0.33462632883286975,
  'recall': 0.3280181241504568,
  'f1': 0.2984054659905132}]

In [None]:
best_net_train = torch.load('best_net_train.pth')

In [None]:
import pickle

with open('minimum_example/hyperparameters.pkl', 'rb') as f:
    hyperparameters = pickle.load(f)

In [None]:
hyperparameters

In [None]:
from dosed.functions import compute_metrics_dataset
metrics_test = compute_metrics_dataset(
    best_net_train,
    dataset_test,
    hyperparameters['best_threshold_train']
)

In [None]:
metrics_test