In [55]:
import time

import numpy as np

from portiloop_software.portiloop_python.ANN.utils import get_configs

from portiloop_software.portiloop_python.ANN.data.mass_data import read_pretraining_dataset, read_sleep_staging_labels, read_spindle_trains_labels
from portiloop_software.portiloop_python.ANN.models.lstm import get_trained_model
import torch


experiment_name = 'test_adapation'
seed = 42
model_path = 'no_att_baseline'
subject_id = '01-02-0019'

config = get_configs(experiment_name, False, seed)
# config['nb_conv_layers'] = 4
# config['hidden_size'] = 64
# config['nb_rnn_layers'] = 4

# Load the model
net = get_trained_model(config, config['path_models'] / model_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Run some testing on subject 1
# Load the data
labels = read_spindle_trains_labels(config['old_dataset'])
ss_labels = read_sleep_staging_labels(config['path_dataset'])
# for index, patient_id in enumerate(ss_labels.keys()):

config['subject_id'] = subject_id

data = read_pretraining_dataset(config['MASS_dir'], patients_to_keep=[subject_id])

558


In [12]:
import torch

class DataBuffer:
    """
    A class to get the data in the right format for the model from a stream of data
    """
    def __init__(self, seq_len, window_size, seq_stride):
        self.seq_len = seq_len
        self.window_size = window_size
        self.seq_stride = seq_stride

        # Compute the total number of points to keep in memory as the buffer
        buffer_size = (seq_len - 1) * seq_stride + window_size
        self.data = torch.zeros(buffer_size, dtype=torch.float32)

    def step(self, point):
        # Shift the data
        self.data[:-1] = self.data.clone()[1:]
        self.data[-1] = point
        current_data = self.data.clone().unfold(0, self.window_size, self.seq_stride)
        current_data = current_data.unsqueeze(0).unsqueeze(2)
        return current_data

In [137]:
from torch.utils.data import Dataset
from portiloop_software.portiloop_python.ANN.utils import RMSScorer
import random

class AdaptationSampler(torch.utils.data.Sampler):
    def __init__(self, dataset):
        """
        Sample random items from a dataset
        """
        self.dataset = dataset

    def __iter__(self):
        """
        Returns an iterator over the dataset
        """
        while True:
            toss = random.random()
            if toss > 0.5:
                # Get a random index from the spindle indexes
                yield 1
            else:
                yield 0


class AdapatationDatasetRMS(Dataset):
    """
    Dataset for adaptation using the RMS score metric. 
    """
    def __init__(self, seq_len, window_size, replacement=False, candidate_threshold=0.5, rms_threshold=2.5, real_threshold=0.95, buffer_time=1250):
        self.buffer_time = buffer_time # 5 seconds: time before and after each window for the RMS
        self.buffer_size = max(seq_len * window_size, self.buffer_time) + self.buffer_time
        self.data = []
        self.labels = []
        self.seq_len = seq_len
        self.window_size = window_size
        self.candidate_threshold = candidate_threshold
        self.rms_threshold = rms_threshold
        self.real_threshold = real_threshold
        self.scorer = RMSScorer()

        self.replacement = replacement

        self.interval = 100 # 400 ms

        # List of samples
        self.positive_samples = []
        self.negative_samples = []
        # self.spindle_indexes = []
        # self.non_spindle_indexes = []

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        '''
        Returns a sample from the dataset. If replacement is True, the sample is not removed from the dataset.
        '''
        # Choose a random index in one of the lists
        index_in_list = random.randint(0, len(self.positive_samples)-1) if index == 1 else random.randint(0, len(self.negative_samples)-1)

        list_to_sample = self.positive_samples if index == 1 else self.negative_samples

        sample = list_to_sample[index_in_list] if self.replacement else list_to_sample.pop(index_in_list)

        return sample, index
    
    def spindle_percentage(self):
        sum_spindles = sum([i[1] for i in self.samples if i[1] == 1])
        return sum_spindles / len(self)

    def add_sample(self, sample, label):
        """
        Takes a sample and adds it to the dataset.
        """
        # self.samples.append((sample, label))
        # if label == 1:
        #     self.spindle_indexes.append(len(self.samples) - 1)
        # else:
        #     self.non_spindle_indexes.append(len(self.samples) - 1)
        if label == 1:
            self.positive_samples.append(sample)
        else:
            self.negative_samples.append(sample)

    def step(self, point, label):
        # Shift the data
        self.data.append(point)
        self.labels.append(label)
    
    def compute(self, max_time=-1):
        """
        Compute the RMS score for the whole buffer and add to the dataset
        """
        # start = time.time()
        # Filter the whole buffer
        self.filtered_data = self.scorer.filter(self.data)

        labels = np.array(self.labels)
    
        # Get all indexes where the labels are above the candidate threshold
        indexes = np.where(labels >= self.candidate_threshold)[0]

        # # Remove the indexes where the previous is less than 400 ms before
        # indexes = indexes[np.insert(np.diff(indexes) >= self.interval, 0, True)]

        # Get the RMS score for each index
        rms_scores = []
        rms_scores_detect = []
        for i, index in enumerate(indexes):
            if index - self.buffer_time < 0 or index + self.buffer_time >= len(self.data):
                continue
            score = self.scorer.get_score(self.filtered_data[index - self.buffer_time:index + self.buffer_time], filter=False)
            rms_scores.append(score)
            if labels[index] >= self.real_threshold and i == 0:
                rms_scores_detect.append(score)
            elif labels[index] >= self.real_threshold and index - indexes[i-1] >= self.interval:
                rms_scores_detect.append(score)
        rms_scores = np.array(rms_scores)

        # Get the indexes where the RMS score is above the threshold
        indexes_positive = indexes[np.where(rms_scores >= self.rms_threshold)[0]]

        # Get the indexes where the RMS score is below the threshold
        indexes_negative = indexes[np.where(rms_scores < self.rms_threshold)[0]]

        # Add the positive samples to the dataset
        for index in indexes_positive:
            if index - self.seq_len * self.window_size < 0:
                continue
            ordered_data = torch.tensor(self.data[index-(self.seq_len * self.window_size):index]).clone().reshape(self.seq_len, 1, self.window_size)
            self.add_sample(ordered_data, 1)

        # Add the negative samples to the dataset
        for index in indexes_negative:
            if index - self.seq_len * self.window_size < 0:
                continue
            ordered_data = torch.tensor(self.data[index-(self.seq_len * self.window_size):index]).clone().reshape(self.seq_len, 1, self.window_size)
            self.add_sample(ordered_data, 0)

        # end = time.time()
        # print(f"Time to compute: {end - start}")

        self.data = []
        self.labels = []
        
        return rms_scores_detect
        

    def num_samples(self):
        '''
        Returns the minimum between the positive and the negative samples
        '''
        return min(len(self.positive_samples), len(self.negative_samples)) 


In [138]:
import copy
from torch import nn
from torch import optim
from tqdm import tqdm


train = True

buffer = DataBuffer(config['seq_len'], config['window_size'], config['seq_stride'])
dataset = AdapatationDatasetRMS(config['seq_len'], config['window_size'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], sampler=AdaptationSampler(dataset))

net_copy = copy.deepcopy(net)
net_copy = net_copy.to(device)
net_copy = net_copy.train()

h1 = torch.zeros((config['nb_rnn_layers'], 1, config['hidden_size']), device=device)

# Initialize optimizer and criterion
optimizer = optim.AdamW(net_copy.parameters(), lr=config['lr_adam'], weight_decay=config['adam_w'])
criterion = nn.BCELoss(reduction='none')

rms_scores = []
real_indexes = []

signal = data[subject_id]['signal']
output = 0

for index, point in enumerate(tqdm(signal)):

    current_data = buffer.step(point)

    if index < config['seq_len'] * config['window_size']:
        continue

    with torch.no_grad():
                
        if index % config['seq_stride'] == 0:
            # Get the output of the network if we have waited the seq stride steps
            output, h1, _ = net_copy(current_data.to(device), h1)
            output = output.squeeze(-1)
            output = output[-1].item()

        # # Put the data into the dataset
        dataset.step(point, output)

        # real_indexes.append(output)

        if index % 10000 == 0:
            rms_scores += dataset.compute()

    # If we have enough data, we train the network
    if dataset.num_samples() > config['batch_size'] and train:

        print(f"Training, {dataset.num_samples()} samples")

        train_sample, train_label = next(iter(dataloader))
        train_sample = train_sample.to(device)
        train_label = train_label.to(device)

        optimizer.zero_grad()

        # Get the output of the network
        h_zero = torch.zeros((config['nb_rnn_layers'], train_sample.size(0), config['hidden_size']), device=device)
        output, _, _ = net_copy(train_sample, h_zero)
        
        # Compute the loss
        output = output.squeeze(-1)
        train_label = train_label.squeeze(-1).float()
        loss = criterion(output, train_label)
        
        loss = loss.mean()
        loss.backward()
        optimizer.step()


  0%|          | 20000/6598000 [00:02<12:52, 8516.09it/s] 

Training, 77 samples





RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same

In [98]:
rms_scores = np.array(rms_scores)
rms_scores.mean(), rms_scores.std()


(3.459782155121597, 2.504105168078916)

In [99]:
sum(rms_scores >= 3) / len(rms_scores)

0.4883720930232558

In [84]:
real_indexes = np.append(np.zeros(2700), real_indexes)

In [87]:
labels_real_model = real_indexes >= 0.95

In [90]:
np.unique(labels_model, return_counts=True)

(array([0., 1.]), array([6407992,  190008]))

In [89]:
np.unique(labels_real_model, return_counts=True)

(array([False,  True]), array([6406270,  191730]))

In [37]:
import time

import numpy as np
from portiloop_software.portiloop_python.ANN.adaptation_training import run_adaptation
from portiloop_software.portiloop_python.ANN.data.mass_data import SingleSubjectDataset, SingleSubjectSampler, read_pretraining_dataset
from portiloop_software.portiloop_python.ANN.utils import get_metrics


config['subject_id'] = subject_id

data = read_pretraining_dataset(config['MASS_dir'], patients_to_keep=[subject_id])

assert subject_id in data.keys(), 'Subject not in the dataset'
assert subject_id in labels.keys(), 'Subject not in the dataset'

dataset = SingleSubjectDataset(config['subject_id'], data=data, labels=labels, config=config, ss_labels=ss_labels)  
sampler = SingleSubjectSampler(len(dataset), config['seq_stride'])
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=1, 
    sampler=sampler, 
    num_workers=0)

# Run the adaptation
start = time.time()
# run_adaptation(dataloader, net, device, config)
output_total, window_labels_total, loss, net_copy = run_adaptation(dataloader, net, device, config, train)
end = time.time()
print('Time: ', end - start)

print("Distribution of the predictions:")
print(np.unique(output_total.cpu().numpy(), return_counts=True))
print("Distribution of the labels:")
print(np.unique(window_labels_total.cpu().numpy(), return_counts=True))

# Get the metrics
acc, f1, precision, recall = get_metrics(output_total, window_labels_total)

Number of spindles: 73
Number of spindle labels: 12410
len of full signal: 6598000
Length of sampler: 157094
Doing index: 0/157094
Doing index: 10000/157094
Doing index: 20000/157094
Doing index: 30000/157094
Doing index: 40000/157094
Doing index: 50000/157094
Doing index: 60000/157094
Doing index: 70000/157094
Doing index: 80000/157094
Doing index: 90000/157094
Doing index: 100000/157094
Doing index: 110000/157094
Doing index: 120000/157094
Doing index: 130000/157094
Doing index: 140000/157094
Doing index: 150000/157094
Time:  239.07770156860352
Distribution of the predictions:
(array([0., 1.], dtype=float32), array([152570,   4524]))
Distribution of the labels:
(array([0., 1.], dtype=float32), array([156799,    295]))


In [41]:
labels_model = np.repeat(output_total.cpu().numpy(), 42)
labels_model = np.append(labels_model, np.zeros(52))

In [45]:
old_indexes = labels_model == 1

array([False, False, False, ..., False, False, False])

In [48]:
old_indexes = np.where(old_indexes)[0]

In [50]:
old_indexes = old_indexes[np.insert(np.diff(old_indexes) > 100, 0, True)]

In [52]:
old_indexes.shape

(1695,)

In [43]:
labels_model = np.diff(labels_model)
labels_model[labels_model == -1] = 0

(6598000,)

In [25]:
rms_scores = torch.tensor(rms_scores)
rms_scores.mean(), rms_scores.std()

  rms_scores = torch.tensor(rms_scores)


(tensor(3.4736, dtype=torch.float64), tensor(2.4870, dtype=torch.float64))

In [27]:
sum(rms_scores > 2) / len(rms_scores)

tensor(0.6690)