# Imports

In [None]:
import torch
from datetime import datetime
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import numpy as np

import sys
sys.path.append('../')
from src.utils.data_preprocessing import load_processed_data
from src.utils.model_training import ModelTrainer
from src.models.vgg16_1D import VGG16_1D
from src.explainers.lrp import LRP_1D

# Variables

In [None]:
CLASS_NAMES = ['MI', 'NORM', 'OTHER']
PREPROCESSED_DIR = f"../data/preprocessed/"

# Load Data

In [None]:
X_train, y_train = load_processed_data(PREPROCESSED_DIR, "train", class_names=CLASS_NAMES)
X_val, y_val = load_processed_data(PREPROCESSED_DIR, "val", class_names=CLASS_NAMES)
X_test, y_test = load_processed_data(PREPROCESSED_DIR, "test", class_names=CLASS_NAMES)

In [None]:
splits = {'Train': y_train, 'Validation': y_val, 'Test': y_test}

for split_name, y_split in splits.items():
    y_arr = np.asarray(y_split)
    # Multi-Label: Summiere pro Klasse (Spalte)
    if y_arr.ndim == 2 and y_arr.shape[1] == len(CLASS_NAMES):
        counts = y_arr.sum(axis=0)
    else:
        counts = np.bincount(y_arr, minlength=len(CLASS_NAMES))
    perc = counts / counts.sum() * 100
    print(f"{split_name}:")
    for name, c, p in zip(CLASS_NAMES, counts, perc):
        print(f"  {name}: {int(c)} ({p:.2f}%)")
    print(f"  TOTAL: {int(counts.sum())} (100.00%)\n")

# Load Model

In [None]:
date = "2025_08_29"
model_path = f"../src/models/trained/vgg16_1d_trained_{date}.pth"
history_path = f"../data/results/training_history/training_history_{date}.csv"

In [None]:
# load existing model
ModelTrainer.load_model(model_path)
ModelTrainer.load_training_history(history_path)

# LRP

# Plot Signals with LRP