In [None]:
import numpy as np
import torch
import random
from src.dataset.eeg_dataset import EEGDataset
from pathlib import Path
import pandas as pd
import os
from src.utils import Utils
from src.analysis.interpretability import compute_gradcam_1d

In [None]:
window_sizes = [128, 2 * 128, 5 * 128, 10 * 128, 30 * 128, 60*128]
step_sizes = [128, 2 * 128, 5 * 128, 10 * 128, 30 * 128, 60*128]

window_size = window_sizes[3]
step_size = step_sizes[3]
preprocessing = True
feature_selection = False
depth_of_anesthesia = True
for_majority = int(window_size / 2)
strategy = "rawEEG"  
sampling_rate = 128
training_ids=[1, 2, 3, 4, 6, 7, 12,]
random_seed=42

base_path = Path.cwd()
data_path = base_path / "data"

In [None]:
training_data = pd.read_hdf(data_path / f"training_data_{strategy}_{window_size}_{step_size}.h5", key='eeg_window')
test_data = pd.read_hdf(data_path / f"test_data_{strategy}_{window_size}_{step_size}.h5", key='eeg_window')
validation_data = pd.read_hdf(data_path / f"validation_data_{strategy}_{window_size}_{step_size}.h5", key='eeg_window')

In [None]:
utils = Utils(
    for_majority=for_majority,
    window_size=window_size,
    step_size=step_size,
    random_seed=random_seed,
    preprocessing=preprocessing,
    sampling_rate=sampling_rate,
    results_validation_csv_path=base_path
    / "doA_classification"
    / "ml_models"
    / "validation_results_df.csv",
    results_test_csv_path=base_path
    / "doA_classification"
    / "ml_models"
    / "test_results_df.csv",
    model_dir=base_path / "doA_classification" / "ml_models",
)

exclude_columns = ["Start", "End", "sleep"]
labels = ["sleep"]
if depth_of_anesthesia:
    exclude_columns.extend(["cr", "sspl", "burst_suppression"])
    labels_to_process = ["sleep", "cr", "sspl", "burst_suppression"]
    labels.extend(["cr", "sspl", "burst_suppression"])
else:
    labels_to_process = ["sleep"]

# Define features (excluding the necessary columns)
features = training_data.drop(columns=exclude_columns, errors="ignore").columns

In [None]:
# Create a new dictionary to store preprocessed data
preprocessed_data_dict = {}


for label in labels_to_process:
    print(f"Processing {label}...")
    # Preprocess data
    (
        X,
        y,
        X_val,
        y_val,
        X_test,
        y_test,
        train_loader_nn,
        val_loader_nn,
        test_loader_nn,
        input_size,
    ) = utils.preprocess_data(
        X=np.vstack(training_data['eeg_window'].values),
        y=training_data[label],
        X_val=np.vstack(validation_data['eeg_window'].values),
        y_val=validation_data[label],
        X_test=np.vstack(test_data['eeg_window'].values),
        y_test=test_data[label],
        batch_size=16,
        device="mps",
        strategy=strategy,
        classification_type=label,
        scaling_nn=False,
        imbalanced=True,
    )

    preprocessed_data_dict[label] = {
        "X": X,
        "y": y,
        "X_val": X_val,
        "y_val": y_val,
        "X_test": X_test,
        "y_test": y_test,
        "train_loader_nn": train_loader_nn,
        "val_loader_nn": val_loader_nn,
        "test_loader_nn": test_loader_nn,
        "input_size": input_size,
    }

print("Processing completed for all labels.")

In [None]:
from src.models.UCR import UCRResNet

task="burst_suppression"
model_name="UCRResNet"
model = UCRResNet(
        input_shape=1,
        n_feature_maps=64,
        nb_classes=1
    )
# Assuming model, visu_x, and other variables are already defined
model_filename = ( base_path / "doA_classification" / "ml_models" /
    f"rawEEG_{task}_ws{window_size}_ss{step_size}_"
    f"majority{for_majority}_type{model_name}_"
    f"preproc{preprocessing}_randomseed{random_seed}_all"
)



for idx in range(780,859):
    visu_x = preprocessed_data_dict["sleep"]["X_test"][idx]
    input_tensor = torch.Tensor(visu_x).unsqueeze(0)
    if input_tensor.dim() == 2:
        input_tensor = input_tensor.unsqueeze(1)

    model.load_state_dict(
        torch.load(f"{model_filename}.pt", map_location='mps')
    )
    model.eval()
    model.to('cpu')
    output = model(input_tensor)
    if output.item() > 0.9:
        print(idx)
    

In [None]:
# window 10
#idx = 227, 229, 232, 234-> burst
idx=827

print(preprocessed_data_dict["sleep"]["y_test"][idx])
print(preprocessed_data_dict["sspl"]["y_test"][idx])
print(preprocessed_data_dict["cr"]["y_test"][idx])
print(preprocessed_data_dict["burst_suppression"]["y_test"][idx])
visu_x = preprocessed_data_dict["sleep"]["X_test"][idx]

In [None]:
print(visu_x.shape)

In [None]:
stft_ws=64
stft_hop=16
window = torch.hann_window(stft_ws)
stft_result = torch.stft(
    torch.tensor(torch.tensor(visu_x)),
    n_fft=stft_ws,
    hop_length=stft_hop,
    win_length=stft_ws,
    return_complex=True,
    window=window,
)
#stft_result = torch.view_as_real(stft_result)
stft_result = 20 * torch.log10(stft_result + 1e-6)
print(stft_result.shape)
min_freq = 0.5
max_freq = 47
freqs = np.fft.rfftfreq(stft_ws, d=1 / 128)
print(freqs)
print(freqs.shape)
cropped_stft = stft_result[
    np.argmax(freqs >= min_freq) : np.argmax(freqs > max_freq)
]

print(stft_result.shape)
print(cropped_stft.shape)
cropped_stft = torch.view_as_real(cropped_stft)
#print(cropped_stft[:, :, 0].shape)
import matplotlib.pyplot as plt

# Plot the original image and the Grad-CAM heatmap
plt.figure()  # Adjusted figsize for vertical layout
plt.subplot(2, 1, 1)  # Changed to 2 rows, 1 column, 1st subplot
plt.imshow(cropped_stft[:, :, 0], cmap='viridis')
plt.title('Original Image')
#plt.axis('off')

plt.subplot(2, 1, 2)  # Changed to 2 rows, 1 column, 2nd subplot
plt.imshow(cropped_stft[:, :, 1], cmap='viridis')
plt.title('Grad-CAM Heatmap')
#plt.axis('off')



plt.tight_layout()
plt.show()


In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, BinaryClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import matplotlib.pyplot as plt
import torch
import numpy as np
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from src.models.UCR import UCRResNet

task="burst_suppression"

model_name="UCRResNet"
model = UCRResNet(
        input_shape=1,
        n_feature_maps=64,
        nb_classes=1
    )
# Assuming model, visu_x, and other variables are already defined
model_filename = ( base_path / "doA_classification" / "ml_models" /
    f"rawEEG_{task}_ws{window_size}_ss{step_size}_"
    f"majority{for_majority}_type{model_name}_"
    f"preproc{preprocessing}_randomseed{random_seed}_all"
)

input_tensor = torch.Tensor(visu_x).unsqueeze(0)
print(input_tensor.shape)
if input_tensor.dim() == 2:
    input_tensor = input_tensor.unsqueeze(1)

model.load_state_dict(
    torch.load(f"{model_filename}.pt", map_location='mps')
)
model.eval()
model.to('cpu')


input_tensor.requires_grad_(True)
output = model(input_tensor)
print("Model output:", output)
output.backward()
print("Input tensor gradients:", input_tensor.grad)


print(input_tensor.shape)
input_tensor.requires_grad_(True)

target_layer = model.conv9  # or any deeper layer
input_tensor.requires_grad_(True)
input_tensor = input_tensor.to('cpu')  # ensure same device
grayscale_cam = compute_gradcam_1d(model, input_tensor, target_layer, target_class=0)


print(grayscale_cam)

grayscale_cam_1d = grayscale_cam.flatten()

# Create a figure
fig, ax = plt.subplots(figsize=(10, 5))

# Create a set of line segments so that we can color them individually
points = np.array([np.arange(len(visu_x)), visu_x]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

# Create a LineCollection from the segments
lc = LineCollection(segments, cmap='viridis', norm=Normalize(vmin=grayscale_cam_1d.min(), vmax=grayscale_cam_1d.max()))
lc.set_array(grayscale_cam_1d)
lc.set_linewidth(2)
line = ax.add_collection(lc)

ax.set_xlim(0, len(visu_x))
ax.set_ylim(visu_x.min(), visu_x.max())
ax.set_title('EEG with GradCAM Overlay')
ax.set_xlabel('Time')
ax.set_ylabel('Amplitude')

# Add a colorbar for the line plot
cbar = fig.colorbar(line, ax=ax, label='Intensity')

plt.tight_layout()
plt.show()
