## Model stuffs

In [13]:
import torch.nn as nn
import torch

In [15]:
class TransformerEncoderWithCLS(nn.Module):
    def __init__(self, embedder, embedding_size, num_heads, num_layers, num_classes):
        super(TransformerEncoderWithCLS, self).__init__()

        self.embedder = embedder

        self.positional_embedding = nn.Embedding(10 * 250, embedding_size) # Positional embedding

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embedding_size, num_heads),
            num_layers
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_size))  # Learnable <cls> token
        self.activation = nn.ReLU()
        self.classifier = nn.Linear(embedding_size, num_classes)  # Classification layer

    def forward(self, x):
        x = self.embedder(x)  # Embed the input sequence
        batch_size, seq_len, embedding_size = x.size()

        # Add positional embedding
        positions = torch.arange(0, seq_len).expand(batch_size, seq_len).to(device)
        x = x + self.positional_embedding(positions)

        cls_tokens = self.cls_token.expand(batch_size, 1, embedding_size)  # Shape: (batch_size, 1, embedding_size)
        x_cls = torch.cat([cls_tokens, x], dim=1)  # Shape: (batch_size, seq_length + 1, embedding_size)

        output = self.transformer_encoder(x_cls)  # Apply TransformerEncoder
        
        out_cls = output[:, 0]  # Return the representation of the <cls> token

        out_cls = self.activation(out_cls)  # Apply activation function

        return self.classifier(out_cls)  # Classify the <cls> token representation


In [16]:
from einops import rearrange
class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionBlock, self).__init__()

        # 1x1 convolution branch
        self.branch1x1 = nn.Conv1d(in_channels, out_channels[0], kernel_size=1)

        # 3x3 convolution branch
        self.branch3x3 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels[1], kernel_size=1),
            nn.Conv1d(out_channels[1], out_channels[2], kernel_size=3, padding=1)
        )

        # 5x5 convolution branch
        self.branch5x5 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels[3], kernel_size=1),
            nn.Conv1d(out_channels[3], out_channels[4], kernel_size=5, padding=2)
        )

        # Max pooling branch
        self.branch_pool = nn.Sequential(
            nn.MaxPool1d(kernel_size=3, stride=1, padding=1),
            nn.Conv1d(in_channels, out_channels[5], kernel_size=1)
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)

        # Concatenate the branch outputs along the channel dimension
        outputs = [branch1x1, branch3x3, branch5x5, branch_pool]
        outputs = torch.cat(outputs, dim=1)
        outputs = rearrange(outputs, 'b c s -> b s c')  # Convert back to (batch_size, seq_len, channels)

        return outputs

In [18]:
incept = InceptionBlock(1, [16, 8, 16, 8, 16, 16])
batch_size = 64
freq = 250
seq_len = 10 * freq
in_size = 1
x = torch.randn(batch_size, in_size, seq_len) 
embedding_size = 64

transformer = TransformerEncoderWithCLS(incept, embedding_size, 8, 2, 5)


In [19]:
from torchsummary import summary

summary(transformer, (1, seq_len))

Layer (type:depth-idx)                        Output Shape              Param #
├─InceptionBlock: 1-1                         [-1, 2500, 64]            --
|    └─Conv1d: 2-1                            [-1, 16, 2500]            32
|    └─Sequential: 2-2                        [-1, 16, 2500]            --
|    |    └─Conv1d: 3-1                       [-1, 8, 2500]             16
|    |    └─Conv1d: 3-2                       [-1, 16, 2500]            400
|    └─Sequential: 2-3                        [-1, 16, 2500]            --
|    |    └─Conv1d: 3-3                       [-1, 8, 2500]             16
|    |    └─Conv1d: 3-4                       [-1, 16, 2500]            656
|    └─Sequential: 2-4                        [-1, 16, 2500]            --
|    |    └─MaxPool1d: 3-5                    [-1, 1, 2500]             --
|    |    └─Conv1d: 3-6                       [-1, 16, 2500]            32
├─Embedding: 1-2                              [-1, 2500, 64]            160,000
├─Transformer

Layer (type:depth-idx)                        Output Shape              Param #
├─InceptionBlock: 1-1                         [-1, 2500, 64]            --
|    └─Conv1d: 2-1                            [-1, 16, 2500]            32
|    └─Sequential: 2-2                        [-1, 16, 2500]            --
|    |    └─Conv1d: 3-1                       [-1, 8, 2500]             16
|    |    └─Conv1d: 3-2                       [-1, 16, 2500]            400
|    └─Sequential: 2-3                        [-1, 16, 2500]            --
|    |    └─Conv1d: 3-3                       [-1, 8, 2500]             16
|    |    └─Conv1d: 3-4                       [-1, 16, 2500]            656
|    └─Sequential: 2-4                        [-1, 16, 2500]            --
|    |    └─MaxPool1d: 3-5                    [-1, 1, 2500]             --
|    |    └─Conv1d: 3-6                       [-1, 16, 2500]            32
├─Embedding: 1-2                              [-1, 2500, 64]            160,000
├─Transformer

In [None]:
transformer = transformer.to(device)
x = x.to(device)
out = transformer(x)
out.shape

torch.Size([64, 5])

## Data loader stuffs

In [1]:
import random
from torch.utils.data import Dataset, DataLoader, Sampler

class SleepStageDataset(Dataset):
    def __init__(self, subjects, data, labels, seq_len, freq):
        '''
        This class takes in a list of subject, a path to the MASS directory 
        and reads the files associated with the given subjects as well as the sleep stage annotations
        '''
        super().__init__()

        self.seq_len = seq_len

        # Get the sleep stage labels
        self.full_signal = []
        self.full_labels = []

        self.subject_list = []
        for subject in subjects:
            if subject not in data.keys():
                print(f"Subject {subject} not found in the pretraining dataset")
                continue

            # Get the signal for the given subject
            signal = torch.tensor(data[subject]['signal'], dtype=torch.float)

            # Get all the labels for the given subject
            label = torch.tensor([SleepStageDataset.get_labels().index(lab) for lab in labels[subject]]).type(torch.uint8)

            # Repeat the labels freq times to match the signal using a pytorch function
            label = torch.tensor(label).repeat_interleave(freq) 

            # Add some '?' padding at the end to make sure the length of signal and label match
            missing = len(signal) - len(label)
            label = torch.cat([label, torch.full((missing, ), SleepStageDataset.get_labels().index('?')).type(torch.uint8)])

            # Make sure that the signal and the labels are the same length
            assert len(signal) == len(label)

            # Add to full signal and full label
            self.full_labels.append(label)
            self.full_signal.append(signal)
            del data[subject], signal, label
        
        self.full_signal = torch.cat(self.full_signal)
        self.full_labels = torch.cat(self.full_labels)

    @staticmethod
    def get_labels():
        return ['1', '2', '3', 'R', 'W', '?']

    def __getitem__(self, index):
        # Get data and label at the given index
        signal = self.full_signal[index - self.seq_len:index]
        label = self.full_labels[index]
        signal = signal.unsqueeze(0)

        return signal, label.type(torch.LongTensor)

    def __len__(self):
        return len(self.full_signal)
    
class SleepStageSampler(Sampler):
    def __init__(self, dataset, seq_len, nb_batch_per_epoch, batch_size):
        self.dataset = dataset
        self.seq_len = seq_len
        self.max_len = len(dataset)
        self.limit = nb_batch_per_epoch * batch_size
        self.nb_batch_per_epoch = nb_batch_per_epoch

    def __iter__(self):
        for i in range(self.limit): 
            while True:
                index = random.randint(self.seq_len, self.max_len - 1)
                # Make sure that the label at the end of the window is not '?'
                label = self.dataset.full_labels[index]
                if label != SleepStageDataset.get_labels().index('?'):
                    break
            yield index

    def __len__(self):
        return self.nb_batch_per_epoch

In [2]:
import time

import numpy as np

from portiloop_software.portiloop_python.ANN.utils import get_configs

from portiloop_software.portiloop_python.ANN.data.mass_data import read_pretraining_dataset, read_sleep_staging_labels, read_spindle_trains_labels
import torch


experiment_name = 'test_sleep_staging'
seed = 42

config = get_configs(experiment_name, False, seed)
# config['nb_conv_layers'] = 4
# config['hidden_size'] = 64
# config['nb_rnn_layers'] = 4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Run some testing on subject 1
# Load the data
labels = read_spindle_trains_labels(config['old_dataset'])
ss_labels = read_sleep_staging_labels(config['path_dataset'])
# for index, patient_id in enumerate(ss_labels.keys()):


data = read_pretraining_dataset(config['MASS_dir'])

In [7]:
dataset = SleepStageDataset(list(data.keys()), data, ss_labels, seq_len, freq)
loader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    sampler=SleepStageSampler(dataset, seq_len, 1000, batch_size),
    num_workers=0,
    pin_memory=True,
    drop_last=True
)

  label = torch.tensor(label).repeat_interleave(freq)


In [33]:
x, y = next(iter(loader))
x.shape, y.shape

(torch.Size([64, 1, 2500]), torch.Size([64]))

In [27]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score

def evaluate(model, data_loader, criterion):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in data_loader:
            X, y = batch
            X = X.to(device)
            y = y.to(device)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(torch.argmax(model(X), dim=1).cpu().numpy())
            loss = criterion(model(X), y)
            total_loss += loss.item()
            total_correct += torch.sum(torch.argmax(model(X), dim=1) == y)
            total_samples += len(X)
    loss  = total_loss / len(data_loader)
    accuracy = total_correct / total_samples
    conf_mat = confusion_matrix(y_true, y_pred)
    class_report = classification_report(y_true, y_pred, target_names=SleepStageDataset.get_labels()[:-1])
    return loss, accuracy, conf_mat, class_report

In [40]:
import torch.optim as optim

def train_epoch(model, data_loader, loss_function, optimizer, device):
    model.train()
    epoch_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for inputs, labels in data_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        _, predicted = torch.max(outputs, dim=1)

        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)

    epoch_loss /= len(data_loader)
    accuracy = correct_predictions / total_predictions

    return epoch_loss, accuracy

In [41]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer.parameters(), lr=1e-5)

max_epochs = 100
for i in range(max_epochs):
    train_loss, train_acc = train_epoch(transformer, loader, criterion, optimizer, device)
    test_loss, test_accuracy, conf_mat, class_report = evaluate(transformer, loader, criterion)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [42]:
train_loss

95.96601337591807

In [48]:
train_acc

0.44903125

In [43]:
test_loss

94.30902067820232

In [44]:
test_accuracy

tensor(0.4695, device='cuda:0')

In [47]:
print(class_report)

              precision    recall  f1-score   support

           1       0.00      0.00      0.00      5842
           2       0.47      1.00      0.64     30051
           3       0.00      0.00      0.00      8026
           R       0.00      0.00      0.00     10858
           W       0.00      0.00      0.00      9223

    accuracy                           0.47     64000
   macro avg       0.09      0.20      0.13     64000
weighted avg       0.22      0.47      0.30     64000

