In [1]:
import os, sys
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

from MicroneurographyDataloader import *
from _external.WHVPNet_pytorch.networks import *
from _external.WHVPNet_tensorflow.VPLayer import *
from spike_classification import *
from XAIProject import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.collections import LineCollection
import matplotlib.ticker as mticker
import torch.nn.functional as F

In [2]:
"""
Load the model.
"""

model_name = 'trained_models/widnow_15_overlapping_11_hidden_6_nweight_4_id_6'

dtype = torch.float64
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
window_size = 15
overlapping_size = 11

MNG_dataloader = MicroneurographyDataloader()
path = f'window_{window_size}_overlap_{overlapping_size}_corrected.pkl'
MNG_dataloader.load_samples_and_labels_from_file(path)

dataloaders = MNG_dataloader.sequential_split_with_resampling(1024)

sample_windows = torch.tensor(MNG_dataloader.raw_data_windows, dtype=torch.float64).unsqueeze(1)
n_channels, n_in = sample_windows[0].shape
n_out = len(MNG_dataloader.binary_labels_onehot[0])
num_VP_features = 6
num_weights = 4
fcn_neurons = 6
affin = torch.tensor([6 / n_in, -0.3606]).tolist()
weight = ((torch.rand(num_weights)-0.5)*8).tolist()


model = VPNet(n_in, n_channels, num_VP_features, VPTypes.FEATURES, affin + weight, WeightedHermiteSystem(n_in, num_VP_features, num_weights), [fcn_neurons], n_out, device=device, dtype=dtype)
model.load_state_dict(torch.load(model_name, weights_only=True))

  all_spike['track'] = all_spike['track'].replace(self.track_replacement_dict).infer_objects(copy=False)


<All keys matched successfully>

In [8]:
"""
Evaluation.
"""
dataset_name='val' # 'or test
decision_boundary = 0.8
class_weights = torch.tensor([0.003, 0.997]).to(device)
weighted_criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion = VPLoss(weighted_criterion, 0.1)
accuracy, loss, binary_labels, multiple_labels, predicted_classes, predicted_probabilities = test(model, dataloaders[f'{dataset_name}_loader'], criterion, decision_boundary)
# compute_common_metrics(binary_labels, predicted_classes)
# compute_merged_metrics(binary_labels, predicted_classes)
# create_decision_ceratinty_boxplots(binary_labels, multiple_labels, predicted_classes, predicted_probabilities)

Label 0:
  TP: 0.0
  FN: 0.0
  FP: 15258.0
  TN: 387548.0

Label 1:
  TP: 259.0
  FN: 0.0
  FP: 0.0
  TN: 0.0

Label 2:
  TP: 255.0
  FN: 0.0
  FP: 0.0
  TN: 0.0

Label 3:
  TP: 236.0
  FN: 17.0
  FP: 0.0
  TN: 0.0

Accuracy: 96.22%, loss: 10.2869
Weighted Balanced Accuracy: 0.9700


In [4]:
"""
Get samples and timestamps from the dataloader.
"""

sample_windows = []
for data in dataloaders[f'{dataset_name}_loader']:
    x, labels, multiple = data
    sample_windows.append(x.cpu())
sample_windows = torch.cat(sample_windows).squeeze(1)
timestamp_windows = dataloaders[f'{dataset_name}_timestamps']

In [5]:
"""
Retrieve information for the output plot.
"""

def reconstruct_original_sequence(overlapping_windows, window_size, overlapping):
    """
    Transforms back the overlapping windows into a one dimensional sequence.
    """
    overlapping_windows = np.asarray(overlapping_windows)
    num_windows = len(overlapping_windows)
    stride = window_size - overlapping
    original_length = (num_windows - 1) * stride + window_size

    reconstructed_sequence = np.zeros(original_length)
    for i in range(num_windows):
        start_index = i * stride
        end_index = start_index + window_size
        reconstructed_sequence[start_index:end_index] = overlapping_windows[i]
    reconstructed_sequence = np.nan_to_num(reconstructed_sequence)
    return reconstructed_sequence


original_samples_seq = reconstruct_original_sequence(sample_windows, 15, 11)
original_timestamps_seq = reconstruct_original_sequence(timestamp_windows, 15, 11)

start_ts = original_timestamps_seq[0]
end_ts = original_timestamps_seq[-1]
start_index = MNG_dataloader.all_spikes_df['ts'].searchsorted(start_ts, side='left')
end_index = MNG_dataloader.all_spikes_df['ts'].searchsorted(end_ts, side='right')
ground_truth_spikes_in_dataset = MNG_dataloader.all_spikes_df.iloc[start_index:end_index]

In [6]:
def transform_predictions_to_sequence(overlapping_windows, window_size, overlapping):
    """
    Transforms the overlapping windows with prediction probabilities back into a one dimensional sequence.
    Every datapoint got different probabilities from the different windows it was present in.
    The sequence contains for every datapoint the average probs of all the probs it got in the repetitive windows.
    """
    predicted_probs_np = np.asarray(predicted_probabilities[:, 1])
    overlapping_windows = np.asarray(overlapping_windows)
    num_windows = len(overlapping_windows)
    stride = window_size - overlapping
    original_length = (num_windows - 1) * stride + window_size

    transformed_sequence = np.zeros(original_length)
    count_array = np.zeros(original_length)
    for i in range(num_windows):
        start_index = i * stride
        end_index = start_index + window_size
        transformed_sequence[start_index:end_index] += predicted_probs_np[i]
        count_array[start_index:end_index] += 1
    non_zero_count_mask = count_array > 0
    transformed_sequence[non_zero_count_mask] /= count_array[non_zero_count_mask] #avg of the predictions one datapoint got from all windows it was present in
    transformed_sequence = np.nan_to_num(transformed_sequence)
    return transformed_sequence

prediction_sequence = transform_predictions_to_sequence(timestamp_windows, 15, 11)

In [7]:
def plot_output_windows():
    """
    sample_size : how many datapoints should be included in one plot.
    plotting_start_idx = from where the plotting should start.
    num_of_plots = the number of plots the code should generate after each other.
    """
    sample_size = 3000
    plotting_start_idx = 181200
    num_of_plots = 2

    # probability bins for the coloring and legend
    prob_bin_colors = ['#C0C0C0', 'khaki', '#FFA500', '#FF0000']
    bin_limits = [0, 50, 80, 98, 100]
    percentage_labels = ['0-50%', '50-80%', '80-98%', '98-100%']
    legend_elements = [Line2D([0], [0], marker='o', color='w', label=label, markerfacecolor=prob_bin_colors[i], markersize=10)
                       for i, label in enumerate(percentage_labels)]

    # spikes for legend
    unique_tracks = MNG_dataloader.track_replacement_dict.values()
    marker_styles = ['s', '^', 'D', 'p', '*', 'h', 'H', '+', 'x', '|', '_', 'v', '<', '>', '8', 'P', 'X']
    track_markers = {track: marker_styles[i % len(marker_styles)] for i, track in enumerate(unique_tracks)}
    spike_labels = [f'AP {track-1}' if track != 1 else 'Stimulus' for track in unique_tracks]

    track_legend_elements = [
        Line2D([0], [0], color='black', marker=track_markers[track_num], linestyle='None', markersize=10, label=label)
        for track_num, label in zip(unique_tracks, spike_labels)
    ]

    for start_index in range(plotting_start_idx, plotting_start_idx+sample_size*num_of_plots, sample_size):
        end_index = start_index + sample_size
        timesamps_to_plot_np = np.asarray(original_timestamps_seq[start_index:end_index])
        samples_to_plot_np = original_samples_seq[start_index:end_index]
        probs_to_plot_np = prediction_sequence[start_index:end_index]

        plt.figure(figsize=(20, 6))

        # Define color bins based on val_probabilities_class1
        probability_map = np.digitize(probs_to_plot_np * 100, bins=bin_limits) - 1
        probability_map = np.clip(probability_map, 0, len(prob_bin_colors) - 1)

        ts_sample_points = np.array([timesamps_to_plot_np, samples_to_plot_np]).T.reshape(-1, 1, 2)
        segments = np.concatenate([ts_sample_points[:-1], ts_sample_points[1:]], axis=1)
        line_colors = [prob_bin_colors[probability_map[i]] for i in range(len(probability_map) - 1)]
        lc = LineCollection(segments, colors=line_colors, linewidths=2, alpha=0.6)
        ax = plt.gca()
        ax.add_collection(lc)

        ground_truth_spikes_to_plot = ground_truth_spikes_in_dataset[(ground_truth_spikes_in_dataset['ts'] >= timesamps_to_plot_np.min()) & 
                                              (ground_truth_spikes_in_dataset['ts'] <= timesamps_to_plot_np.max())]

        if ground_truth_spikes_to_plot.empty:
            print(f"No ground truth spikes present between timestamps {timesamps_to_plot_np.min()} - {timesamps_to_plot_np.max()}")
            plt.plot()

        y_top = samples_to_plot_np.max() * 1.1
        y_bottom = samples_to_plot_np.min() * 1.1

        track_marker_map = {track: marker_styles[i % len(marker_styles)] for i, track in enumerate(unique_tracks)}
        for index, spike_row in ground_truth_spikes_to_plot.iterrows():
            color = 'black'
            marker = track_marker_map.get(int(spike_row['track']), 'o')

            # vertical marks for the spikes
            plt.scatter(spike_row['ts'], y_top, color=color, marker=marker, s=120)
            plt.scatter(spike_row['ts'], y_bottom, color=color, marker=marker, s=120, label=spike_row['track'])
            plt.axvline(x=spike_row['ts'], color=color, linestyle='--', lw=0.5)
        
        y_min_all = original_samples_seq.min() * 1.2
        y_max_all = original_samples_seq.max() * 1.2
        plt.ylim(y_min_all, y_max_all)
        plt.xlim(timesamps_to_plot_np.min(), timesamps_to_plot_np.max())
        plt.legend(handles=legend_elements + track_legend_elements, title="Probability and AP", bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=14, title_fontsize=16)
        plt.xlabel('Timestamp', fontsize=14)
        plt.ylabel('Amplitude', fontsize=14)
        plt.grid(axis='y', linestyle=':', linewidth=1)
        plt.gca().xaxis.set_major_locator(mticker.MaxNLocator(nbins=50))
        plt.gca().xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:.4f}'))
        plt.xticks(rotation=45, ha='right', fontsize=12)
        plt.yticks(fontsize=12)
        plt.tight_layout()
        plt.grid(False)
        plt.show()
        plt.close()

plot_output_windows()


No ground truth spikes present between timestamps 2504.107794082538 - 2504.407694082538
