In [2]:
import os, sys
from typing import Callable
import torch
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.collections import LineCollection
import matplotlib.ticker as mticker
import torch.nn.functional as F


start_time = time.time()

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)


from MicroneurographyDataloader import *
from metrics import *
from _external.WHVPNet_pytorch.networks import *


In [2]:
def train(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, \
          n_epoch: int, optimizer: torch.optim.Optimizer, \
          criterion: Callable[[torch.Tensor | list[torch.Tensor]], torch.Tensor], decision_boundary: float = 0.5) -> None:
    n_digits = len(str(n_epoch))
    for epoch in range(n_epoch):
        total_loss = 0
        total_accuracy = 0
        total_length = 0

        total_true_positives = 0
        total_false_positives = 0
        total_false_negatives = 0
        total_true_negatives = 0

        for data in data_loader:
            input_data, binary_labels, multiple_labels = data
            optimizer.zero_grad()
            outputs = model(input_data)

            binary_classes = binary_labels.argmax(dim=-1)
            loss = criterion(outputs, binary_labels)
           
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
            predicted_labels = outputs[0] if isinstance(outputs, tuple) else outputs
            predicted_classes = (predicted_labels[:, 1] > decision_boundary).float()

            total_accuracy += (binary_classes == predicted_classes).sum().item()
            total_length += binary_labels.size(0)
            total_true_positives += ((binary_classes == 1) & (predicted_classes == 1)).sum().item()
            total_false_positives += ((binary_classes == 0) & (predicted_classes == 1)).sum().item()
            total_false_negatives += ((binary_classes == 1) & (predicted_classes == 0)).sum().item()
            total_true_negatives += ((binary_classes == 0) & (predicted_classes == 0)).sum().item()


        precision = total_true_positives / (total_true_positives + total_false_positives) if total_true_positives + total_false_positives > 0 else 0.0
        recall = total_true_positives / (total_true_positives + total_false_negatives) if total_true_positives + total_false_negatives > 0 else 0.0
        total_accuracy /= total_length / 100
        avg_loss = total_loss / len(data_loader)


        sensitivity = total_true_positives / (total_true_positives + total_false_negatives) if total_true_positives + total_false_negatives > 0 else 0.0
        specificity = total_true_negatives / (total_true_negatives + total_false_positives) if total_true_negatives + total_false_positives > 0 else 0.0
        balanced_accuracy = 0.5 * (sensitivity + specificity)

        print(f'Epoch: {epoch+1:0{n_digits}d} / {n_epoch}, '
              f'accuracy: {total_accuracy:.2f}%, loss: {total_loss:.4f}, '
              f'Precision: {precision:.4f}, Recall: {recall:.4f}, '
              f'Balanced Accuracy: {balanced_accuracy:.4f}')

In [3]:
def test(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, \
         criterion: Callable[[torch.Tensor | list[torch.Tensor]], torch.Tensor], decision_boundary: float = 0.5) -> tuple[float, float]:
    with torch.no_grad():
        total_loss = 0
        total_accuracy = 0
        total_number = 0

        all_binary_labels = []
        all_multiple_labels = []
        all_predicted_classes = []
        all_predicted_probabilities = []

        total_true_positives = 0
        total_false_positives = 0
        total_false_negatives = 0
        total_true_negatives = 0
        for data in data_loader:
            input_data, binary_labels, multiple_labels = data
            outputs = model(input_data)
            loss = criterion(outputs, binary_labels)
            total_loss += loss.item()
            
            binary_classes = binary_labels.argmax(dim=-1)
            predicted_labels = outputs[0] if isinstance(outputs, tuple) else outputs
            predicted_classes = (predicted_labels[:, 1] > decision_boundary).float()

            total_accuracy += (binary_classes == predicted_classes).sum().item()
            total_number += binary_labels.size(0)
            total_true_positives += ((binary_classes == 1) & (predicted_classes == 1)).sum().item()
            total_false_positives += ((binary_classes == 0) & (predicted_classes == 1)).sum().item()
            total_false_negatives += ((binary_classes == 1) & (predicted_classes == 0)).sum().item()
            total_true_negatives += ((binary_classes == 0) & (predicted_classes == 0)).sum().item()


            multiple_classes = multiple_labels.argmax(dim=-1)
            all_multiple_labels.append(multiple_classes.cpu())
            all_binary_labels.append(binary_classes.cpu())
            all_predicted_classes.append(predicted_classes.cpu())
            all_predicted_probabilities.append(predicted_labels.cpu())

        all_binary_labels = torch.cat(all_binary_labels)
        all_multiple_labels = torch.cat(all_multiple_labels)
        all_predicted_classes = torch.cat(all_predicted_classes)
        all_predicted_probabilities = torch.cat(all_predicted_probabilities)

        sensitivity = total_true_positives / (total_true_positives + total_false_negatives) if total_true_positives + total_false_negatives > 0 else 0.0
        specificity = total_true_negatives / (total_true_negatives + total_false_positives) if total_true_negatives + total_false_positives > 0 else 0.0
        balanced_accuracy = 0.5 * (sensitivity + specificity)

        compare_predictions_to_multilabels(all_binary_labels, all_multiple_labels, all_predicted_classes)

        total_accuracy /= total_number / 100
        print("=" * 40)
        print(f'Accuracy: {total_accuracy:.2f}%, loss: {total_loss:.4f}')
        print(f'Weighted Balanced Accuracy: {balanced_accuracy:.4f}')
        return total_accuracy, total_loss, all_binary_labels, all_multiple_labels, all_predicted_classes, all_predicted_probabilities

In [7]:
"""
Data generation and statistics with MNG dataloader
"""
MNG_dataloader = MicroneurographyDataloader()
filename = 'window_15_overlap_11_corrected_or.pkl'
MNG_dataloader.load_samples_and_labels_from_file(filename)
# MNG_dataloader.generate_raw_windows(window_size=20, overlapping=15)
# MNG_dataloader.generate_labels()
# MNG_dataloader.generate_labels_stimuli_relabel()
# MNG_dataloader.write_samples_and_labels_into_file(filename)

# MNG_dataloader.get_statistics_of_spikes()
# MNG_dataloader.plot_raw_data_window_by_label(0, 5)
MNG_dataloader.get_value_statistics_for_classes()
MNG_dataloader.plot_raw_data_window_by_label(1, 6)
# MNG_dataloader.plot_raw_data_window_by_label(2, 5)
# MNG_dataloader.plot_raw_data_window_by_label(3, 5)

  all_spike['track'] = all_spike['track'].replace(self.track_replacement_dict).infer_objects(copy=False)


Value statistics
Label 0: Overall Min = -10.230488, Overall Max = 9.450225
Label 1: Overall Min = -10.242657, Overall Max = 9.476206
Label 2: Overall Min = -9.928562, Overall Max = 9.464366
Label 3: Overall Min = -8.4738655, Overall Max = 8.722748


In [None]:
""" 
First get the statistics of spikes to see what is the minimum gap between two timestamps. If the window size is not bigger than that, there will be no occurence of two spikes being in one window.
"""
MNG_dataloader = MicroneurographyDataloader(raw_data_relative_path='../data/5_nerve/raw_data.csv',
                                            spikes_relative_path='../data/5_nerve/spike_timestamps.csv',
                                            stimulation_relative_path='../data/5_nerve/stimulation_timestamps.csv')
MNG_dataloader.get_statistics_of_spikes()

In [4]:
""" 
FULL TRAINING
"""
"""
    CHANGE the decision boundary for the metrics here.
"""
decision_boundary = 0.8
epoch = 10
lr = 0.01
dtype = torch.float64
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

"""
    CHANGE window size and overlapping size here.
"""
window_size = 13
overlapping_size = 10
""" 
    CHANGE the file name and path here for the already generated dataset.
    If the file does not exist, the overlapping windows will be generater now and saved to this file path.
"""
MNG_dataloader_filepath = f'window_{window_size}_overlap_{overlapping_size}_corrected_5nerve.pkl'

"""
CHANGE the file source paths here.
"""
MNG_dataloader = MicroneurographyDataloader(raw_data_relative_path='../data/5_nerve/raw_data.csv',
                                            spikes_relative_path='../data/5_nerve/spike_timestamps.csv',
                                            stimulation_relative_path='../data/5_nerve/stimulation_timestamps.csv')
""" 
CHANGE: this function prints the min distance between the spike instances. Useful for window size determination.
"""
MNG_dataloader.get_statistics_of_spikes()

full_path = os.path.join('../data', MNG_dataloader_filepath)
if os.path.exists(full_path):
    print("Dataset loading.")
    MNG_dataloader.load_samples_and_labels_from_file(MNG_dataloader_filepath)
else:
    print("Dataset generating.")
    MNG_dataloader.generate_raw_windows(window_size=window_size, overlapping=overlapping_size)
    """ 
    CHANGE If no relabel is necessary.
    """
    #dataSet.generate_labels()
    # on the first dataset with two nerves, negative_stimulus_limit=-10, positive_stimulus_limit=9 works great
    """ 
    CHANGE the filtering thresholds here.
    Unfortunately, it needs a little testing.
    the filter gets windows which has values {negative_stimulus_limit} {logigal_operator} {positive_stimulus_limit}
    For the first file, -10 and 9 was fine.
    For this file, -9 or 8 is fine.
    If after relabel, the number of stimulus relabels are approximately the same as the number of stimulus windows there was before, the relabel is correct.
    """
    MNG_dataloader.generate_labels_stimuli_relabel(logigal_operator="or")
    MNG_dataloader.write_samples_and_labels_into_file(MNG_dataloader_filepath)

""" 
    CHANGE the over and undersampling ratio here. 
"""
dataloaders = MNG_dataloader.sequential_split_with_resampling(batch_size=1024, minor_upsample_count=25000, major_downsample_count=75000)

""" 
CHANGE: this function prints the min and max value for every multi class label. Also useful for stimuli relabel.
"""
MNG_dataloader.get_value_statistics_for_classes()
""" 
CHANGE: this function prints how many labels are  for every multi class label.
If the number of stimulus labels are not in correlation with the number of other labels, the stimulus filter should be changed by the logical operator.
"""
MNG_dataloader.get_statistics_of_labels()


if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = False
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

# WHVPNet params
samples = torch.tensor(MNG_dataloader.raw_data_windows, dtype=torch.float64).unsqueeze(1)#.to(MNG_dataloader.device)
n_channels, n_in = samples[0].shape
n_out = len(MNG_dataloader.binary_labels_onehot[0])
num_VP_features = 6
num_weights = 4
fcn_neurons = 6
affin = torch.tensor([6 / n_in, -0.3606]).tolist()
weight = ((torch.rand(num_weights)-0.5)*8).tolist()


model = VPNet(n_in, n_channels, num_VP_features, VPTypes.FEATURES, affin + weight, WeightedHermiteSystem(n_in, num_VP_features, num_weights), [fcn_neurons], n_out, device=device, dtype=dtype)
""" 
CHANGE the weigths here for the class imbalances. They should sum up to 1, and should be in close correlation with the oversampling undersampling ratio.
"""
class_weights = torch.tensor([0.3, 0.7]).to(device)
weighted_criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion = VPLoss(weighted_criterion, 0.1)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
train(model, dataloaders['train_loader_under'], epoch, optimizer, criterion)

if isinstance(model, VPNet):
    print(*list(model.vp_layer.parameters()))

class_weights = torch.tensor([0.003, 0.997]).to(device)
weighted_criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion = VPLoss(weighted_criterion, 0.1)
print("=" * 40)
print("VALIDATION/TESTING:")
"""
CHANGE the dataloader on which the validation should be.
dataloaders['val_loader'] is for validation dataset.
dataloaders['test_loader'] If you finished with the training, use this test set for the function.
"""
accuracy, loss, all_binary_labels, all_multiple_labels, all_predicted_classes, all_predicted_probabilities = test(model, dataloaders['val_loader'], criterion, decision_boundary)
compute_common_metrics(all_binary_labels, all_predicted_classes)
compute_merged_metrics(all_binary_labels, all_predicted_classes)
create_decision_ceratinty_boxplots(all_binary_labels, all_multiple_labels, all_predicted_classes, all_predicted_probabilities)
print()

"""
CHANGE: model path for saving.
"""
#torch.save(model.state_dict(), f'_trained_models/widnow_{window_size_}_overlapping_{overlapping_size_}_hidden_{hidden1}_nweight_{weight_num}')

  all_spike['track'] = all_spike['track'].replace(self.track_replacement_dict).infer_objects(copy=False)


Minimum index gap between spike timestamps: 13
Maximum index gap between spike timestamps: 199982
Minimum time gap between spike timestamps: 0.0010999999931300408
Maximum time gap between spike timestamps: 4.000000000000455
Minimum gap between sampled values: 9.999999929277692e-05
Spike min gap / sample freq gap: 11.000000009094947
Dataset loading.
Value statistics
Label 0: Overall Min = -9.671367, Overall Max = 8.843775
Label 1: Overall Min = -9.671367, Overall Max = 8.832592
Label 2: Overall Min = -6.2966104, Overall Max = 8.561269
Label 3: Overall Min = -9.638477, Overall Max = 8.832592
Label 4: Overall Min = -9.666762, Overall Max = 8.836868
Label 5: Overall Min = -9.4043045, Overall Max = 8.84213
Label 6: Overall Min = -8.598845, Overall Max = 8.814833
Epoch: 01 / 10, accuracy: 78.78%, loss: 25.8074, Precision: 0.5898, Recall: 0.4971, Balanced Accuracy: 0.6909
Epoch: 02 / 10, accuracy: 89.23%, loss: 19.9364, Precision: 0.7676, Recall: 0.8167, Balanced Accuracy: 0.8671
Epoch: 03 / 


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.boxplot(
The palette list has fewer values (3) than needed (6) and will cycle, which may produce an uninterpretable plot.
  sns.boxplot(





'\nCHANGE: model path for saving.\n'

In [12]:
"""
Load pretrained model.
"""
"""
CHANGE the path of the trained model here.
"""
model_name = '_trained_models/widnow_15_overlapping_11_hidden_6_nweight_4_id_6'
"""
    CHANGE the parameters to which the pretrained model was trained.
"""
dtype = torch.float64
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
window_size = 15
overlapping_size = 11

MNG_dataloader = MicroneurographyDataloader()
path = f'window_{window_size}_overlap_{overlapping_size}_corrected.pkl'
MNG_dataloader.load_samples_and_labels_from_file(path)

dataloaders = MNG_dataloader.sequential_split_with_resampling(1024)

sample_windows = torch.tensor(MNG_dataloader.raw_data_windows, dtype=torch.float64).unsqueeze(1)
n_channels, n_in = sample_windows[0].shape
n_out = len(MNG_dataloader.binary_labels_onehot[0])
num_VP_features = 6
num_weights = 4
fcn_neurons = 6
affin = torch.tensor([6 / n_in, -0.3606]).tolist()
weight = ((torch.rand(num_weights)-0.5)*8).tolist()


model = VPNet(n_in, n_channels, num_VP_features, VPTypes.FEATURES, affin + weight, WeightedHermiteSystem(n_in, num_VP_features, num_weights), [fcn_neurons], n_out, device=device, dtype=dtype)
model.load_state_dict(torch.load(model_name, weights_only=True))

  all_spike['track'] = all_spike['track'].replace(self.track_replacement_dict).infer_objects(copy=False)


<All keys matched successfully>

In [None]:
"""
Evaluation.
"""
""" 
CHANGE: the dataset name to val for validation and test for testing. Adjust the decision boundary.
It can also be explanatory, the model does not need retraining only for a decision boundary adjustment.
"""
dataset_name='val' # 'or test
decision_boundary = 0.8
class_weights = torch.tensor([0.003, 0.997]).to(device)
weighted_criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion = VPLoss(weighted_criterion, 0.1)
accuracy, loss, binary_labels, multiple_labels, predicted_classes, predicted_probabilities = test(model, dataloaders[f'{dataset_name}_loader'], criterion, decision_boundary)
# compute_common_metrics(binary_labels, predicted_classes)
# compute_merged_metrics(binary_labels, predicted_classes)
# create_decision_ceratinty_boxplots(binary_labels, multiple_labels, predicted_classes, predicted_probabilities)

Label 0:
  TP: 0.0
  FN: 0.0
  FP: 15258.0
  TN: 387548.0

Label 1:
  TP: 259.0
  FN: 0.0
  FP: 0.0
  TN: 0.0

Label 2:
  TP: 255.0
  FN: 0.0
  FP: 0.0
  TN: 0.0

Label 3:
  TP: 236.0
  FN: 17.0
  FP: 0.0
  TN: 0.0

Accuracy: 96.22%, loss: 10.2869
Weighted Balanced Accuracy: 0.9700


In [6]:
"""
Get samples and timestamps from the dataloader.
"""

sample_windows = []
for data in dataloaders[f'{dataset_name}_loader']:
    x, labels, multiple = data
    sample_windows.append(x.cpu())
sample_windows = torch.cat(sample_windows).squeeze(1)
timestamp_windows = dataloaders[f'{dataset_name}_timestamps']

In [8]:
"""
Retrieve information for the output plot.
"""

def reconstruct_original_sequence(overlapping_windows, window_size, overlapping):
    """
    Transforms back the overlapping windows into a one dimensional sequence.
    """
    overlapping_windows = np.asarray(overlapping_windows)
    num_windows = len(overlapping_windows)
    stride = window_size - overlapping
    original_length = (num_windows - 1) * stride + window_size

    reconstructed_sequence = np.zeros(original_length)
    for i in range(num_windows):
        start_index = i * stride
        end_index = start_index + window_size
        reconstructed_sequence[start_index:end_index] = overlapping_windows[i]
    reconstructed_sequence = np.nan_to_num(reconstructed_sequence)
    return reconstructed_sequence


original_samples_seq = reconstruct_original_sequence(sample_windows, window_size, overlapping_size)
original_timestamps_seq = reconstruct_original_sequence(timestamp_windows, window_size, overlapping_size)

start_ts = original_timestamps_seq[0]
end_ts = original_timestamps_seq[-1]
start_index = MNG_dataloader.all_spikes_df['ts'].searchsorted(start_ts, side='left')
end_index = MNG_dataloader.all_spikes_df['ts'].searchsorted(end_ts, side='right')
ground_truth_spikes_in_dataset = MNG_dataloader.all_spikes_df.iloc[start_index:end_index]

In [9]:
def transform_predictions_to_sequence(overlapping_windows, window_size, overlapping):
    """
    Transforms the overlapping windows with prediction probabilities back into a one dimensional sequence.
    Every datapoint got different probabilities from the different windows it was present in.
    The sequence contains for every datapoint the average probs of all the probs it got in the repetitive windows.
    """
    predicted_probs_np = np.asarray(predicted_probabilities[:, 1])
    overlapping_windows = np.asarray(overlapping_windows)
    num_windows = len(overlapping_windows)
    stride = window_size - overlapping
    original_length = (num_windows - 1) * stride + window_size

    transformed_sequence = np.zeros(original_length)
    count_array = np.zeros(original_length)
    for i in range(num_windows):
        start_index = i * stride
        end_index = start_index + window_size
        transformed_sequence[start_index:end_index] += predicted_probs_np[i]
        count_array[start_index:end_index] += 1
    non_zero_count_mask = count_array > 0
    transformed_sequence[non_zero_count_mask] /= count_array[non_zero_count_mask] #avg of the predictions one datapoint got from all windows it was present in
    transformed_sequence = np.nan_to_num(transformed_sequence)
    return transformed_sequence

prediction_sequence = transform_predictions_to_sequence(timestamp_windows, 15, 11)

In [10]:
def plot_output_windows():
    """
    sample_size : how many datapoints should be included in one plot.
    plotting_start_idx = from where the plotting should start.
    num_of_plots = the number of plots the code should generate after each other.
    """
    sample_size = 3000
    plotting_start_idx = 181200
    num_of_plots = 2

    # probability bins for the coloring and legend
    prob_bin_colors = ['#C0C0C0', 'khaki', '#FFA500', '#FF0000']
    bin_limits = [0, 50, 80, 98, 100]
    percentage_labels = ['0-50%', '50-80%', '80-98%', '98-100%']
    legend_elements = [Line2D([0], [0], marker='o', color='w', label=label, markerfacecolor=prob_bin_colors[i], markersize=10)
                       for i, label in enumerate(percentage_labels)]

    # spikes for legend
    unique_tracks = MNG_dataloader.track_replacement_dict.values()
    marker_styles = ['s', '^', 'D', 'p', '*', 'h', 'H', '+', 'x', '|', '_', 'v', '<', '>', '8', 'P', 'X']
    track_markers = {track: marker_styles[i % len(marker_styles)] for i, track in enumerate(unique_tracks)}
    spike_labels = [f'AP {track-1}' if track != 1 else 'Stimulus' for track in unique_tracks]

    track_legend_elements = [
        Line2D([0], [0], color='black', marker=track_markers[track_num], linestyle='None', markersize=10, label=label)
        for track_num, label in zip(unique_tracks, spike_labels)
    ]

    for start_index in range(plotting_start_idx, plotting_start_idx+sample_size*num_of_plots, sample_size):
        end_index = start_index + sample_size
        timesamps_to_plot_np = np.asarray(original_timestamps_seq[start_index:end_index])
        samples_to_plot_np = original_samples_seq[start_index:end_index]
        probs_to_plot_np = prediction_sequence[start_index:end_index]

        plt.figure(figsize=(20, 6))

        # Define color bins based on val_probabilities_class1
        probability_map = np.digitize(probs_to_plot_np * 100, bins=bin_limits) - 1
        probability_map = np.clip(probability_map, 0, len(prob_bin_colors) - 1)

        ts_sample_points = np.array([timesamps_to_plot_np, samples_to_plot_np]).T.reshape(-1, 1, 2)
        segments = np.concatenate([ts_sample_points[:-1], ts_sample_points[1:]], axis=1)
        line_colors = [prob_bin_colors[probability_map[i]] for i in range(len(probability_map) - 1)]
        lc = LineCollection(segments, colors=line_colors, linewidths=2, alpha=0.6)
        ax = plt.gca()
        ax.add_collection(lc)

        ground_truth_spikes_to_plot = ground_truth_spikes_in_dataset[(ground_truth_spikes_in_dataset['ts'] >= timesamps_to_plot_np.min()) & 
                                              (ground_truth_spikes_in_dataset['ts'] <= timesamps_to_plot_np.max())]

        if ground_truth_spikes_to_plot.empty:
            print(f"No ground truth spikes present between timestamps {timesamps_to_plot_np.min()} - {timesamps_to_plot_np.max()}")
            plt.plot()

        y_top = samples_to_plot_np.max() * 1.1
        y_bottom = samples_to_plot_np.min() * 1.1

        track_marker_map = {track: marker_styles[i % len(marker_styles)] for i, track in enumerate(unique_tracks)}
        for index, spike_row in ground_truth_spikes_to_plot.iterrows():
            color = 'black'
            marker = track_marker_map.get(int(spike_row['track']), 'o')

            # vertical marks for the spikes
            plt.scatter(spike_row['ts'], y_top, color=color, marker=marker, s=120)
            plt.scatter(spike_row['ts'], y_bottom, color=color, marker=marker, s=120, label=spike_row['track'])
            plt.axvline(x=spike_row['ts'], color=color, linestyle='--', lw=0.5)
        
        y_min_all = original_samples_seq.min() * 1.2
        y_max_all = original_samples_seq.max() * 1.2
        plt.ylim(y_min_all, y_max_all)
        plt.xlim(timesamps_to_plot_np.min(), timesamps_to_plot_np.max())
        plt.legend(handles=legend_elements + track_legend_elements, title="Probability and AP", bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=14, title_fontsize=16)
        plt.xlabel('Timestamp', fontsize=14)
        plt.ylabel('Amplitude', fontsize=14)
        plt.grid(axis='y', linestyle=':', linewidth=1)
        plt.gca().xaxis.set_major_locator(mticker.MaxNLocator(nbins=50))
        plt.gca().xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:.4f}'))
        plt.xticks(rotation=45, ha='right', fontsize=12)
        plt.yticks(fontsize=12)
        plt.tight_layout()
        plt.grid(False)
        plt.show()
        plt.close()

plot_output_windows()


No ground truth spikes present between timestamps 2068.1215416 - 2068.4214416
No ground truth spikes present between timestamps 2068.4215416 - 2068.7214416
