In [1]:
import os, sys
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

from Weighted_VP_model import *
sys.path.append(os.path.abspath('../Weighted_VP_model'))

from vpnet import *
from vpnet.vp_functions import *
from spike_classification import *

import numpy as np
import matplotlib.pyplot as plt

In [2]:
from data_handling import *

dtype = torch.float64
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
window_size_ = 15
overlapping_size_ = 11

dataSet = NeurographyDataset()
path = f'window_{window_size_}_overlap_{overlapping_size_}_corrected.pkl'
full_path = os.path.join('../data', path)
dataSet.load_samples_and_labels_from_file(path)

dataloaders = dataSet.random_split_undersampling()

n_channels, n_in = dataSet.samples[0].shape
n_out = len(dataSet.binary_labels_onehot[0])
hidden1 = 3
weight_num = 2
affin = torch.tensor([6 / n_in, -0.3606]).tolist()
#affin = torch.tensor([6 / n_in, -0.3606]).tolist()  #semioptimal
weight = ((torch.rand(weight_num)-0.5)*8).tolist()

model = VPNet(n_in, n_channels, hidden1, VPTypes.FEATURES, affin + weight, WeightedHermiteSystem(n_in, hidden1, weight_num), [hidden1], n_out, device=device, dtype=dtype)
model.load_state_dict(torch.load('trained_models/widnow_15_overlapping_11_hidden_3_nweight_2_id_1', weights_only=True))


  all_spike['track'] = all_spike['track'].replace(replacement_dict)


val timestamps shape (403573, 15)
val_samples shape torch.Size([403573, 1, 15])
class1 count 25000
class0 count 75000
Balanced shapes:  torch.Size([100000, 1, 15]) torch.Size([100000, 2]) torch.Size([100000, 4])
class1 count 25000
class0 count 75000
Balanced shapes:  torch.Size([100000, 1, 15]) torch.Size([100000, 2]) torch.Size([100000, 4])
Dataloaders are ready


<All keys matched successfully>

In [3]:
decision_boundary = 0.5
class_weights = torch.tensor([0.003, 0.997]).to(device)
weighted_criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion = VPLoss(weighted_criterion, 0.1)
val_accuracy, val_loss, test_labels, test_predictions, test_probabilities = test(model, dataloaders['val_loader'], criterion, decision_boundary)
compute_metrics(test_labels, test_predictions)


Label 0:
  True Positives (TP): 0.0
  False Negatives (FN): 0.0
  False Positives (FP): 37789.0
  True Negatives (TN): 365013.0

Label 1:
  True Positives (TP): 252.0
  False Negatives (FN): 11.0
  False Positives (FP): 0.0
  True Negatives (TN): 0.0

Label 2:
  True Positives (TP): 250.0
  False Negatives (FN): 5.0
  False Positives (FP): 0.0
  True Negatives (TN): 0.0

Label 3:
  True Positives (TP): 240.0
  False Negatives (FN): 13.0
  False Positives (FP): 0.0
  True Negatives (TN): 0.0

Shape of all_multiple_labels: (403573,)
Val accuracy: 90.63%, loss: 16.8474
           MODEL METRICS          
Precision : 0.0193
Recall    : 0.9624
F1-Score  : 0.0378
       CONFUSION MATRIX           
              Predicted
          365013    37789
Actual    29    742
ROC-AUC   : 0.9343


In [4]:
all_samples = []
for data in dataloaders['val_loader']:
    x, labels, multiple = data
    all_samples.append(x.cpu())
all_samples = torch.cat(all_samples).squeeze(1)

In [5]:
all_samples.shape

torch.Size([403573, 15])

In [6]:
dataloaders['val_timestamps'].flatten().shape

(6053595,)

In [7]:
print(test_labels.shape, test_predictions.shape, test_probabilities.shape, all_samples.shape)

torch.Size([403573]) torch.Size([403573]) torch.Size([403573, 2]) torch.Size([403573, 15])


In [8]:
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
# Flatten the samples back into continuous data
flattened_samples = all_samples.flatten()

# Define the number of windows to plot at a time
windows_to_plot = 5
window_size = all_samples.shape[1]  # Number of points in each window

# Extract probabilities from the second column (corresponding to the class of interest)
probabilities = test_probabilities[:, 1]

# Define the color mapping based on the percentage categories (10 colors)
cmap = plt.get_cmap('tab10')  # 10 distinct colors
color_bins = np.digitize(probabilities * 100, bins=np.arange(0, 101, 10)) - 1  # 0-9%, 10-19%, ..., 90-100%

# Total number of windows
total_windows = all_samples.shape[0]

# Create a list of colors and percentage labels for the legend
percentage_labels = [f'{i*10}-{(i+1)*10-1}%' for i in range(10)]
legend_elements = [Line2D([0], [0], marker='o', color='w', label=label, 
                          markerfacecolor=cmap(i), markersize=10) for i, label in enumerate(percentage_labels)]

# Loop to plot windows in chunks of 5
for i in range(0, windows_to_plot, windows_to_plot):
    # Get the subset of windows to plot
    start = i * window_size
    end = min((i + windows_to_plot) * window_size, flattened_samples.shape[0])
    
    # Get the probabilities and corresponding color bin for the current windows
    window_probabilities = probabilities[i:i + windows_to_plot]
    window_colors = color_bins[i:i + windows_to_plot]
    
    # Create scatter plot for this range of windows
    plt.figure(figsize=(10, 4))
    
    # For each window, scatter plot the points and assign color based on its probability category
    for j in range(windows_to_plot):
        window_start = start + j * window_size
        window_end = window_start + window_size
        
        plt.scatter(range(window_start, window_end),
                    flattened_samples[window_start:window_end],
                    c=[cmap(window_colors[j])] * window_size,  # Same color for all points in a window
                    alpha=0.6, s=10)
    
    # Add legend for probability categories
    plt.legend(handles=legend_elements, title="Probability (%)", bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.title(f"Scatter Plot of Windows {i} to {i + windows_to_plot - 1}")
    plt.xlabel('Sample index')
    plt.ylabel('Amplitude')
    plt.tight_layout()
    plt.show()


In [9]:
df = dataSet.all_differentiated_spikes
# Flatten the samples back into continuous data
flattened_samples = all_samples.flatten()

# Define the number of windows to plot at a time
windows_to_plot = 15
window_size = all_samples.shape[1]  # Number of points in each window

# Extract probabilities from the second column (corresponding to the class of interest)
probabilities = test_probabilities[:, 1]

# Define the color mapping based on the percentage categories (10 colors)
cmap = plt.get_cmap('tab10')  # 10 distinct colors
color_bins = np.digitize(probabilities * 100, bins=np.arange(0, 101, 10)) - 1  # 0-9%, 10-19%, ..., 90-100%

# Define track colors
track_colors = {0: 'red', 1: 'green', 2: 'blue', 3: 'purple'}

# Create a list of colors and percentage labels for the legend
percentage_labels = [f'{i*10}-{(i+1)*10-1}%' for i in range(10)]
legend_elements = [Line2D([0], [0], marker='o', color='w', label=label, 
                          markerfacecolor=cmap(i), markersize=10) for i, label in enumerate(percentage_labels)]

# Add track legend elements
track_legend_elements = [Line2D([0], [0], color=track_colors[i], lw=2, label=f'Track {i}') for i in track_colors]

# Loop to plot windows in chunks of 5
for i in range(0, windows_to_plot*10, windows_to_plot):
    # Get the subset of windows to plot
    start = i * window_size
    end = min((i + windows_to_plot) * window_size, flattened_samples.shape[0])
    
    # Get the probabilities and corresponding color bin for the current windows
    window_probabilities = probabilities[i:i + windows_to_plot]
    window_colors = color_bins[i:i + windows_to_plot]
    
    # Create scatter plot for this range of windows
    plt.figure(figsize=(10, 4))
    
    # For each window, scatter plot the points and assign color based on its probability category
    for j in range(windows_to_plot):
        window_start = start + j * window_size
        window_end = window_start + window_size
        
        plt.scatter(range(window_start, window_end),
                    flattened_samples[window_start:window_end],
                    c=[cmap(window_colors[j])] * window_size,  # Same color for all points in a window
                    alpha=0.6, s=10)
    
    # Plot vertical lines based on dataframe timestamps and track values
    window_timestamps = dataloaders['val_timestamps'][i:i + windows_to_plot]  # Select the relevant window timestamps
    for _, row in df.iterrows():
        # Check if the timestamp from the dataframe is within the current windows
        for j in range(windows_to_plot):
            if row['ts'] >= window_timestamps[j].min() and row['ts'] <= window_timestamps[j].max():
                # Find the index of the timestamp in the window
                timestamp_index_within_window = np.argmin(np.abs(window_timestamps[j] - row['ts']))
                global_index = start + j * window_size + timestamp_index_within_window
                # Plot the vertical line with the color corresponding to the track value
                plt.axvline(x=global_index, color=track_colors[row['track']], linestyle='--', lw=2)
    
    # Add legends
    plt.legend(handles=legend_elements + track_legend_elements, title="Probability (%) and Track", bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.title(f"Scatter Plot of Windows {i} to {i + windows_to_plot - 1}")
    plt.xlabel('Sample index')
    plt.ylabel('Amplitude')
    plt.tight_layout()
    plt.show(block=True)

In [10]:
df = dataSet.all_differentiated_spikes
first_matching_window_index = None  # Initialize variable to hold the index of the first matching window
counter = 0
for i in range(len(dataloaders['val_timestamps'])):  # Loop through all windows
    window_timestamps = dataloaders['val_timestamps'][i]  # Get the current window's timestamps
    # Check for matches between the window's timestamps and the DataFrame
    if any(np.isin(window_timestamps, df['ts'].values)):  # Check for any matches
        first_matching_window_index = i  # Set the index of the first matching window
        print(i)
        counter += 1
        if counter == 10:
            break

5947
5948
5949
5950
5971
5972
5973
5974
15947
15948


In [11]:
#overlapping not handled
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

df = dataSet.all_differentiated_spikes
flattened_samples = all_samples.flatten()
print(flattened_samples.shape)

# Define the number of windows to plot at a time
windows_to_plot = 100
window_size = all_samples.shape[1]  # Number of points in each window

# Extract probabilities from the second column (corresponding to the class of interest)
probabilities = test_probabilities[:, 1]

# Define the color mapping based on 4 distinct categories
cmap = plt.get_cmap('tab10', 4)  # Using a colormap that supports 4 colors
color_bins = np.digitize(probabilities * 100, bins=np.linspace(0, 100, 5)) - 1  # 0-24%, 25-49%, 50-74%, 75-100%

# Define track colors
track_colors = {0: 'red', 1: 'green', 2: 'blue', 3: 'purple'}  # Ensure to include all track values

# Create a list of colors and percentage labels for the legend
percentage_labels = [f'{i * 25}-{(i + 1) * 25 - 1}%' for i in range(4)]  # Update for 4 categories
legend_elements = [Line2D([0], [0], marker='o', color='w', label=label, 
                          markerfacecolor=cmap(i), markersize=10) for i, label in enumerate(percentage_labels)]

# Add track legend elements
track_legend_elements = [Line2D([0], [0], color=track_colors[i], lw=2, label=f'Track {i}') for i in track_colors]

# Define the index of the first matching timestamp window
first_matching_index = 5947
first_matching_index = 15947

# Calculate the start index for plotting windows around the first matching index
start_index = max(0, first_matching_index - windows_to_plot // 2)  # Start plotting a few windows before the matching index
end_index = start_index + windows_to_plot  # Number of windows to plot

# Create a scatter plot for the windows around the first matching timestamp
plt.figure(figsize=(18, 6))

# Loop to plot windows in the defined range
for i in range(start_index, end_index):
    # Get the subset of windows to plot
    start = i * window_size
    window_end = start + window_size

    # Get the probabilities and corresponding color bin for the current window
    window_probabilities = probabilities[i]
    window_colors = color_bins[i]
    
    # Scatter plot for this window
    plt.scatter(range(start, window_end),
                flattened_samples[start:window_end],
                c=[cmap(window_colors)] * window_size,  # Same color for all points in a window
                alpha=0.6, s=10)

    # Plot vertical lines based on dataframe timestamps and track values
    window_timestamps = dataloaders['val_timestamps'][i:i + 1]  # Select the relevant window timestamps
    for _, row in df.iterrows():
        # Check if the timestamp from the dataframe is within the current window
        if row['ts'] >= window_timestamps.min() and row['ts'] <= window_timestamps.max():
            # Find the index of the timestamp in the window
            timestamp_index_within_window = np.argmin(np.abs(window_timestamps - row['ts']))
            global_index = start + timestamp_index_within_window
            # Plot the vertical line with the color corresponding to the track value
            plt.axvline(x=global_index, color=track_colors[int(row['track'])], linestyle='--', lw=2)  # Convert to int

# Add legends
plt.legend(handles=legend_elements + track_legend_elements, title="Probability (%) and Track", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.title(f"Scatter Plot of Windows Around Index {first_matching_index}")
plt.xlabel('Sample index')
plt.ylabel('Amplitude')
plt.tight_layout()
plt.show(block=True)
plt.close()


torch.Size([6053595])


In [12]:
def reconstruct_original_sequence(overlapping_windows, window_size, overlapping):
    # Ensure the input is a NumPy array
    overlapping_windows = np.asarray(overlapping_windows)

    # Calculate the original length
    num_windows = len(overlapping_windows)
    stride = window_size - overlapping
    original_length = (num_windows - 1) * stride + window_size
    
    # Initialize an array to store the reconstructed sequence
    reconstructed_sequence = np.zeros(original_length)
    
    # Create an array to track how many times each index has been populated
    count_array = np.zeros(original_length)

    # Loop through the windows to reconstruct the sequence
    for i in range(num_windows):
        start_index = i * stride
        end_index = start_index + window_size
        
        # Add the window values to the appropriate positions in the reconstructed sequence
        reconstructed_sequence[start_index:end_index] += overlapping_windows[i]
        
        # Increment the count for the positions populated by this window
        count_array[start_index:end_index] += 1

    # Divide by the count to fill in the overlapping parts correctly
    non_zero_count_mask = count_array > 0
    reconstructed_sequence[non_zero_count_mask] /= count_array[non_zero_count_mask]
    
    # Fill NaN values if any (can happen at the beginning or end)
    reconstructed_sequence = np.nan_to_num(reconstructed_sequence)

    return reconstructed_sequence
original_val_samples = reconstruct_original_sequence(all_samples, 15, 11)
original_val_ts = reconstruct_original_sequence(dataloaders['val_timestamps'], 15, 11)
original_val_samples.shape, original_val_ts.shape

((1614303,), (1614303,))

In [13]:
start_ts = original_val_ts[0]
end_ts = original_val_ts[-1]

# Find the first row index in df where 'ts' is greater than or equal to start_ts
start_index = df['ts'].searchsorted(start_ts, side='left')

# Find the last row index in df where 'ts' is less than or equal to end_ts
end_index = df['ts'].searchsorted(end_ts, side='right')

# Slice df to only include rows within the range of original_val_ts timestamps
df_val_range = df.iloc[start_index:end_index]

In [14]:
first_ts_df_val_range = df_val_range['ts'].iloc[0]

# Convert dataloaders['val_timestamps'] to a NumPy array if it isn't one already
val_timestamps_array = np.array(dataloaders['val_timestamps'])

closest_index = np.argmin(np.abs(val_timestamps_array - first_ts_df_val_range))

# Retrieve the closest timestamp value for reference
closest_ts = val_timestamps_array[closest_index]
closest_index


80609

In [37]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.ticker as mticker

# Assuming dataloaders and original_val_ts are already defined
df = dataSet.all_differentiated_spikes
ts_windows_array = np.array(dataloaders['val_timestamps'])
total_windows = len(original_val_ts)
probabilities = test_probabilities[:, 1]


# Initialize a list to store the highest probabilities
# Define the index of the first matching timestamp window
sample_size = 3000  # Number of windows to plot at a time
custom_colors = ['#00008B', '#ADD8E6', '#FFFF00', '#FFA500', '#FF0000']  # Adjust colors as needed

# Adjust bins for custom ranges
bins = [0, 25, 50, 70, 90, 100]  # Custom bin edges

# Set up labels for the new ranges
percentage_labels = ['0-25%', '25-50%', '50-70%', '70-90%', '90-100%']
legend_elements = [Line2D([0], [0], marker='o', color='w', label=label,
                          markerfacecolor=custom_colors[i], markersize=10) for i, label in enumerate(percentage_labels)]
track_colors = { 1: 'black', 2: 'yellow', 3: 'green'}

for start_index in range(21300, 21300+sample_size*1, sample_size):
    highest_probabilities = []
    end_index = start_index + sample_size
    # Get relevant timestamps and their corresponding probabilities
    for ts in original_val_ts[start_index:end_index]:
        # Create a boolean mask for the current timestamp
        window_mask = np.any((ts_windows_array == ts), axis=1)
        window_indices = np.where(window_mask)[0]

        if window_indices.size > 0:
            probs = probabilities[window_indices].numpy()
            highest_probability = np.mean(probs)
        else:
            highest_probability = np.nan

        highest_probabilities.append(highest_probability)

    # Create a DataFrame to store the results
    df_orig = pd.DataFrame({
        'Timestamp': original_val_ts[start_index:end_index],
        'Samples': original_val_samples[start_index:end_index],
        'Probability': highest_probabilities
    })

    # Plotting
    plt.figure(figsize=(20, 6))

    # Define color bins based on probabilities
    color_bins = np.digitize(df_orig['Probability'] * 100, bins=bins) - 1

    track_legend_elements = [Line2D([0], [0], color=track_colors[i], lw=2, label=f'Track {i}') for i in track_colors]

    # Scatter plot for this window
    plt.scatter(df_orig['Timestamp'],
                df_orig['Samples'],
                c=[custom_colors[color_bins[i]] for i in range(len(color_bins))],  # Color according to bins
                alpha=0.6, s=10)

    # Add vertical lines based on DataFrame timestamps and track values
    window_timestamps = df_orig['Timestamp'].values
    print("Plotting range:", window_timestamps.min(), window_timestamps.max())
    print("First spike timestamp in df_val_range:", first_ts_df_val_range)
    filtered_df_val_range = df_val_range[(df_val_range['ts'] >= window_timestamps.min()) & 
                                     (df_val_range['ts'] <= window_timestamps.max())]

    # Iterate over the filtered DataFrame
    y_top = df_orig['Samples'].max() * 1.1  # Slightly above the max sample value
    y_bottom = df_orig['Samples'].min() * 1.1
    for index, row in filtered_df_val_range.iterrows():
        color = track_colors.get(int(row['track']), 'black')  # Default to black if track color is not found
        
        # Plot a circle at the top position
        plt.scatter(row['ts'], y_top, color=color, s=100, alpha=0.8, edgecolor='black', linewidth=0.5)

        # Plot a circle at the bottom position
        plt.scatter(row['ts'], y_bottom, color=color, s=100, alpha=0.8, edgecolor='black', linewidth=0.5)
    #for index, row in filtered_df_val_range.iterrows():
    #    plt.axvline(x=row['ts'], color=track_colors[int(row['track'])], linestyle='--', lw=2)

    # Add legends
    plt.legend(handles=legend_elements + track_legend_elements, title="Probability (%) and Track", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xlabel('Sample index')
    plt.ylabel('Amplitude')
    plt.gca().xaxis.set_major_locator(mticker.MaxNLocator(nbins=30))  # Increase the number of x-ticks
    plt.gca().xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{x:.8f}'))
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show(block=True)
    plt.close()


Plotting range: 2487.817794082538 2488.117694082538
First spike timestamp in df_val_range: 2487.838394082541


In [26]:
import matplotlib.pyplot as plt

# Assuming `original_val_samples` and `original_val_ts` are already defined
# Select the first 45 values
df = dataSet.all_differentiated_spikes
num_values_to_plot = 100

# Ensure the original sequences have enough data
if len(original_val_samples) < num_values_to_plot or len(original_val_ts) < num_values_to_plot:
    raise ValueError("Not enough data to plot the specified number of values.")

# Get the data to plot
samples_to_plot = original_val_samples[:num_values_to_plot]
timestamps_to_plot = original_val_ts[:num_values_to_plot]

# Create a scatter plot
plt.close()
plt.figure(figsize=(16, 6))
plt.scatter(timestamps_to_plot, samples_to_plot, color='blue', alpha=0.6)
plt.title('Scatter Plot of Original Value Samples')
plt.xlabel('Timestamps')
plt.ylabel('Samples')
plt.grid(True)
plt.show()