# Temporal TCN Model

For Example Testing please run the Section "Model Definition" and "Testing"

## Dataset Preparation

In [17]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import utilities, configuration
from sklearn.preprocessing import RobustScaler
from pickle import dump

# Group the dataframe by 'walk_id'
groups = configuration.df.groupby("walk_id")

# ===========================================================================
# Iterate through groups to create sequences and labels

# Initialize empty lists to store sequences, labels, and IDs
data_sequences = []
data_labels = []
data_ids = []

# Iterate through groups to create sequences and labels
dataset_number = 0
for name, group in groups:

        data_temp = group[["acc_x", "acc_y", "acc_z"]].to_numpy()  # Input data
        event_data = group["events"].to_numpy()  # Target labels

        event_data[event_data == 2] = 1  # 1,2 IC
        event_data[event_data == 3] = 2  # FC
        event_data[event_data == 4] = 2  # FC

        data_labels_temp = np.empty([2, len(event_data)])
        data_labels_temp[0] = np.where(event_data == 1, 1, 0)
        data_labels_temp[1] = np.where(event_data == 2, 1, 0)
        data_labels_temp = data_labels_temp.transpose(1, 0)
        
        # take data in account if longer than window_size
        if len(data_labels_temp) > configuration.window_size:

            # Apply sliding window
            seq, lbl = utilities.sliding_window(data_temp, data_labels_temp, configuration.window_size, configuration.stride)

            for s in seq:
                data_sequences.extend(torch.tensor(s, dtype=torch.float32))

            for l in lbl:
                data_labels.extend(torch.tensor(l, dtype=torch.float32))
                data_ids.extend([group["id"][group.first_valid_index()]] * len(l))

            dataset_number += 1
            
# ===========================================================================
# Stack data and labels
data_sequences = torch.vstack(data_sequences) # shape(samples,3)
data_labels = torch.vstack(data_labels)  # shape(samples,1)

# Split Datasets
train_index, test_index = next(configuration.group_split.split(data_sequences, data_labels, data_ids))

# ===========================================================================
# Prepare training datasets
X = data_sequences[train_index]
y = data_labels[train_index]
ids = [data_ids[i] for i in train_index]

# ===========================================================================
# Setup scaler only with training data; no scaling for class labels

# Convert the PyTorch tensor to a NumPy array
x_np = X.numpy()

# Reshape the array to 2D (necessary for the scaler)
n_samples, n_features = x_np.shape
x_np_reshaped = x_np.reshape(-1, n_features)

# Apply the RobustScaler
scaler_data = RobustScaler().fit(x_np_reshaped)

# Save scaler parameters
dump(scaler_data, open('checkpoints/scaler_temporal_input.pkl', 'wb'))

x_scaled_np_reshaped = scaler_data.transform(x_np_reshaped)

# Reshape the array back to its original shape
x_scaled_np = x_scaled_np_reshaped.reshape(n_samples,n_features)

# Convert the scaled array back to a PyTorch tensor
X = torch.from_numpy(x_scaled_np)

# Create Fullwalk Datasets for unseen Testing
# ===========================================================================
# Initialize empty lists to store sequences and labels
data_sequences = []
data_labels = []
test_ids_fullwalk = []
walk_ids_fullwalk = []
walk_speed = []

ids_test = [data_ids[i] for i in test_index]

# Filter by test_ids
df_test = configuration.df[configuration.df["id"].isin(ids_test)]

# Group by "walk_id"
df_test_groups = df_test.groupby("walk_id")

dataset_number = 0
for name, group in df_test_groups:

        data_temp = group[["acc_x", "acc_y", "acc_z"]].to_numpy()  # Input data
        event_data = group["events"].to_numpy()  # Target labels

        event_data[event_data == 2] = 1  # Right steps as class 2
        event_data[event_data == 3] = 2  # FC
        event_data[event_data == 4] = 2  # FC

        data_labels_temp = np.empty([2, len(event_data)])
        data_labels_temp[0] = np.where(event_data == 1, 1, 0)
        data_labels_temp[1] = np.where(event_data == 2, 1, 0)
        data_labels_temp = data_labels_temp.transpose(1, 0)

        data_sequences.append(scaler_data.transform(torch.tensor(data_temp, dtype=torch.float32)))
        data_labels.append(torch.tensor(data_labels_temp, dtype=torch.float32))
        test_ids_fullwalk.extend([group["id"][group.first_valid_index()]])
        walk_ids_fullwalk.append(group["walk_id"][group.first_valid_index()])
        walk_speed.extend([group["speed"][group.first_valid_index()]])

# Prepare Test Dataloader
test_dataset_fullwalk = utilities.CustomDataset(data_sequences, data_labels)

# Get the first three examples from the test dataset
testing_fullwalk_examples = [test_dataset_fullwalk[i] for i in range(3)]

# Dump the testing_fullwalk_examples into a pickle file
with open('example_data/testing_fullwalk_examples_temporal.pkl', 'wb') as f:
    dump(testing_fullwalk_examples, f)

test_loader_fullwalk = DataLoader(test_dataset_fullwalk)

## Model Definition

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import configuration
from torchsummary import summary
from torchview import draw_graph
import graphviz
graphviz.set_jupyter_format('svg')
graphviz.set_default_format('svg')

# Define a block of the Temporal Convolutional Network (TCN)
class TCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, dropout=0.2):
        super(TCNBlock, self).__init__()
        # Calculate padding to ensure the output has the same length as the input
        padding = (kernel_size - 1) * dilation // 2
        # Define the first convolutional layer with the given parameters
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=padding, dilation=dilation)
        # Define batch normalization to normalize the output of the convolutional layer
        self.bn1 = nn.BatchNorm1d(out_channels)
        # Define ReLU activation function
        self.relu = nn.ReLU()
        # Define dropout for regularization to prevent overfitting
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Apply convolutional layer
        out = self.conv1(x)
        # Apply batch normalization
        out = self.bn1(out)
        # Apply ReLU activation
        out = self.relu(out)
        # Apply dropout
        out = self.dropout(out)
        return out

# Define the overall TCN structure
class TCN(nn.Module):
    def __init__(self, input_dim, kernel_size=3, dilation_list=[1], dropout=0.2, out_channels=16):
        super(TCN, self).__init__()
        layers = []
        channels = input_dim
        # Create TCN blocks based on the dilation list
        for i, dilation in enumerate(dilation_list):
            layers.append(TCNBlock(channels, out_channels, kernel_size, dilation, dropout))
            # Update the number of input channels for the next layer
            channels = out_channels
        # Combine all TCN blocks into a sequential network
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        # Permute input to (batch_size, channels, seq_len) for Conv1d
        x = x.permute(0, 2, 1)
        # Pass input through the TCN network
        x = self.network(x)
        # Permute back to (batch_size, seq_len, channels) for further processing
        x = x.permute(0, 2, 1)
        return x

# Define the complete model structure
class TemporalTCNModel(nn.Module):
    def __init__(self, input_dim, num_classes, kernel_size=5, dilation_list=[1, 2, 4], dropout=0.2, NB_FILTERS=16):
        super(TemporalTCNModel, self).__init__()
        # Create the TCN part of the model
        self.tcn = TCN(input_dim=input_dim, kernel_size=kernel_size, dilation_list=dilation_list, dropout=dropout, out_channels=NB_FILTERS)
        # Define a 1x1 convolutional layer as a classifier
        self.classifier = nn.Conv1d(NB_FILTERS, num_classes, kernel_size=1)
        # Define a sigmoid activation function for the final output
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Pass input through the TCN
        x = self.tcn(x)
        # Permute to (batch_size, channels, seq_len) for the classifier
        x = x.permute(0, 2, 1)
        # Apply the classifier
        x = self.classifier(x)
        # Permute back to (batch_size, seq_len, channels) for final output
        x = x.permute(0, 2, 1)
        return x

# Function to create and return the model, moved to the given device
def get_model(device, input_dim=3, num_classes=2, kernel_size=5, dilation_list=[1, 2, 4], dropout=0.2):
    model = TemporalTCNModel(input_dim, num_classes, kernel_size, dilation_list, dropout, NB_FILTERS=16)
    return model.to(device)

# Get the model and move it to the specified device
model = get_model(configuration.device)

""" for inputs, labels in test_loader_fullwalk:
        summary(model.to(torch.device("cpu")),inputs.to(torch.device("cpu"),dtype=torch.float32), depth=10)
        model_graph = draw_graph(model.to(torch.device("cpu")) , inputs.to(torch.device("cpu"),dtype=torch.float32), device='meta', save_graph=True, filename='model_temporal_graph', depth=10)
        break """


## Model Training

In [None]:
from torch.utils.data import DataLoader
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.autonotebook import tqdm
import itertools
import configuration
import utilities  # Assuming utilities is a module containing required functions

%config InlineBackend.figure_formats = ['svg']

def custom_collate_fn(batch):
    """
    Custom collate function to separate sequences and labels.
    """
    sequences = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    return sequences, labels


def train_model(batch_size, learning_rate, epochs):
    """
    Train the model with the given hyperparameters and save it.

    Args:
    batch_size (int): Size of the batches used for training.
    learning_rate (float): Learning rate for the optimizer.
    epochs (int): Number of epochs to train the model.

    Returns:
    float: The average validation loss over all epochs.
    str: Filename where the model is saved.
    """
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(11))
    val_loss_overall = 0

    for epoch in tqdm(range(epochs), desc="Epochs", leave=False):
        model.train()
        split_count = 0
        val_loss_epoch = 0

        for train_index, val_index in configuration.group_kfold.split(X, y, ids):
            split_count += 1

            # Split the data into training and validation sets
            # stack in respect window_size
            X_train = [X[i] for i in train_index]
            X_train = [
                torch.stack(X_train[i : i + configuration.window_size])
                for i in range(0, len(X_train), configuration.window_size)
            ]

            X_val = [X[i] for i in val_index]
            X_val = [
                torch.stack(X_val[i : i + configuration.window_size])
                for i in range(0, len(X_val), configuration.window_size)
            ]

            y_train = [y[i] for i in train_index]
            y_train = [
                torch.stack(y_train[i : i + configuration.window_size])
                for i in range(0, len(y_train), configuration.window_size)
            ]

            y_val = [y[i] for i in val_index]
            y_val = [
                torch.stack(y_val[i : i + configuration.window_size])
                for i in range(0, len(y_val), configuration.window_size)
            ]

             # Create TensorDatasets
            train_dataset = utilities.CustomDataset(X_train, y_train)
            val_dataset = utilities.CustomDataset(X_val, y_val)

            # Create DataLoaders
            train_loader = DataLoader(
                train_dataset, batch_size=batch_size, shuffle=True
            )
            val_loader = DataLoader(val_dataset, batch_size=batch_size)

            total_train_loss_split = 0
            for sequences_batch, labels_batch in train_loader:
                model.train()

                # Transfer to GPU if available
                sequences_batch = sequences_batch.to(configuration.device)
                labels_batch = labels_batch.to(configuration.device)

                optimizer.zero_grad()
                output = model(sequences_batch)
                loss = loss_fn(output, labels_batch)
                loss.backward()
                optimizer.step()

                total_train_loss_split += loss.item()

            train_loss_split = total_train_loss_split / len(train_loader)

            # Validate the model
            model.eval()
            total_val_loss_split = 0
            with torch.no_grad():
                for sequences_batch, labels_batch in val_loader:
                    
                    # Transfer to GPU if available
                    sequences_batch = sequences_batch.to(configuration.device)
                    labels_batch = labels_batch.to(configuration.device)

                    output = model(sequences_batch)
                    loss = loss_fn(output, labels_batch)
                    total_val_loss_split += loss.item()

            avg_val_loss_split = total_val_loss_split / len(val_loader)
            val_loss_epoch += avg_val_loss_split

            print(
                f"Epoch {epoch + 1}/{epochs} Split {split_count}, Training Loss: {train_loss_split}, Validation Loss: {avg_val_loss_split}, Difference: {train_loss_split - avg_val_loss_split}"
            )

        val_loss_epoch /= split_count
        val_loss_overall += val_loss_epoch

        print(f"Epoch {epoch + 1}/{epochs}, Mean Validation Loss: {val_loss_epoch}")

    val_loss_overall /= epochs

    # Save the model with a unique filename
    os.makedirs("checkpoints", exist_ok=True)
    model_filename = (
        f"checkpoints/model_temporal_bs{batch_size}_lr{learning_rate}_epochs{epochs}.pt"
    )
    torch.save(model.state_dict(), model_filename)
    print(f"Model saved as {model_filename}")

    return val_loss_overall, model_filename


# Hyperparameter grid
batch_size_options = [4, 8, 12]
learning_rate_options = [1e-3, 1e-4, 1e-5]
epochs_options = [10, 30, 75]

# Placeholder for storing performance metrics along with hyperparameters and filenames
results = []

# Loop over the parameter grid
for batch_size, learning_rate, epochs in itertools.product(batch_size_options, learning_rate_options, epochs_options):
    print(f"Training model with batch size={batch_size}, learning rate={learning_rate}, epochs={epochs}")

    model_performance, model_filename = train_model(batch_size, learning_rate, epochs)
    results.append({
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "epochs": epochs,
        "performance": model_performance,
        "filename": model_filename
    })
    print(results)

# Select the best model
best_model = min(results, key=lambda x: x["performance"])
print("Best model parameters and performance:", best_model)
print(results)

# Optional: Train with specific parameters if required
# val_loss_overall, model_filename = train_model(12, 0.001, 30)

## Testing

In [None]:
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import torch
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import configuration
import pandas as pd
import pickle
from torch.utils.data import DataLoader

%config InlineBackend.figure_formats = ['svg']

# Load Example Data from Pickle (only for testing not for performance evaluation)
with open('example_data/testing_fullwalk_examples_temporal.pkl', 'rb') as f:
    test_loader_fullwalk_examples = pickle.load(f)
    test_loader_fullwalk = DataLoader(test_loader_fullwalk_examples)
    
    test_ids_fullwalk = [[1,1],[2,2],[3,3]]

# Define model path manually (optional)
model_filename = 'checkpoints/model_temporal_bs12_lr0.001_epochs30_freezing.pt'

# Load the model state from the specified checkpoint and set it to evaluation mode
model.load_state_dict(torch.load(model_filename))
model.eval()

# Initialize data iterator and other global variables
data_iterator = iter(test_loader_fullwalk)
current_index = -1
total_batches = len(test_loader_fullwalk)  # Assuming this is known or calculated
color_ic_fc = ["black", "orange"]
break_by_error = False
window_first_ic_matching = 20

def select_gt_pd(logits_torch, true_steps):
    """Select ground truth and predicted steps (IC and FC) from logits and true steps."""
    global break_by_error, current_index
    
    # Transform logits through sigmoid
    sigmoid_logits = torch.sigmoid(logits_torch)  

    # Rough estimate peak finding of predicted values IC / FC
    rough_peaks = {0: [], 1: []}
    for class_idx in range(sigmoid_logits.shape[1]):
        peaks, _ = find_peaks(
            sigmoid_logits[:, class_idx], 
            distance=configuration.rough_estimate_min_peak_distance
        )
        rough_peaks[class_idx] = peaks

    ic_pd_rough = rough_peaks[0]
    fc_pd_rough = rough_peaks[1]

    # Delete FCs before ICs
    while fc_pd_rough[0] < ic_pd_rough[0]:
        fc_pd_rough = fc_pd_rough[1:]

    # GT Calculations
    ic_gt_selected = np.where(true_steps[:, 0] == 1)[0]
    fc_gt_selected = np.where(true_steps[:, 1] == 1)[0]

    # Remove last FC from GT when after last IC
    if fc_gt_selected[-1] > ic_gt_selected[-1]:
        fc_gt_selected = fc_gt_selected[:-1]

    # PD Calculations
    fc_pd_selected = []
    ic_pd_selected = []

    ic_pd_rough_copy = ic_pd_rough.copy()
    fc_pd_rough_copy = fc_pd_rough.copy()

    # Find first best matching IC
    i_first_selected = 0
    i_first_select = False

    while not i_first_select:
        if abs(ic_pd_rough_copy[i_first_selected] - ic_gt_selected[0]) < abs(
            ic_pd_rough_copy[i_first_selected + 1] - ic_gt_selected[0]
        ):
            i_first_select = True
            ic_pd_selected.append(ic_pd_rough_copy[i_first_selected])
        else:
            i_first_selected += 1

    # Trim rough to the first matching IC
    ic_pd_rough_copy = ic_pd_rough_copy[i_first_selected:]
    ic_range_start = ic_pd_rough_copy[0]
    ic_pd_rough_copy = ic_pd_rough_copy[1:]

    while len(ic_pd_rough_copy) > 0 and len(fc_pd_rough_copy) > 0:
        if len(ic_pd_rough_copy) == 0:
            break

        # Look for next possible IC
        ic_range_stop = None

        while ic_range_stop is None:
            if len(ic_pd_rough_copy) == 0:
                ic_range_stop = len(logits_torch)
                break
            
            ic_range_stop_tmp = ic_pd_rough_copy[0]

            # Next IC which is after the next FC
            if ic_range_stop_tmp > ic_range_start and ic_range_stop_tmp > fc_pd_rough_copy[0]:
                ic_range_stop = ic_range_stop_tmp
                break

            ic_pd_rough_copy = ic_pd_rough_copy[1:]  # like pop()

        # Peak for FC in range of ic_range_start / ic_range_stop
        peak_fc_in_icrange, _ = find_peaks(
            torch.sigmoid(logits_torch[ic_range_start:ic_range_stop, 1]),
            distance=ic_range_stop - ic_range_start,
            height=0.1,
        )

        if len(peak_fc_in_icrange) > 1:
            break_by_error = True
            print("handle more than one PEAK for IC")

        elif len(peak_fc_in_icrange) == 1:
            peak_fc_in_icrange = ic_range_start + peak_fc_in_icrange[0]

        elif len(peak_fc_in_icrange) == 0:
            peak_fc_in_icrange = fc_pd_rough_copy[0]  # Fallback to rough estimate if not found

        fc_pd_selected.append(peak_fc_in_icrange)

        # Peak for IC
        fc_range_start = peak_fc_in_icrange
        fc_range_stop = None

        while fc_range_stop is None:
            if len(fc_pd_rough_copy) == 0:
                fc_range_stop = len(logits_torch)
                break
        
            fc_range_stop_tmp = fc_pd_rough_copy[0]

            if fc_range_stop_tmp > fc_range_start and fc_range_stop_tmp > ic_pd_rough_copy[0]:
                fc_range_stop = fc_range_stop_tmp
                break

            fc_pd_rough_copy = fc_pd_rough_copy[1:]  # like pop()

        peak_ic_in_fcrange, _ = find_peaks(
            torch.sigmoid(logits_torch[fc_range_start:fc_range_stop, 0]),
            distance=fc_range_stop - fc_range_start,
            height=0.1,
        )

        if len(peak_ic_in_fcrange) > 1:
            break_by_error = True
            print("handle more than one PEAK for FC")

        elif len(peak_ic_in_fcrange) == 1:
            peak_ic_in_fcrange = fc_range_start + peak_ic_in_fcrange[0]

        elif len(peak_ic_in_fcrange) == 0:
            try:
                peak_ic_in_fcrange = ic_pd_rough_copy[0]
            except:
                break

        ic_pd_rough_copy = ic_pd_rough_copy[1:]  # like pop
        ic_pd_selected.append(peak_ic_in_fcrange)

        # Clean up arrays for next cycle
        ic_range_start = peak_ic_in_fcrange

    # Return results of IC and FC calculations
    return (
        ic_pd_rough,
        fc_pd_rough,
        ic_pd_selected,
        fc_pd_selected,
        ic_gt_selected,
        fc_gt_selected,
    )

def update_status_label(current_index, total_batches):
    """Update the status label with current batch information."""
    status_label.value = f"Batch: {current_index + 1}/{total_batches} ID:{test_ids_fullwalk[current_index]}"

def plot_data(sample_sequence, true_steps, logits):
    """Plot accelerometer data and step detection results."""
    # Convert logits for uniform handling
    logits_torch = torch.squeeze(logits).cpu()
    logits = np.array(logits_torch)
    true_steps = true_steps.cpu().numpy().squeeze()
    time_steps = range(sample_sequence.size(1))
    accelerometer_data = sample_sequence.cpu().numpy().squeeze()

    fig, axes = plt.subplots(4, 1, figsize=(12, 10))
    fig.suptitle("Example Step Detection", fontsize=16)

    if logits.ndim == 1:
        logits = logits.reshape(-1, 1)
    time_steps = np.arange(logits.shape[0])

    for i, axis in enumerate(["X", "Y", "Z"]):
        axes[0].plot(time_steps, accelerometer_data[:, i], label=f"{axis}-axis")
    axes[0].set_title("Accelerometer Data")
    axes[0].legend()

    (
        ic_pd_rough,
        fc_pd_rough,
        ic_pd_gaitcycle_conform,
        fc_pd_gaitcycle_conform,
        ic_gts_selected,
        fc_gts_selected,
    ) = select_gt_pd(logits_torch, true_steps)

    axes[1].set_title("Step Detection / Ground truth")
    axes[1].set_xlim(axes[0].get_xlim())
    axes[1].set_ylim([0, 1])
    axes[1].set_ylabel("Ground truth")

    ic_gts = np.where(true_steps[:, 0] == 1)[0]
    fc_gts = np.where(true_steps[:, 1] == 1)[0]

    for ic_gt in ic_gts:
        axes[1].axvline(x=ic_gt, color=color_ic_fc[0])

    for peak in ic_gts_selected:
        axes[1].plot(peak, 0.95, "v", markersize=10, color=color_ic_fc[0], alpha=0.5)

    for fc_gt in fc_gts:
        axes[1].axvline(x=fc_gt, color=color_ic_fc[1])

    for peak in fc_gts_selected:
        axes[1].plot(peak, 0.95, "v", markersize=10, color=color_ic_fc[1], alpha=0.5)

    for peak in ic_pd_rough:
        axes[3].axvline(x=peak, linestyle=":", color=color_ic_fc[0])
        axes[2].axvline(x=peak, linestyle=":", color=color_ic_fc[0])
        axes[1].axvline(x=peak, linestyle=":", color=color_ic_fc[0])
        axes[0].axvline(x=peak, linestyle=":", color=color_ic_fc[0])

    for peak in fc_pd_rough:
        axes[3].axvline(x=peak, linestyle=":", color=color_ic_fc[1])
        axes[2].axvline(x=peak, linestyle=":", color=color_ic_fc[1])
        axes[1].axvline(x=peak, linestyle=":", color=color_ic_fc[1])
        axes[0].axvline(x=peak, linestyle=":", color=color_ic_fc[1])

    for peak in ic_pd_gaitcycle_conform:
        axes[1].plot(peak, 0.05, "^", markersize=10, color=color_ic_fc[0], alpha=0.5)

    for peak in fc_pd_gaitcycle_conform:
        axes[1].plot(peak, 0.05, "^", markersize=10, color=color_ic_fc[1], alpha=0.5)

    axes[2].plot(time_steps, logits[:, 0], label=f"IC Logits", color=color_ic_fc[0])
    axes[2].plot(time_steps, logits[:, 1], label=f"FC Logits", color=color_ic_fc[1])
    axes[2].set_title("Probability")
    axes[2].set_ylabel("Value")

    for class_idx in range(logits.shape[1]):
        sigmoid_logits = torch.sigmoid(
            logits_torch[:, class_idx]
        )  # Transform logits through sigmoid
        axes[3].plot(
            time_steps,
            sigmoid_logits,
            label=f"Class {class_idx} Sigmoid",
            color=color_ic_fc[class_idx],
        )
    axes[3].set_title("Logits with Sigmoid Function")
    axes[3].set_xlabel("Frames")

    plt.tight_layout()
    plt.show()

def load_and_plot(index_change):
    """Load data and plot based on index change."""
    global current_index, data_iterator
    current_index += index_change
    try:
        sample_sequence, true_steps = next(data_iterator)
    except StopIteration:
        # Reset the iterator
        data_iterator = iter(test_loader_fullwalk)
        sample_sequence, true_steps = next(data_iterator)
        if index_change < 0:
            for _ in range(len(test_loader_fullwalk) - 1):
                sample_sequence, true_steps = next(data_iterator)
    sample_sequence = sample_sequence.to(configuration.device, dtype=torch.float32)
    true_steps = true_steps.to(configuration.device, dtype=torch.float32)

    model.eval()
    with torch.no_grad():
        logits = model(sample_sequence)
    

    clear_output(wait=True)
    update_status_label(current_index, total_batches)
    plot_data(sample_sequence, true_steps, logits)
    display(widgets.VBox([status_label, buttons]))

# Buttons for navigation
prev_button = widgets.Button(description="Previous")
next_button = widgets.Button(description="Next")

def on_prev_button_clicked(b):
    """Move to the previous item."""
    load_and_plot(-1)

def on_next_button_clicked(b):
    """Move to the next item."""
    load_and_plot(1)

prev_button.on_click(on_prev_button_clicked)
next_button.on_click(on_next_button_clicked)

buttons = widgets.HBox([prev_button, next_button])
status_label = widgets.Label()
display(widgets.VBox([status_label, buttons]))

# Update status label initially
update_status_label(current_index, total_batches)

# Load and plot the first batch
load_and_plot(1)

#for i in range(1, total_batches):
#    load_and_plot(1)
#    if break_by_error:
#       break_by_error = False

## Performance Testing

In [None]:
import custom_statistics


np.seterr(divide="ignore", invalid="ignore")

model.eval()  # Set model to evaluation mode

# Initialize dictionaries for metrics and counts, separated by class
# Metrics for single step analysis

class_metrics = {
    0: {
        "accuracy": [],
        "precision": [],
        "recall": [],
        "f1_score": [],
        "misplacement_rel": [],
    },
    1: {
        "accuracy": [],
        "precision": [],
        "recall": [],
        "f1_score": [],
        "misplacement_rel": [],
    },
}


class_counts = {
    0: {"TP": [], "TN": [], "FP": [], "FN": []},
    1: {"TP": [], "TN": [], "FP": [], "FN": []},
}


gt_pd_storage = {"stride": {}, "swing": {}, "double": {}}

todelete = []


with torch.no_grad():

    count_analysing_ic_fc = 0

    count_act_dataset = -1

    # Loop test_loader_fullwalk

    for inputs, labels in test_loader_fullwalk:

        count_act_dataset += 1

        # predict outputs
        outputs = model(inputs.to(configuration.device, dtype=torch.float32))

        # move targets to cpu
        logits_torch = outputs.cpu().squeeze()
        true_steps = labels.cpu().squeeze()

        # get selected and rough ICs
        (
            ic_pd_rough,
            fc_pd_rough,
            ic_pd_selected,
            fc_pd_selected,
            ic_gt_selected,
            fc_gt_selected,
        ) = select_gt_pd(logits_torch, true_steps)

        # Convert result frames with events numbers to "ones" frames
        ic_predicted_ones = np.zeros_like(true_steps[:, 0])
        ic_predicted_ones[ic_pd_rough] = 1
        ic_gt_ones = np.zeros_like(true_steps[:, 0])
        ic_gt_ones[ic_gt_selected] = 1

        # Calculate with rough_estimates for ICs
        accuracy, precision, recall, f1_score, TP, TN, FP, FN = (
            custom_statistics.calculate_metrics_with_tolerance(
                ic_predicted_ones, ic_gt_ones
            )
        )

        # Store metrics and counts separately for each class
        class_metrics[0]["accuracy"].append(accuracy)
        class_metrics[0]["precision"].append(precision)
        class_metrics[0]["recall"].append(recall)
        class_metrics[0]["f1_score"].append(f1_score)
        class_counts[0]["TP"].append(TP)
        class_counts[0]["TN"].append(TN)
        class_counts[0]["FP"].append(FP)
        class_counts[0]["FN"].append(FN)

        # Convert result frames with events numbers to "ones" frames
        fc_predfcted_ones = np.zeros_like(true_steps[:, 1])
        fc_predfcted_ones[fc_pd_rough] = 1
        fc_gt_ones = np.zeros_like(true_steps[:, 1])
        fc_gt_ones[fc_gt_selected] = 1

        accuracy, precision, recall, f1_score, TP, TN, FP, FN = (
            custom_statistics.calculate_metrics_with_tolerance(
                fc_predfcted_ones, fc_gt_ones
            )
        )

        # Store metrics and counts separately for each class
        class_metrics[1]["accuracy"].append(accuracy)
        class_metrics[1]["precision"].append(precision)
        class_metrics[1]["recall"].append(recall)
        class_metrics[1]["f1_score"].append(f1_score)
        class_counts[1]["TP"].append(TP)
        class_counts[1]["TN"].append(TN)
        class_counts[1]["FP"].append(FP)
        class_counts[1]["FN"].append(FN)

        # Calculate mean metrics and total counts for each class
        mean_metrics = {}
        total_counts = {}

        # Calculations for IC
        if len(ic_pd_selected) == len(ic_gt_selected):

            # calculate misplacement
            misplacement_frames = ic_pd_selected - ic_gt_selected
            misplacement_frames_relativ = np.mean(misplacement_frames)
            class_metrics[0]["misplacement_rel"].append(misplacement_frames_relativ)

        # Calculations for FC
        if len(fc_pd_selected) == len(fc_gt_selected):

            # calculate misplacement
            misplacement_frames = fc_pd_selected - fc_gt_selected
            misplacement_frames_relativ = np.mean(misplacement_frames)
            class_metrics[1]["misplacement_rel"].append(misplacement_frames_relativ)

        """
        Calculations using IC and FC
        Stride Time
        begin at IC 3 and match it to the first IC
        cut the FC in respect to the maximum possible IC correlation (2 at the end does not have a matching IC)
        """
        if len(ic_gt_selected) == len(ic_pd_selected) and len(fc_gt_selected) == len(
            fc_pd_selected
        ):

            # Stride Time
            # ICn+2 - ICn
            stride_frames_predicted = (
                ic_pd_selected[2:] - np.roll(ic_pd_selected, 2)[2:]
            )

            stride_frames_gt = ic_gt_selected[2:] - np.roll(ic_gt_selected, 2)[2:]
            count_analysing_ic_fc += 1

            if test_ids_fullwalk[count_act_dataset] not in gt_pd_storage["stride"]:
                gt_pd_storage["stride"][test_ids_fullwalk[count_act_dataset]] = {
                    "gt": {
                        "slow": [],
                        "regular": [],
                        "fast": [],
                        "slow_walks": [],
                        "regular_walks": [],
                        "fast_walks": [],
                    },
                    "pd": {
                        "slow": [],
                        "regular": [],
                        "fast": [],
                        "slow_walks": [],
                        "regular_walks": [],
                        "fast_walks": [],
                    },
                }

            gt_pd_storage["stride"][test_ids_fullwalk[count_act_dataset]]["gt"][
                walk_speed[count_act_dataset]
            ].extend(stride_frames_gt)

            gt_pd_storage["stride"][test_ids_fullwalk[count_act_dataset]]["pd"][
                walk_speed[count_act_dataset]
            ].extend(stride_frames_predicted)

            gt_pd_storage["stride"][test_ids_fullwalk[count_act_dataset]]["gt"][
                walk_speed[count_act_dataset] + "_walks"
            ].append(stride_frames_gt)

            gt_pd_storage["stride"][test_ids_fullwalk[count_act_dataset]]["pd"][
                walk_speed[count_act_dataset] + "_walks"
            ].append(stride_frames_predicted)

            # Swing Time
            swing_frames_predicted = np.array(ic_pd_selected[1:]) - np.array(
                fc_pd_selected[: len(ic_pd_selected[1:])]
            )
            swing_frames_gt = np.array(ic_gt_selected[1:]) - np.array(
                fc_gt_selected[: len(ic_gt_selected[1:])]
            )
            if len(swing_frames_predicted) == len(swing_frames_gt):

                # alles unter 20% < löschen

                if all(
                    x > 20 for x in ((swing_frames_gt[1:] / stride_frames_gt) * 100)
                ):

                    if (
                        test_ids_fullwalk[count_act_dataset]
                        not in gt_pd_storage["swing"]
                    ):

                        gt_pd_storage["swing"][test_ids_fullwalk[count_act_dataset]] = {
                            "gt": {
                                "slow": [],
                                "regular": [],
                                "fast": [],
                                "slow_walks": [],
                                "regular_walks": [],
                                "fast_walks": [],
                            },
                            "pd": {
                                "slow": [],
                                "regular": [],
                                "fast": [],
                                "slow_walks": [],
                                "regular_walks": [],
                                "fast_walks": [],
                            },
                        }

                    gt_pd_storage["swing"][test_ids_fullwalk[count_act_dataset]]["gt"][
                        walk_speed[count_act_dataset]
                    ].extend(swing_frames_gt)

                    gt_pd_storage["swing"][test_ids_fullwalk[count_act_dataset]]["pd"][
                        walk_speed[count_act_dataset]
                    ].extend(swing_frames_predicted)

                    gt_pd_storage["swing"][test_ids_fullwalk[count_act_dataset]]["gt"][
                        walk_speed[count_act_dataset] + "_walks"
                    ].append(swing_frames_gt)
                    gt_pd_storage["swing"][test_ids_fullwalk[count_act_dataset]]["pd"][
                        walk_speed[count_act_dataset] + "_walks"
                    ].append(swing_frames_predicted)

                else:

                    print("to delete")

                    todelete.append(walk_ids_fullwalk[count_act_dataset])

            # Double support Time

            doublesupport_pd_min_length = min(len(fc_pd_selected), len(ic_pd_selected))

            doublesupport_frames_predicted = np.array(
                fc_pd_selected[:doublesupport_pd_min_length]
            ) - np.array(ic_pd_selected[:doublesupport_pd_min_length])

            doublesupport_gt_min_length = min(len(fc_gt_selected), len(ic_gt_selected))

            doublesupport_frames_gt = np.array(
                fc_gt_selected[:doublesupport_pd_min_length]
            ) - np.array(ic_gt_selected[:doublesupport_pd_min_length])

            if len(doublesupport_frames_predicted) == len(doublesupport_frames_gt):

                if all(
                    x < 30
                    for x in ((doublesupport_frames_gt[1:] / stride_frames_gt) * 100)
                ):

                    if (
                        test_ids_fullwalk[count_act_dataset]
                        not in gt_pd_storage["double"]
                    ):

                        gt_pd_storage["double"][
                            test_ids_fullwalk[count_act_dataset]
                        ] = {
                            "gt": {
                                "slow": [],
                                "regular": [],
                                "fast": [],
                                "slow_walks": [],
                                "regular_walks": [],
                                "fast_walks": [],
                            },
                            "pd": {
                                "slow": [],
                                "regular": [],
                                "fast": [],
                                "slow_walks": [],
                                "regular_walks": [],
                                "fast_walks": [],
                            },
                        }

                    gt_pd_storage["double"][test_ids_fullwalk[count_act_dataset]]["gt"][
                        walk_speed[count_act_dataset]
                    ].extend(doublesupport_frames_gt)

                    gt_pd_storage["double"][test_ids_fullwalk[count_act_dataset]]["pd"][
                        walk_speed[count_act_dataset]
                    ].extend(doublesupport_frames_predicted)

                    gt_pd_storage["double"][test_ids_fullwalk[count_act_dataset]]["gt"][
                        walk_speed[count_act_dataset] + "_walks"
                    ].append(doublesupport_frames_gt)

                    gt_pd_storage["double"][test_ids_fullwalk[count_act_dataset]]["pd"][
                        walk_speed[count_act_dataset] + "_walks"
                    ].append(doublesupport_frames_predicted)

                else:

                    print("to delete")

                    todelete.append(walk_ids_fullwalk[count_act_dataset])

        else:
            pass


"""

Statistics for IC / FC Classes
"""

print("============================================================================")

print("Statistic Results")

print(f"analysed full walks: {count_analysing_ic_fc}")

print("============================================================================")

print("")


for class_number in class_metrics:

    mean_metrics[class_number] = {
        metric: np.mean(values)
        for metric, values in class_metrics[class_number].items()
    }

    total_counts[class_number] = {
        count: np.sum(values) for count, values in class_counts[class_number].items()
    }


for class_number in range(2):

    print(f"CLASS:{class_number}")

    print("{:<10}                 {:<10}".format("Parameter", "Value"))

    print("======================================")

    for key, value in mean_metrics[class_number].items():

        print("{:<10}                 {:<10}".format(key, value))

    for key, value in total_counts[class_number].items():

        print("{:<10}                 {:<10}".format(key, value))

    print("\n\n")


""" 

Statistics for Gaitparameters
"""

print("============================================================================")

print("Gaitparameter")

print("============================================================================")

print("")


statistics = {
    "stride": {
        "mean": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "cv": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "asymmetry": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "rmserel": {"slow": {}, "regular": {}, "fast": {}},
    },
    "swing": {
        "mean": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "cv": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "asymmetry": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "rmserel": {"slow": {}, "regular": {}, "fast": {}},
    },
    "double": {
        "mean": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "cv": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "asymmetry": {
            "gt": {"slow": [], "regular": [], "fast": []},
            "pd": {"slow": [], "regular": [], "fast": []},
        },
        "rmserel": {"slow": {}, "regular": {}, "fast": {}},
    },
}


for parameter in gt_pd_storage:

    for proband in gt_pd_storage[parameter]:

        for sensor in gt_pd_storage[parameter][proband]:

            for speed in ["slow", "regular", "fast"]:

                values = gt_pd_storage[parameter][proband][sensor][speed]

                statistics[parameter]["mean"][sensor][speed].append(np.mean(values))

                statistics[parameter]["cv"][sensor][speed].append(
                    (np.std(values) / np.mean(values)) * 100
                )

                # Calculate RMSE for each parameter on proband level
                if sensor == "gt":

                    for values_gt_walk, values_pd_walk in zip(
                        gt_pd_storage[parameter][proband]["gt"][speed + "_walks"],
                        gt_pd_storage[parameter][proband]["pd"][speed + "_walks"],
                    ):
                        rmse = np.mean(
                            np.sqrt(np.mean((values_gt_walk - values_pd_walk) ** 2))
                        )
                        relative_rmse = (rmse / np.mean(values_gt_walk)) * 100

                        if proband not in statistics[parameter]["rmserel"][speed]:
                            statistics[parameter]["rmserel"][speed][proband] = []

                        statistics[parameter]["rmserel"][speed][proband].append(
                            relative_rmse
                        )

                # Asymmetry
                # Extract Left / Right

                meanValue0 = np.array(values[1::2])

                meanValue1 = np.array(values[0::2])

                min_length = min(len(meanValue0), len(meanValue1))

                meanValue0 = np.nanmean(meanValue0[:min_length])

                meanValue1 = np.nanmean(meanValue1[:min_length])

                if meanValue0 > meanValue1:

                    asymmetrie = 100 * (1 - abs(meanValue1 / meanValue0))
                else:

                    asymmetrie = 100 * (1 - abs(meanValue0 / meanValue1))

                statistics[parameter]["asymmetry"][sensor][speed].append(asymmetrie)

import scipy
import os
import scipy.stats as stats

for parameter in statistics:

    rmserel_values = {"slow": [], "regular": [], "fast": []}

    for speed in statistics[parameter]["rmserel"]:
        rmserel_values_tmp = []
        for proband in statistics[parameter]["rmserel"][speed]:
            rmserel_values_tmp.append(
                np.mean(statistics[parameter]["rmserel"][speed][proband])
            )

        rmserel_values[speed] = rmserel_values_tmp

        print(
            f"RMSE relative:: Parameter: {parameter}; Speed: {speed}; Probands: {len(rmserel_values_tmp)} ; Mean: {np.mean(rmserel_values_tmp):.4f}; Variance: {np.var(rmserel_values_tmp):.4f}; Std: {np.std(rmserel_values_tmp):.4f}"
        )

    for speed, data in zip(rmserel_values, rmserel_values.values()):
        shapiro_test = stats.shapiro(data)
        print(
            f"Shapiro-Wilk Test {speed}: Statistic={shapiro_test.statistic:.4f}, p-value={shapiro_test.pvalue:.4f}"
        )

    # Perform one-way ANOVA
    f_val, p_val = scipy.stats.f_oneway(
        rmserel_values["slow"], rmserel_values["regular"], rmserel_values["fast"]
    )

    print(
        f"Parameter: {parameter} ANOVA result: F-value = {f_val:.4f}, P-value = {p_val:.4f}"
    )

    plt.figure()
    plt.bar(
        rmserel_values.keys(),
        [np.mean(values) for values in rmserel_values.values()],
        yerr=[np.std(values) for values in rmserel_values.values()],
    )
    plt.xlabel("Speed")
    plt.ylabel("RMSE Relative")
    plt.ylim(0, 60)
    plt.title("RMSE Relative for Different Speeds")

    # Save the plot as svg in the rmse_figures folder
    plt.savefig(f"rmse_figures/rmse_plot_{parameter}.svg", format="svg")

    plt.show()
print("here")


def standard_output(metric, parameter):

    overall_gt = np.array(
        statistics[metric][parameter]["gt"]["fast"]
        + statistics[metric][parameter]["gt"]["slow"]
        + statistics[metric][parameter]["gt"]["regular"]
    )

    overall_pd = np.array(
        statistics[metric][parameter]["pd"]["fast"]
        + statistics[metric][parameter]["pd"]["slow"]
        + statistics[metric][parameter]["pd"]["regular"]
    )

    plt.figure()

    plt.scatter(overall_gt, overall_pd)

    # plt.show()

    print("============================================")

    print("Error, Correlation & Statistics")

    print("============================================")

    print("SLOW ============================================")

    print(
        custom_statistics.uniform_statistics(
            np.array(statistics[metric][parameter]["gt"]["slow"]),
            np.array(statistics[metric][parameter]["pd"]["slow"]),
        )
    )

    print("REGULAR ============================================")

    print(
        custom_statistics.uniform_statistics(
            np.array(statistics[metric][parameter]["gt"]["regular"]),
            np.array(statistics[metric][parameter]["pd"]["regular"]),
        )
    )

    print("FAST ============================================")

    print(
        custom_statistics.uniform_statistics(
            np.array(statistics[metric][parameter]["gt"]["fast"]),
            np.array(statistics[metric][parameter]["pd"]["fast"]),
        )
    )

    print("OVERALL ============================================")

    print(custom_statistics.uniform_statistics(overall_gt, overall_pd))


print(
    "================================================================================"
)

print("STRIDE")

print(
    "================================================================================"
)

standard_output("stride", "mean")

standard_output("stride", "cv")

print("asymmetrie ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

standard_output("stride", "asymmetry")


print(
    "================================================================================"
)

print("SWING")

print(
    "================================================================================"
)

standard_output("swing", "mean")

standard_output("swing", "cv")

print("asymmetrie ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

standard_output("swing", "asymmetry")


print(
    "================================================================================"
)

print("DOUBLE")

print(
    "================================================================================"
)

standard_output("double", "mean")

standard_output("double", "cv")

print("asymmetrie ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

standard_output("double", "asymmetry")
