In [None]:
# Identify whether a CUDA-enabled GPU is available
import torch

import os
import copy
import mne
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mne.datasets.sleep_physionet.age import fetch_data
%matplotlib inline

if torch.cuda.is_available():
    print('CUDA-enabled GPU found. Training should be faster.')
else:
    print('No GPU found. Training will be carried out on CPU, which might be '
          'slower.\n\nIf running on Google Colab, you can request a GPU runtime by'
          ' clicking\n`Runtime/Change runtime type` in the top bar menu, then '
          'selecting \'GPU\'\nunder \'Hardware accelerator\'.')
    
mne.set_log_level('ERROR')  # To avoid flooding the cell outputs with messages
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import glob
import os



# Local directory where the dataset is stored
local_dir = "/kaggle/input/sleep-edf-and-apnea/sleep-edf-database-expanded-1.0.0/sleep-edf-database-expanded-1.0.0/sleep-cassette"

# Get all .edf files in the directory
edf_files = sorted(glob.glob(os.path.join(local_dir, "*.edf")))


# Initialize list to store file paths
list_small = []
list_big = []

# Loop through each file
for i,edf_file in enumerate(edf_files):

    
    
    directory_parts = edf_file.split("/")

    file_name = directory_parts[-1].split(".")[0]

    signal_type=file_name.split("-")[-1]
    subject_name= file_name.split("-")[0]

    
    if signal_type=='PSG':
        list_small.append(edf_file)
        
        
    elif signal_type=='Hypnogram':

        hypnogram_path = os.path.join(local_dir, f"{subject_name}-Hypnogram.edf")

        list_small.append(hypnogram_path)
    
        # Append list_small to list_big
        list_big.append(list_small)

        list_small = []
#     if i==2:
#         break
# Display list_big
print(list_big[0])
fnames1=list_big

In [None]:
fnames = fnames1
print(len(fnames))

In [None]:
## DATASET FOR EXPERIMENT

# fnames = fnames_cst  # ONLY IMPORT CST FILES  (78 subjects - 153 files)
# # fnames = fnames_tlm  # ONLY IMPORT TLM FILES  (44 Subjects - 88 Files)
# # fnames = fnames_cst+fnames_tml  # IMPORT BOTH CST AND TLM FILES (122 SUBJECTS)
# print(len(fnames))


'''We will import EEG signal 20 minutes before sleep and 0 or 1 minute after sleep
   However, some sleep subjects does not have any record prior to sleep onset, for those we define before and after as 0 mins '''

def load_sleep_physionet_raw(raw_fname, annot_fname, load_eeg_only=True):   
    try:
        crop_wake_mins_start=20
        crop_wake_mins_end=1
        mapping = {'EOG horizontal': 'eog',
                   'Resp oro-nasal': 'misc',
                   'EMG submental': 'misc',
                   'Temp rectal': 'misc',
                   'Event marker': 'misc'}
        exclude = mapping.keys() if load_eeg_only else ()
        raw = mne.io.read_raw_edf(raw_fname, exclude=exclude, preload=True)    
        # channels_to_pick = ['EEG Fpz-Cz','EEG Pz-Oz']
        channels_to_pick = ['EEG Fpz-Cz']
        raw.pick_channels(channels_to_pick)   
        annots = mne.read_annotations(annot_fname)
        raw.set_annotations(annots, emit_warning=False)    
        if not load_eeg_only:            
            raw.set_channel_types(mapping)   
        if crop_wake_mins_start > 0: 
            mask = [x[-1] in ['1', '2', '3', '4', 'R']
                    for x in annots.description]
            sleep_event_inds = np.where(mask)[0]
            tmin = annots[int(sleep_event_inds[0])]['onset'] - crop_wake_mins_start * 60
            tmax = annots[int(sleep_event_inds[-1])]['onset'] + \
                   crop_wake_mins_end * 60 + annots[int(sleep_event_inds[-1])]['duration']
            raw.crop(tmin=tmin, tmax=tmax)

        ch_names = {i: i.replace('EEG ', '') 
                    for i in raw.ch_names if 'EEG' in i}
        mne.rename_channels(raw.info, ch_names)
        basename = os.path.basename(raw_fname)
        subj_nb, rec_nb = int(basename[3:5]), int(basename[5])

   
        raw.info['subject_info'] = {'id': subj_nb}
        raw.info['description'] = f"rec_id:{rec_nb}"

    except:     
        crop_wake_mins_start=0
        crop_wake_mins_end=0
        mapping = {'EOG horizontal': 'eog',
                   'Resp oro-nasal': 'misc',
                   'EMG submental': 'misc',
                   'Temp rectal': 'misc',
                   'Event marker': 'misc'}
        exclude = mapping.keys() if load_eeg_only else ()
        raw = mne.io.read_raw_edf(raw_fname, exclude=exclude, preload=True)     
        channels_to_pick = ['EEG Fpz-Cz']  
        # channels_to_pick = ['EEG Fpz-Cz','EEG Pz-Oz'] 
        raw.pick_channels(channels_to_pick)       
        annots = mne.read_annotations(annot_fname)
        raw.set_annotations(annots, emit_warning=False)  
        if not load_eeg_only:            
            raw.set_channel_types(mapping)   
        if crop_wake_mins_start > 0: 
            mask = [x[-1] in ['1', '2', '3', '4', 'R']  
                    for x in annots.description]         
            sleep_event_inds = np.where(mask)[0]         
            tmin = annots[int(sleep_event_inds[0])]['onset'] - crop_wake_mins_start * 60
            tmax = annots[int(sleep_event_inds[-1])]['onset'] + \
                   crop_wake_mins_end * 60 + annots[int(sleep_event_inds[-1])]['duration']
            raw.crop(tmin=tmin, tmax=tmax)

        ch_names = {i: i.replace('EEG ', '') 
                    for i in raw.ch_names if 'EEG' in i}
        mne.rename_channels(raw.info, ch_names)
        basename = os.path.basename(raw_fname)
        subj_nb, rec_nb = int(basename[3:5]), int(basename[5])

    
        raw.info['subject_info'] = {'id': subj_nb}
        raw.info['description'] = f"rec_id:{rec_nb}"

    return raw

raws = []
error = []
for er,f in enumerate(fnames):
    if er>=0:
        try:
            raw = load_sleep_physionet_raw(f[0], f[1])
            raws.append(raw)
        except Exception as e:
            print(f"Ignoring error: {e}")
            error.append(er)
        if er==None:
            break
        if er%10==0 and er!=0:
            print(f"iteration completed:{er}")
print(f"Total files loaded: {len(raws)}")
print(f"Total error in anotation:{len(error)}")

In [None]:
from torch.utils.data import Dataset, ConcatDataset

def extract_epochs(raw, chunk_duration=30., epoch_length=5):
    try:
        annotation_desc_2_event_id = {
            'Sleep stage W': 1,
            'Sleep stage 1': 2,
            'Sleep stage 2': 3,
            'Sleep stage 3': 4,
            'Sleep stage 4': 4,
            'Sleep stage R': 5}
        events, _ = mne.events_from_annotations(
            raw, event_id=annotation_desc_2_event_id, 
            chunk_duration=chunk_duration)
        event_id = {
            'Sleep stage W': 1,
            'Sleep stage 1': 2,
            'Sleep stage 2': 3,
            'Sleep stage 3/4': 4,
            'Sleep stage R': 5}
        tmax = 30. - (1. / raw.info['sfreq']) 
        picks = mne.pick_types(raw.info, eeg=True, eog=True)
        epochs = mne.Epochs(raw=raw, events=events, picks=picks, preload=True,
                            event_id=event_id, tmin=0., tmax=tmax, baseline=None)
        data, labels = epochs.get_data(), epochs.events[:, 2] - 1
        combined_data, combined_labels = [], []
        for i in range(len(data) - epoch_length + 1):
            combined_data.append(np.concatenate(data[i:i + epoch_length], axis=-1))
            combined_labels.append(labels[i + epoch_length - 1])  
    except:
        annotation_desc_2_event_id_altr = {
            'Sleep stage W': 1,
            'Sleep stage 1': 2,
            'Sleep stage 2': 3,
            'Sleep stage R': 5}
        events, _ = mne.events_from_annotations(
            raw, event_id=annotation_desc_2_event_id, 
            chunk_duration=chunk_duration)
        event_id = {
            'Sleep stage W': 1,
            'Sleep stage 1': 2,
            'Sleep stage 2': 3,
            'Sleep stage R': 5}
        tmax = 30. - (1. / raw.info['sfreq'])  
        picks = mne.pick_types(raw.info, eeg=True, eog=True)
        epochs = mne.Epochs(raw=raw, events=events, picks=picks, preload=True,
                            event_id=event_id, tmin=0., tmax=tmax, baseline=None)
        data, labels = epochs.get_data(), epochs.events[:, 2] - 1
        combined_data, combined_labels = [], []
        for i in range(len(data) - epoch_length + 1):
            combined_data.append(np.concatenate(data[i:i + epoch_length], axis=-1))
            combined_labels.append(labels[i + epoch_length - 1]) 
    return np.array(combined_data), np.array(combined_labels)



class EpochsDataset(Dataset):
    def __init__(self, epochs_data, epochs_labels, subj_nb=None, 
                 rec_nb=None, transform=None):
        assert len(epochs_data) == len(epochs_labels)
        self.epochs_data = epochs_data
        self.epochs_labels = epochs_labels
        self.subj_nb = subj_nb
        self.rec_nb = rec_nb
        self.transform = transform
    def __len__(self):
        return len(self.epochs_labels)
    def __getitem__(self, idx):
        X, y = self.epochs_data[idx], self.epochs_labels[idx]
        if self.transform is not None:
            X = self.transform(X)
        X = torch.as_tensor(X[None, ...])
        return X, y
def scale(X):
    X -= np.mean(X, axis=1, keepdims=True)
    return X / np.std(X, axis=1, keepdims=True)




import os

all_datasets = []
error_2 = []
valid_sub = []

for i, raw in enumerate(raws):
    try:
        # 1) Extract epochs
        X_data, y_labels = extract_epochs(raw)

        # 2) Parse subject and rec from filename
        basename = os.path.basename(raw.filenames[0])   # e.g. 'SC4001E0-PSG.edf'
        subj_nb = int(basename[3:5])                    # '40'
        rec_nb = int(basename[5])                       # '0'

        # 3) Build dataset
        dataset = EpochsDataset(
            X_data,
            y_labels,
            subj_nb=subj_nb,
            rec_nb=rec_nb,
            transform=scale
        )
        all_datasets.append(dataset)
        valid_sub.append(subj_nb)

    except Exception as e:
        print(f"found error in {i}")
        print(e)
        error_2.append(i)

    if i == 200:
        break
    if i % 10 == 0 and i != 0:
        print(f"Total Recording Loaded:{i}")

print(f"Total error file: {len(error_2)}")

dataset = ConcatDataset(all_datasets)


In [None]:
from sklearn.model_selection import LeavePGroupsOut

def pick_recordings(dataset, subj_rec_nbs):
    pick_idx = list()                              
    for subj_nb, rec_nb in subj_rec_nbs:
        for i, ds in enumerate(dataset.datasets):
            if (ds.subj_nb == subj_nb) and (ds.rec_nb == rec_nb):
                pick_idx.append(i)
    print(f"test files: {pick_idx}")           
    remaining_idx = np.setdiff1d(
        range(len(dataset.datasets)), pick_idx)
    print(f" train+val {remaining_idx}")
    pick_ds = ConcatDataset([dataset.datasets[i] for i in pick_idx])
    if len(remaining_idx) > 0:
        remaining_ds = ConcatDataset(
            [dataset.datasets[i] for i in remaining_idx])
    else:
        remaining_ds = None    
    return pick_ds, remaining_ds
    
def train_test_split(dataset, n_groups, split_by='subj_nb'):
    groups = [getattr(ds, split_by) for ds in dataset.datasets]
    print(groups)
    train_idx, test_idx = next(
        LeavePGroupsOut(n_groups).split(X=groups, groups=groups))
    print(len(train_idx))
    print(len(test_idx))
    train_ds = ConcatDataset([dataset.datasets[i] for i in train_idx])
    test_ds = ConcatDataset([dataset.datasets[i] for i in test_idx]) 
    return train_ds, test_ds

result = []
fst_sub = dataset.datasets[0].subj_nb
total_sub=14   ## Number of subject in test set
for i in range(fst_sub,fst_sub+total_sub):
    if i in [2, 4]:
        continue
    for j in range(1, 3):
        result.append((i, j))
print(result)
test_recs = [(subj_nb, rec_nb)  
             for subj_nb, rec_nb in result]
test_ds, train_ds = pick_recordings(dataset, test_recs)



split = 25
k = 1//split
# k=0.02

n_subjects_valid = max(1, int(len(train_ds.datasets) * k))
train_ds, valid_ds = train_test_split(train_ds, n_subjects_valid, split_by='subj_nb')
print('Number of examples in each set:')
print(f'Number of Training Segments: {len(train_ds)}')
print(f'Number of Validation Segments: {len(valid_ds)}')
print(f'Number of Test Segments: {len(test_ds)}')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

classes_mapping = {0: 'W', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'R'}
y_train = pd.Series([y for _, y in train_ds]).map(classes_mapping)
class_order = [classes_mapping[i] for i in range(len(classes_mapping))]
value_counts = y_train.value_counts()
ordered_value_counts = value_counts.reindex(class_order)
colors = sns.color_palette("BuGn_r", len(class_order))
fig, ax = plt.subplots()

ax.pie(ordered_value_counts,
       labels=class_order,
       colors=colors,
       autopct='%1.1f%%',
       startangle=90,
       counterclock=False)
plt.show()


In [None]:
import torch
from torch import nn

class SleepStagerSiam(nn.Module):
    def __init__(self, n_channels, sfreq,Seq_Len,n_conv_chs=8, time_conv_size_s=0.50,
                 max_pool_size_s=0.125, n_classes=5, input_size_s=30,
                 dropout=0.25):
        super().__init__()

        time_conv_size = int(time_conv_size_s * sfreq)
        max_pool_size = int(max_pool_size_s * sfreq)
        input_size = int(input_size_s * sfreq)
        input_size_1st = input_size
        pad_size = time_conv_size // 2
        self.n_channels = n_channels

        if n_channels > 1:
            self.spatial_conv = nn.Conv2d(1, n_channels, (n_channels, 1))

        self.feature_extractor1_1 = nn.Sequential(
            nn.Conv2d(1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
            nn.ReLU(),
            nn.MaxPool2d((1, max_pool_size))
        )
        
        self.feature_extractor1_2 = nn.Sequential(
            nn.Conv2d(1, n_conv_chs, (1, time_conv_size*2), padding=(0, pad_size*2+1)),
            nn.ReLU(),
            nn.MaxPool2d((1, max_pool_size))
        )
        
        self.feature_extractor1_3 = nn.Sequential(
            nn.Conv2d(1, n_conv_chs, (1, time_conv_size*4), padding=(0, 3 + pad_size*4)),
            nn.ReLU(),
            nn.MaxPool2d((1, max_pool_size))
        )
        
        
        
        
        i=32
        self.lstm1 = nn.LSTM(input_size=n_conv_chs * n_channels*3, hidden_size=i, batch_first=True) #lstm er input size n_conv_chs * n_channels *2 hbe karon duita conv layer cat kora hoise
        self.lstm2 = nn.LSTM(input_size=i, hidden_size=i, batch_first=True)
        
        self.lstm3 = nn.LSTM(input_size=i*2, hidden_size=i, batch_first=True)
        self.lstm4 = nn.LSTM(input_size=i*3, hidden_size=i, batch_first=True)
        # self.lstm5 = nn.LSTM(input_size=i*4, hidden_size=i, batch_first=True)
        
#         len_last_layer = 50 * (input_size // (max_pool_size ** 2))  # Adjusted for LSTM output size
#         len_last_layer = 8000 * i//32
        
        len_last_layer = (Seq_Len*3000//12) * 32
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(len_last_layer, n_classes)
        )

    def forward(self, x):
        if self.n_channels > 1:
            x = self.spatial_conv(x)
            x = x.transpose(1, 2)
#         print(f"before mtcl{x.shape}")
        x1 = self.feature_extractor1_1(x)  # Shape: (128, 8, 2, 20)
#         print(f"after mtcl{x1.shape}")
        x2 = self.feature_extractor1_2(x)  # Shape: (128, 8, 2, 20)
#         print(x2.shape)
        x3 = self.feature_extractor1_3(x)
        # x4 = self.feature_extractor1_4(x)
        
        x = torch.cat([x1, x2, x3], dim=1)  # Concatenate along filter dimension -> (128, 16, 2, 20)
#         print(f"after filter cat:{x.shape}")

        

        x = x.view(x.size(0), -1, x.size(3))  # Reshape to (128, 32, 20)
#         print(f"before lstm1:{x.shape}")
        x = x.permute(0, 2, 1)  # Permute to (128, 20, 32)
#         print(f"before lstm1.1:{x.shape}")
        x_l1, _ = self.lstm1(x)  # LSTM expects input of shape (batch_size, seq_length, input_size)
# #         print(f"after lstm1:{x_l1.shape}")
        x_l2, _ = self.lstm2(x_l1)
# #         print(f"after lstm2:{x_l2.shape}")
        x_c1 = torch.cat([x_l1, x_l2], dim=2)
# #         print(f"after cat(ls1+ls2):{x_c1.shape}")
        x_l3, _ = self.lstm3(x_c1)
# #         print(f"after lstm3:{x_l3.shape}")
        x_c2 = torch.cat([x_l1, x_l2, x_l3], dim=2)
#         print(f"after cat(ls1+ls2+ls2):{x_c2.shape}")
        x_l4, _ = self.lstm4(x_c2)
#         print(f"after lstm4:{x.shape}")
        # x_c3 = torch.cat([x_l1, x_l2, x_l3, x_l4], dim=2)
        # x_l5, _ = self.lstm5(x_c3)
        x = x_l4
        x = x.contiguous().view(x.size(0), -1)  # Flatten for fully connected layer
#         print(f"after flatten:{x.shape}")
        return self.fc(x)


# sfreq = raws[0].info['sfreq']  # Sampling frequency
# n_channels = raws[0].info['nchan']  # Number of channels

sfreq = 100
n_channels=1
model = SleepStagerSiam(n_channels, sfreq, n_classes=5, Seq_Len=5)
print(f'Using device \'{device}\'.')
model = model.to(device)

In [None]:
from torch.utils.data import DataLoader

# Create dataloaders
# train_batch_size = 128  # Important hyperparameter
train_batch_size = 512
valid_batch_size = 256  # Can be made as large as what fits in memory; won't impact performance
num_workers = 0  # Number of processes to use for the data loading process; 0 is the main Python process

loader_train = DataLoader(
    train_ds, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
loader_valid = DataLoader(
    valid_ds, batch_size=valid_batch_size, shuffle=False, num_workers=num_workers)
loader_test = DataLoader(
    test_ds, batch_size=valid_batch_size, shuffle=False, num_workers=num_workers)


y_pred_all, y_true_all = list(), list()
for batch_x, batch_y in loader_test:
    batch_x = batch_x.to(device=device, dtype=torch.float32)
    batch_y = batch_y.to(device=device, dtype=torch.int64)
#     output = model.forward(batch_x)
#     y_pred_all.append(torch.argmax(output, axis=1).cpu().numpy())
    y_true_all.append(batch_y.cpu().numpy())
    
# y_pred = np.concatenate(y_pred_all)
y_true = np.concatenate(y_true_all)
rec_ids = np.concatenate(  # indicates which recording each example comes from
    [[i] * len(ds) for i, ds in enumerate(test_ds.datasets)])

In [None]:
import torch
import numpy as np
import copy
from sklearn.metrics import balanced_accuracy_score, cohen_kappa_score, accuracy_score

def _do_train(model, loader, optimizer, criterion, device, metric):
    # training loop
    model.train()
    
    train_loss = np.zeros(len(loader))
    y_pred_all, y_true_all = list(), list()
    for idx_batch, (batch_x, batch_y) in enumerate(loader):
        optimizer.zero_grad()
        batch_x = batch_x.to(device=device, dtype=torch.float32)
        batch_y = batch_y.to(device=device, dtype=torch.int64)

        output = model(batch_x)
        loss = criterion(output, batch_y)

        loss.backward()
        optimizer.step()
        
        y_pred_all.append(torch.argmax(output, axis=1).cpu().numpy())
        y_true_all.append(batch_y.cpu().numpy())

        train_loss[idx_batch] = loss.item()
        
    y_pred = np.concatenate(y_pred_all)
    y_true = np.concatenate(y_true_all)
    perf = metric(y_true, y_pred)
    
    return np.mean(train_loss), perf
        

def _validate(model, loader, criterion, device, metric):
    # validation loop
    model.eval()
    
    val_loss = np.zeros(len(loader))
    y_pred_all, y_true_all = list(), list()
    with torch.no_grad():
        for idx_batch, (batch_x, batch_y) in enumerate(loader):
            batch_x = batch_x.to(device=device, dtype=torch.float32)
            batch_y = batch_y.to(device=device, dtype=torch.int64)
            output = model.forward(batch_x)

            loss = criterion(output, batch_y)
            val_loss[idx_batch] = loss.item()
            
            y_pred_all.append(torch.argmax(output, axis=1).cpu().numpy())
            y_true_all.append(batch_y.cpu().numpy())
            
    y_pred = np.concatenate(y_pred_all)
    y_true = np.concatenate(y_true_all)
    perf = metric(y_true, y_pred)

    return np.mean(val_loss), perf


def train(model, loader_train, loader_valid, optimizer, criterion, n_epochs, 
          patience, device, metric=None):
   
    best_valid_loss = np.inf
    best_model_state = None
    waiting = 0
    history = list()
    
    if metric is None:
        metric = accuracy_score
        
    print('epoch \t train_loss \t valid_loss \t train_perf \t valid_perf')
    print('-------------------------------------------------------------------')

    for epoch in range(1, n_epochs + 1):
        train_loss, train_perf = _do_train(
            model, loader_train, optimizer, criterion, device, metric=metric)
        valid_loss, valid_perf = _validate(
            model, loader_valid, criterion, device, metric=metric)
        history.append(
            {'epoch': epoch, 
             'train_loss': train_loss, 'valid_loss': valid_loss,
             'train_perf': train_perf, 'valid_perf': valid_perf})
        
        print(f'{epoch} \t {train_loss:0.4f} \t {valid_loss:0.4f} '
              f'\t {train_perf:0.4f} \t {valid_perf:0.4f}')

        # model saving
        if valid_loss < best_valid_loss:
            print(f'best val loss {best_valid_loss:.4f} -> {valid_loss:.4f}')
            best_valid_loss = valid_loss
            best_model_state = copy.deepcopy(model.state_dict())
            waiting = 0
        else:
            waiting += 1

        # model early stopping
        if waiting >= patience:
            print(f'Stop training at epoch {epoch}')
            print(f'Best val loss : {best_valid_loss:.4f}')
            break

    # Load the best model state before returning
    model.load_state_dict(best_model_state)
    return model, history


In [None]:
from sklearn.utils.class_weight import compute_class_weight

train_y = np.concatenate([ds.epochs_labels for ds in train_ds.datasets])
class_weights = compute_class_weight('balanced', classes=np.unique(train_y), y=train_y)
class_weights

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
criterion = CrossEntropyLoss(weight=torch.Tensor(class_weights).to(device))

In [None]:
n_epochs = 200
patience = 35

# best_model, history = train(
#     model, loader_train, loader_valid, optimizer, criterion, n_epochs, patience, 
#     device, metric=cohen_kappa_score)

best_model, history = train(
    model, loader_train, loader_valid, optimizer, criterion, n_epochs, patience, 
    device, metric=balanced_accuracy_score)

In [None]:
# Visualizing the learning curves

history_df = pd.DataFrame(history)
ax1 = history_df.plot(x='epoch', y=['train_loss', 'valid_loss'], marker='o')
ax1.set_ylabel('Loss')
ax2 = history_df.plot(x='epoch', y=['train_perf', 'valid_perf'], marker='o')
# ax2.set_ylabel('Cohen\'s kappa')
ax2.set_ylabel('Accuracy')

In [None]:
best_model.eval()

y_pred_all, y_true_all = list(), list()
for batch_x, batch_y in loader_test:
    batch_x = batch_x.to(device=device, dtype=torch.float32)
    batch_y = batch_y.to(device=device, dtype=torch.int64)
    output = model.forward(batch_x)
    # print(batch_x.shape)
    y_pred_all.append(torch.argmax(output, axis=1).cpu().numpy())
    y_true_all.append(batch_y.cpu().numpy())
    
y_pred = np.concatenate(y_pred_all)
y_true = np.concatenate(y_true_all)
rec_ids = np.concatenate(  # indicates which recording each example comes from
    [[i] * len(ds) for i, ds in enumerate(test_ds.datasets)])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, cm_normalized, class_names, title='Confusion Matrix',
                          cmap=plt.cm.Reds, linewidths=0.5, linecolor='k', save_path=None):
    plt.figure(figsize=(6.2, 6))
    
    # Combine normalized values + raw counts into annotation text
    annot = np.empty_like(cm).astype(str)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            annot[i, j] = f"{cm_normalized[i, j]*100:.1f}%\n{cm[i, j]}"
    
    sns.heatmap(cm_normalized, annot=annot, fmt="", cmap=cmap,
                xticklabels=class_names, yticklabels=class_names,
                linewidths=1.5, linecolor=linecolor, cbar=False, annot_kws={"size": 12, "weight": 'bold'})
    
    # plt.title(title)
    plt.ylabel('Ground Truth', fontweight='bold')
    plt.xlabel('Predicted', fontweight='bold')
    plt.xticks(fontsize=12, fontweight='bold')
    plt.yticks(fontsize=12, fontweight='bold')
    if save_path:
        plt.savefig(save_path, format='jpg', bbox_inches='tight', dpi=1200)
    plt.show()


# Compute confusion matrix
conf_mat = confusion_matrix(y_true, y_pred)

# Normalize the confusion matrix per row (recall)
conf_mat_normalized = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis]

# Class names from mapping
class_names = [classes_mapping[i] for i in range(len(classes_mapping))]

# Save path
save_path = "/kaggle/working/cm_uncertaintyyy_uncertain.jpg"

# Plot
plot_confusion_matrix(conf_mat, conf_mat_normalized, class_names,
                      cmap='Purples', linewidths=1.3, linecolor='k',
                      save_path=save_path)


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score
from sklearn.metrics import roc_auc_score, log_loss
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# Example: y_true and y_pred for classification
y_true_class = y_true
y_pred_class = y_pred

# Classification metrics
print("Accuracy:", accuracy_score(y_true_class, y_pred_class))
print("Balanced Accuracy:", balanced_accuracy_score(y_true_class, y_pred_class))
print("F1-Score (macro):", f1_score(y_true_class, y_pred_class, average='macro'))
print("Precision (macro):", precision_score(y_true_class, y_pred_class, average='macro'))
print("Recall (macro):", recall_score(y_true_class, y_pred_class, average='macro'))
test_kappa = cohen_kappa_score(y_true, y_pred)
print(f'Cohen\'s kappa: {test_kappa:0.5f}')
print("Confusion Matrix:\n", confusion_matrix(y_true_class, y_pred_class))

In [None]:
from sklearn.metrics import f1_score

# Example: y_true and y_pred for classification
y_true_class = y_true
y_pred_class = y_pred

# Calculate per-class F1 scores
per_class_f1_scores = f1_score(y_true_class, y_pred_class, average=None)

# Print per-class F1 scores
for i, score in enumerate(per_class_f1_scores):
    print(f"F1-Score for class {i}: {score:.5f}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Assuming y_true, y_pred, and rec_ids are already defined

mask = rec_ids == 0  # pick a recording number

# Sampling frequency and epoch length
fs = 100  # in Hz
epoch_length = 30  # in seconds

# Define the start and end times for the 2-hour period in hours
start_time = 0  # start time in hours
end_time = start_time+7.0  # end time in hours

# Convert start and end times to seconds
start_time_seconds = start_time * 3600
end_time_seconds = end_time * 3600

# Calculate the corresponding epoch indices
start_epoch = int(start_time_seconds / epoch_length)
end_epoch = int(end_time_seconds / epoch_length)

# Select the data for the specified 2-hour period
y_true_period = y_true[mask][start_epoch:end_epoch]
y_pred_period = y_pred[mask][start_epoch:end_epoch]

# Create the time vector for the specified period
t_period = np.arange(len(y_true_period)) * epoch_length / 3600 + start_time

fig, ax = plt.subplots(figsize=(12, 3))
ax.plot(t_period, y_true_period, color='blue', label='Annotated by sleep experts')
ax.plot(t_period, y_pred_period, alpha=0.7, color='red', label='Predicted by our model')
ax.set_yticks([0, 1, 2, 3, 4])
ax.set_yticklabels(['W', 'N1', 'N2', 'N3', 'R'])
ax.set_xlabel('Time (h)')
ax.set_title(f'Hypnogram ({start_time}-{end_time} Hours)')
ax.legend()
plt.show()
