<a href="https://colab.research.google.com/github/AlonResearch/SNN-for-MI-EEG/blob/main/MI3_SNNforMIeeg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
import numpy as np
import torch
import torch.utils.data as da
from torch import nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import scipy.io as scio
!pip install spikingjelly
from spikingjelly.activation_based import ann2snn
import matplotlib.pyplot as plt
import graphviz




In [28]:

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
!nvidia-smi

Using cpu device
/bin/bash: line 1: nvidia-smi: command not found


In [29]:
# @title Data loader and other functions
#Defining functions

def data_loader(data, label, batch=64, shuffle=True, drop=False):
    """
    Preprocess the data to fit model.
    Feed data into data_loader.
    input:
        data (float): samples*length*ch (samples*ch*length).
        label (int): samples, ie.: [0, 1, 1, 0, ..., 2].
        batch (int): batch size
        shuffle (bool): shuffle data before input into decoder
        drop (bool): drop the last samples if True
    output:
        data loader
    """
    label = torch.LongTensor(label.flatten()).to(device)
    if data.shape[1] >= data.shape[2]:
        data = torch.tensor(data.swapaxes(1, 2))
    data = torch.unsqueeze(data, dim=1).type('torch.FloatTensor').to(device)
    data = da.TensorDataset(data, label)
    loader = da.DataLoader(dataset=data, batch_size=batch, shuffle=shuffle, drop_last=drop)
    return loader


def val_snn(Dec, test_loader, T=None):
    Dec.eval().to(device)
    correct = 0
    total = 0
    if T is not None:
        corrects = np.zeros(T)
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if T is None:
                outputs = Dec(inputs)
                correct += (outputs.argmax(dim=1) == targets.to(device)).float().sum().item()
            else:
                for m in Dec.modules():
                    if hasattr(m, 'reset'):
                        m.reset()
                for t in range(T):
                    if t == 0:
                        outputs = Dec(inputs)
                    else:
                        outputs += Dec(inputs)
                    corrects[t] += (outputs.argmax(dim=1) == targets.to(device)).float().sum().item()
            total += targets.shape[0]
    return correct / total if T is None else corrects / total


def anntosnn(cnn_model, train_x, train_y, test_x, test_y, batch=64, T=None):
    # Define data loader
    train_loader = data_loader(train_x, train_y, batch=batch)
    test_loader = data_loader(test_x, test_y, batch=batch)

    print('---------------------------------------------')
    print('Converting using MaxNorm')
    model_converter = ann2snn.Converter(mode='max', dataloader=train_loader)
    snn_model = model_converter(cnn_model)
    mode_max_accs = val_snn(snn_model, test_loader, T=T)

    return mode_max_accs


torch.backends.cudnn.benchmark = True

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)


def train_ann(cnn_model, train_x, train_y, test_x, test_y, ep=500, batch=64):
    """
    input:
        train_x, test_x (float): samples*length*ch (samples*ch*length).
        train_y, test_y (int): samples, ie.: [0, 1, 1, 0, ..., 2].
        ep (int): total train and test epoch
        batch (int): batch size
    output:
        train acc, test acc, weight_file
    """
    # Define training configuration
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.01)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ep)

    # Define data loader
    train_loader = data_loader(train_x, train_y, batch=batch)
    test_loader = data_loader(test_x, test_y, batch=batch)

    train_acc = []
    test_acc = []
    for epoch in range(ep):
        # Train ANN
        cnn_model.train()
        train_loss = 0
        correct = 0
        total = 0
        loss = 0
        print('\n')
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = cnn_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print(batch_idx, len(train_loader), 'Epoch: %d | ANN: trainLoss: %.4f | trainAcc: %.4f%% (%d/%d)'
                  % (epoch, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

        lr_scheduler.step()
        train_acc.append(round(correct / total, 4))

        # Test ANN
        cnn_model.eval()
        val_loss = 0
        correct = 0
        total = 0
        loss = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                outputs = cnn_model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                print(batch_idx, len(test_loader), 'Epoch: %d | ANN: testLoss: %.4f | testAcc: %.4f%% (%d/%d)'
                      % (epoch, val_loss / (batch_idx + 1), 100. * correct / total, correct, total))

        test_acc.append(round(correct / total, 4))

    train_acc = np.asarray(train_acc[-1])
    test_acc = np.asarray(test_acc[-1])
    return train_acc, test_acc,cnn_model


In [30]:
# @title Model definitions
# Model 2a

class LENet(nn.Module):
    """
        LENet Model
    input:
         data shape as: batch_size*1*channel*length (64*1*22*1000) BCI IV-2a
         batch_size：64
         channel：22
         length：1000
    output:
        classes_num
    """

    def __init__(self, classes_num=3, channel_count=22, drop_out = 0.5):
        super(LENet, self).__init__()
        self.drop_out = drop_out

        self.block_TCB_1 = nn.Sequential(
            # Temporal Convolution block kernel_size (1,64) #
            nn.ZeroPad2d((32, 31, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 64),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_2 = nn.Sequential(
            # Temporal Convolution block kernel_size (1,32) #
            nn.ZeroPad2d((16, 15, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 32),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_3 = nn.Sequential(
            # Temporal Convolution block kernel_size (1,16) #
            nn.ZeroPad2d((8, 7, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 16),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )

        self.TCB_fusion = nn.Sequential(
            # Temporal Convolution block fusion kernel_size (1,1) #
            nn.Conv2d(
                in_channels=24,
                out_channels=24,
                kernel_size=(1, 1),
                bias=False,
            ),
            nn.BatchNorm2d(24)
        )

        self.SCB = nn.Sequential(
            # Spatial Convolution block kernel_size (channel,1) #
            nn.Conv2d(
                in_channels=24,
                out_channels=16,
                kernel_size=(channel_count, 1),
                groups=8,
                bias=False
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(self.drop_out)
        )

        self.FFCB = nn.Sequential(
            # Feature Fusion Convolution block kernel_size (1,16) and (1,1) #
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 16),
                groups=16,
                bias=False
            ),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 1),
                bias=False
            ),  #
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(self.drop_out)
        )

        self.CCB = nn.Sequential(
            # Classification Convolution block kernel_size (1,1) #
            nn.Conv2d(
                in_channels=16,
                out_channels=classes_num,
                kernel_size=(1, 1),
                bias=False
            ),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
        )

    def forward(self, x):
        x1 = self.block_TCB_1(x)
        x2 = self.block_TCB_2(x)
        x3 = self.block_TCB_3(x)
        x4 = torch.cat([x1, x2, x3], dim=1)
        x = self.TCB_fusion(x4)
        x = self.SCB(x)
        x = self.FFCB(x)
        x = self.CCB(x)
        return x

class LENet_FCL(nn.Module):
    def __init__(self, classes_num=3, channel_count=60, drop_out=0.5):
        super(LENet_FCL, self).__init__()
        self.drop_out = drop_out

        # Keep all the convolutional layers the same
        self.block_TCB_1 = nn.Sequential(
            nn.ZeroPad2d((32, 31, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 64),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_2 = nn.Sequential(
            nn.ZeroPad2d((16, 15, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 32),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_3 = nn.Sequential(
            nn.ZeroPad2d((8, 7, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 16),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )

        self.TCB_fusion = nn.Sequential(
            nn.Conv2d(
                in_channels=24,
                out_channels=24,
                kernel_size=(1, 1),
                bias=False,
            ),
            nn.BatchNorm2d(24)
        )

        self.SCB = nn.Sequential(
            nn.Conv2d(
                in_channels=24,
                out_channels=16,
                kernel_size=(channel_count, 1),
                groups=8,
                bias=False
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(self.drop_out)
        )

        self.FFCB = nn.Sequential(
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 16),
                groups=16,
                bias=False
            ),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(self.drop_out)
        )

        # We'll determine the size of the FC layer in the forward pass
        self.flatten = nn.Flatten()
        self.fc = None
        self.classes_num = classes_num

    def forward(self, x):
        x1 = self.block_TCB_1(x)
        x2 = self.block_TCB_2(x)
        x3 = self.block_TCB_3(x)
        x4 = torch.cat([x1, x2, x3], dim=1)
        x = self.TCB_fusion(x4)
        x = self.SCB(x)
        x = self.FFCB(x)

        # Flatten the output
        x = self.flatten(x)

        # Create the FC layer on first forward pass if it doesn't exist
        if self.fc is None:
            in_features = x.shape[1]
            self.fc = nn.Linear(in_features, self.classes_num).to(x.device)
            # Initialize weights for the new layer
            nn.init.kaiming_normal_(self.fc.weight, mode="fan_out", nonlinearity="relu")
            if self.fc.bias is not None:
                nn.init.constant_(self.fc.bias, 0)

        # Apply the FC layer
        x = self.fc(x)
        return x


In [39]:
# @title OLD data visualization
def get_layer_str(layer):
    """Helper function to create a label string for a layer node."""
    layer_name = type(layer).__name__
    details = [layer_name]

    if isinstance(layer, nn.Conv2d):
        details.append(f"In: {layer.in_channels}, Out: {layer.out_channels}")
        details.append(f"KS: {layer.kernel_size}, Stride: {layer.stride}")
        details.append(f"Pad: {layer.padding}, Groups: {layer.groups}")
        if layer.bias is None:
            details.append("Bias: False")
    elif isinstance(layer, nn.BatchNorm2d):
        details.append(f"Features: {layer.num_features}")
    elif isinstance(layer, (nn.ReLU, nn.ReLU6)):
        pass # Name is enough
    elif isinstance(layer, nn.AvgPool2d):
        details.append(f"KS: {layer.kernel_size}, Stride: {layer.stride}")
        details.append(f"Pad: {layer.padding}")
    elif isinstance(layer, nn.AdaptiveAvgPool2d):
        details.append(f"Out: {layer.output_size}")
    elif isinstance(layer, nn.Dropout):
        details.append(f"P: {layer.p}")
    elif isinstance(layer, nn.ZeroPad2d):
        details.append(f"Pad: {layer.padding}")
    elif isinstance(layer, nn.Flatten):
        pass # Name is enough
    elif isinstance(layer, nn.Linear):
        details.append(f"In: {layer.in_features}, Out: {layer.out_features}")
        if layer.bias is None:
            details.append("Bias: False")
    elif 'spikingjelly.activation_based.neuron' in str(type(layer)):
        # For SpikingJelly neurons (e.g., IFNode, LIFNode)
        if hasattr(layer, 'tau'): details.append(f"tau: {layer.tau:.2f}")
        if hasattr(layer, 'v_threshold'): details.append(f"v_th: {layer.v_threshold:.2f}")
        if hasattr(layer, 'v_reset'):
            v_reset_val = "None" if layer.v_reset is None else f"{layer.v_reset:.2f}"
            details.append(f"v_reset: {v_reset_val}")
        if hasattr(layer, 'surrogate_function'):
             # Check if surrogate_function is an object with a __name__ or just a function
            sf_name = getattr(type(layer.surrogate_function), '__name__', str(layer.surrogate_function))
            details.append(f"Surrogate: {sf_name}")
    elif 'spikingjelly' in str(type(layer)):
        # Catch other SpikingJelly specific layers if any, just use name
        pass
    else:
        # For any other layer, just use its name
        pass
    return "\n".join(details)


def visualize_network_architecture(model, title="Network Architecture", input_shape_str="(B, 1, C, L)"):
    """
    Generates a Graphviz Digraph for the LENet or LENet_FCL model architecture.
    Works for both ANN and SNN (converted model) instances.
    """
    dot = graphviz.Digraph(comment=title, graph_attr={'splines': 'ortho'}) # Using ortho for potentially cleaner lines
    dot.attr(label=title, fontsize='20', labelloc='t')
    dot.attr(rankdir='TB') # Top to Bottom graph

    # Input node
    input_node_id = "input"
    dot.node(input_node_id, f"Input\n{input_shape_str}", shape='ellipse', color='lightblue', style='filled')

    # --- block_TCB_1, block_TCB_2, block_TCB_3 ---
    tcb_blocks_module_names = ['block_TCB_1', 'block_TCB_2', 'block_TCB_3']
    tcb_output_node_ids = []

    for i, block_name_str in enumerate(tcb_blocks_module_names):
        block_module = getattr(model, block_name_str)
        # Create a subgraph for each TCB block
        cluster_name = f'cluster_tcb_{i+1}'
        with dot.subgraph(name=cluster_name) as c:
            c.attr(label=block_name_str, style='filled', color='lightgray', margin='10')
            current_block_input_id = input_node_id # Each TCB takes the main input
            last_node_in_block_id = current_block_input_id

            for layer_idx, layer in enumerate(block_module):
                layer_id = f"{block_name_str}_layer_{layer_idx}"
                label = get_layer_str(layer)
                c.node(layer_id, label, shape='box', style='filled', fillcolor='white')
                if layer_idx == 0: # First layer in block connects to main input
                     dot.edge(current_block_input_id, layer_id) # Edge from main input to first layer of block
                else:
                    c.edge(last_node_in_block_id, layer_id) # Edge within the block
                last_node_in_block_id = layer_id
            tcb_output_node_ids.append(last_node_in_block_id)

    # --- Concatenation ---
    cat_node_id = "concatenate_tcb_outputs"
    dot.node(cat_node_id, "torch.cat\n(dim=1)", shape='invhouse', color='orange', style='filled')
    for tcb_out_node_id in tcb_output_node_ids:
        dot.edge(tcb_out_node_id, cat_node_id)

    current_last_global_node_id = cat_node_id

    # --- TCB_fusion ---
    module_after_cat = model.TCB_fusion
    block_name_prefix = "TCB_fusion"
    cluster_name = f'cluster_{block_name_prefix.lower()}'
    with dot.subgraph(name=cluster_name) as c:
        c.attr(label=block_name_prefix, style='filled', color='lightgray', margin='10')
        last_node_in_block_id = current_last_global_node_id
        for layer_idx, layer in enumerate(module_after_cat):
            layer_id = f"{block_name_prefix}_layer_{layer_idx}"
            label = get_layer_str(layer)
            c.node(layer_id, label, shape='box', style='filled', fillcolor='white')
            if layer_idx == 0:
                dot.edge(last_node_in_block_id, layer_id)
            else:
                c.edge(f"{block_name_prefix}_layer_{layer_idx-1}", layer_id)
            current_last_global_node_id = layer_id # Update global last node after each layer in this block

    # --- SCB ---
    module_after_cat = model.SCB
    block_name_prefix = "SCB"
    cluster_name = f'cluster_{block_name_prefix.lower()}'
    with dot.subgraph(name=cluster_name) as c:
        c.attr(label=block_name_prefix, style='filled', color='lightgray', margin='10')
        # Input to SCB is the output of TCB_fusion
        input_to_scb_id = current_last_global_node_id
        last_node_in_block_id = input_to_scb_id
        for layer_idx, layer in enumerate(module_after_cat):
            layer_id = f"{block_name_prefix}_layer_{layer_idx}"
            label = get_layer_str(layer)
            c.node(layer_id, label, shape='box', style='filled', fillcolor='white')
            if layer_idx == 0:
                dot.edge(last_node_in_block_id, layer_id)
            else:
                c.edge(f"{block_name_prefix}_layer_{layer_idx-1}", layer_id)
            current_last_global_node_id = layer_id

    # --- FFCB ---
    module_after_cat = model.FFCB
    block_name_prefix = "FFCB"
    cluster_name = f'cluster_{block_name_prefix.lower()}'
    with dot.subgraph(name=cluster_name) as c:
        c.attr(label=block_name_prefix, style='filled', color='lightgray', margin='10')
        input_to_ffcb_id = current_last_global_node_id
        last_node_in_block_id = input_to_ffcb_id
        for layer_idx, layer in enumerate(module_after_cat):
            layer_id = f"{block_name_prefix}_layer_{layer_idx}"
            label = get_layer_str(layer)
            c.node(layer_id, label, shape='box', style='filled', fillcolor='white')
            if layer_idx == 0:
                dot.edge(last_node_in_block_id, layer_id)
            else:
                c.edge(f"{block_name_prefix}_layer_{layer_idx-1}", layer_id)
            current_last_global_node_id = layer_id

    # --- CCB (for LENet) or FCL (for LENet_FCL) ---
    if hasattr(model, 'CCB'):
        classification_block_module = model.CCB
        block_name_prefix = "CCB"
    elif hasattr(model, 'FCL'):
        classification_block_module = model.FCL
        block_name_prefix = "FCL"
    else:
        # Fallback or error if neither is found
        print("Warning: Model has neither CCB nor FCL attribute for classification block.")
        classification_block_module = [] # Avoid error later
        block_name_prefix = "ClassificationUnknown"

    if classification_block_module: # Proceed if block exists
        cluster_name = f'cluster_{block_name_prefix.lower()}'
        with dot.subgraph(name=cluster_name) as c:
            c.attr(label=block_name_prefix, style='filled', color='lightgray', margin='10')
            input_to_class_block_id = current_last_global_node_id
            last_node_in_block_id = input_to_class_block_id
            for layer_idx, layer in enumerate(classification_block_module):
                layer_id = f"{block_name_prefix}_layer_{layer_idx}"
                label = get_layer_str(layer)
                c.node(layer_id, label, shape='box', style='filled', fillcolor='white')
                if layer_idx == 0:
                    dot.edge(last_node_in_block_id, layer_id)
                else:
                    c.edge(f"{block_name_prefix}_layer_{layer_idx-1}", layer_id)
                current_last_global_node_id = layer_id

    # Output node
    output_node_id = "output"
    dot.node(output_node_id, "Output\n(B, num_classes)", shape='ellipse', color='lightgreen', style='filled')
    dot.edge(current_last_global_node_id, output_node_id)

    return dot
#######
####### nEXT LINE

# @title Network Architecture Visualizations

print("--- Generating Network Architecture Visualizations ---")

# Ensure the models are on CPU for visualization if graphviz has issues with CUDA tensors directly
# (though it primarily looks at model structure, not tensor values)
# cnn_model.cpu()
# snn_model_instance.cpu() # SpikingJelly models are usually fine

# Define input shape string for the diagrams
# These variables (channel_count, data_length) should be available from Cell 3
input_shape_viz_str = f"(Batch, 1, {channel_count}, {data_length})"
# --- Visualize CNN ---
print("\nGenerating CNN architecture diagram...")
ann_dot_graph = visualize_network_architecture(cnn_model, # Renamed to avoid conflict with .dot extension
                                         title=f"CNN LENet Architecture (Channels: {channel_count})",
                                         input_shape_str=input_shape_viz_str)

display(ann_dot_graph) # Use display for Colab

#
# Save CNN diagram
#
cnn_filename_base = 'cnn_lenet_architecture'
# Save as SVG (for Excalidraw/vector editing)
ann_dot_graph.render(cnn_filename_base, format='svg', cleanup=True)
print(f"CNN diagram saved as {cnn_filename_base}.svg")

# Save as DOT source (for direct graph code editing)
with open(f"{cnn_filename_base}.gv", "w") as f:
    f.write(ann_dot_graph.source)
print(f"CNN diagram DOT source saved as {cnn_filename_base}.gv")

# Save as PNG (for quick viewing or embedding as image)
ann_dot_graph.render(cnn_filename_base, format='png', cleanup=True)
print(f"CNN diagram saved as {cnn_filename_base}.png")



# --- Visualize SNN ---
print("\nGenerating SNN architecture diagram...")
snn_dot_graph = visualize_network_architecture(snn_model, # Renamed
                                         title=f"SNN LENet Architecture (Converted, T={TIME_STEPS})",
                                         input_shape_str=input_shape_viz_str)
display(snn_dot_graph) # Use display for Colab

#
# Save SNN diagram
#
snn_filename_base = 'snn_lenet_architecture'
# Save as SVG (for Excalidraw/vector editing)
snn_dot_graph.render(snn_filename_base, format='svg', cleanup=True)
print(f"SNN diagram saved as {snn_filename_base}.svg")

# Save as DOT source (for direct graph code editing)
with open(f"{snn_filename_base}.gv", "w") as f:
    f.write(snn_dot_graph.source)
print(f"SNN diagram DOT source saved as {snn_filename_base}.gv")

# Save as PNG (for quick viewing or embedding as image)
snn_dot_graph.render(snn_filename_base, format='png', cleanup=True)
print(f"SNN diagram saved as {snn_filename_base}.png")

print("\nVisualizations generated and saved.")

In [32]:
# @title Loading the data
"""
Loading the data
"""

# Getting real samples
#Locally load the dataset
#file = scio.loadmat('Datasets\BCICIV_2a_gdf\Derivatives\A01T.mat')

# Google Colab load the dataset
file = scio.loadmat('/content/sub-011_eeg.mat')

all_data = file['all_data']
all_label = file['all_label']

# Print data information
channel_count = all_data.shape[1]
num_classes = len(np.unique(all_label.flatten()))
data_length = all_data.shape[2] # Assuming data is samples*ch*length
print(f"Channel count: {channel_count}")
print(f"Data shape: {all_data.shape}")
print(f"Label shape: {all_label.shape}")
print(f"Class distribution: Rest: {np.sum(all_label == 0)}, Elbow: {np.sum(all_label == 1)}, Hand: {np.sum(all_label == 2)}")

datasetX = torch.tensor(all_data, dtype=torch.float32)
datasetY = torch.tensor(all_label, dtype=torch.int64)

Channel count: 62
Data shape: (965, 62, 360)
Label shape: (965, 1)
Class distribution: Rest: 365, Elbow: 300, Hand: 300


In [33]:
# @title LENet to SNN Conversion Framework execution

# Hyperparameters
EPOCHS = 100
BATCH_SIZE = 64
TIME_STEPS = 100  # T for SNN
TEST_SIZE = 0.2
DROP_OUT = 0.25

# Split the data
print(f"{100 - (TEST_SIZE * 100)}% of the dataset is used for training and {TEST_SIZE * 100}% is used for testing.")
train_data, test_data, train_label, test_label = train_test_split(datasetX, datasetY, test_size=TEST_SIZE, shuffle=True,
                                                                  random_state=0)

# Initialize model
cnn_model = LENet(classes_num=3, channel_count=channel_count, drop_out = DROP_OUT).to(device)
cnn_model.apply(initialize_weights)

# Train CNN model
train_acc, test_acc,  cnn_model = train_ann(cnn_model, train_data, train_label, test_data, test_label,
                                              ep=EPOCHS, batch=BATCH_SIZE)
max_norm_acc = anntosnn( cnn_model, train_data, train_label, test_data, test_label,
                        batch=BATCH_SIZE, T=TIME_STEPS)
snn_model = ann2snn.Converter(mode='max', dataloader=data_loader(train_data, train_label, batch=BATCH_SIZE))( cnn_model)

print('\n')
print('ANN accuracy: Test: %.4f%%' % (test_acc * 100))
print('SNN accuracy: max_norm: %.4f%%' % (max_norm_acc[-1] * 100))

80.0% of the dataset is used for training and 20.0% is used for testing.


0 13 Epoch: 0 | ANN: trainLoss: 1.4497 | trainAcc: 35.9375% (23/64)
1 13 Epoch: 0 | ANN: trainLoss: 1.5929 | trainAcc: 31.2500% (40/128)
2 13 Epoch: 0 | ANN: trainLoss: 1.5350 | trainAcc: 30.2083% (58/192)
3 13 Epoch: 0 | ANN: trainLoss: 1.4437 | trainAcc: 30.4688% (78/256)
4 13 Epoch: 0 | ANN: trainLoss: 1.4172 | trainAcc: 30.3125% (97/320)
5 13 Epoch: 0 | ANN: trainLoss: 1.3775 | trainAcc: 31.2500% (120/384)
6 13 Epoch: 0 | ANN: trainLoss: 1.3429 | trainAcc: 33.0357% (148/448)
7 13 Epoch: 0 | ANN: trainLoss: 1.2965 | trainAcc: 34.9609% (179/512)
8 13 Epoch: 0 | ANN: trainLoss: 1.2731 | trainAcc: 34.5486% (199/576)
9 13 Epoch: 0 | ANN: trainLoss: 1.2462 | trainAcc: 35.6250% (228/640)
10 13 Epoch: 0 | ANN: trainLoss: 1.2322 | trainAcc: 35.7955% (252/704)
11 13 Epoch: 0 | ANN: trainLoss: 1.2202 | trainAcc: 35.8073% (275/768)
12 13 Epoch: 0 | ANN: trainLoss: 1.2046 | trainAcc: 35.8808% (277/772)
0 4 Epoch: 0 | ANN

KeyboardInterrupt: 

In [None]:

# Evaluate models and visualize results
# Get CNN predictions
cnn_model.eval()
cnn_predictions = []
true_labels = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        outputs = cnn_model(inputs)
        _, predicted = outputs.max(1)
        cnn_predictions.extend(predicted.cpu().numpy())
        true_labels.extend(targets.cpu().numpy())

# Get SNN predictions
snn_predictions = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        for m in snn_model.modules():
            if hasattr(m, 'reset'):
                m.reset()
        for t in range(TIME_STEPS):
            if t == 0:
                outputs = snn_model(inputs)
            else:
                outputs += snn_model(inputs)
        _, predicted = outputs.max(1)
        snn_predictions.extend(predicted.cpu().numpy())

# Confusion Matrix for ANN
cm_cnn = confusion_matrix(true_labels, cnn_predictions)
disp_ann = ConfusionMatrixDisplay(confusion_matrix=cm_cnn)
disp_ann.plot()
plt.title('Confusion Matrix - ANN')
plt.show()

# Confusion Matrix for SNN
cm_snn = confusion_matrix(true_labels, snn_predictions)
disp_snn = ConfusionMatrixDisplay(confusion_matrix=cm_snn)
disp_snn.plot()
plt.title('Confusion Matrix - SNN')
plt.show()

# Calculate and print accuracy for each class (ANN)
class_names = ['Rest', 'Elbow', 'Hand']
print('ANN Accuracy: %.2f%%' % (test_acc * 100))  # Using test_acc for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(cnn_predictions)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %

# Calculate and print accuracy for each class (SNN)
print('\nSNN Accuracy: %.2f%%' % (max_norm_acc[-1] * 100))  # Using max_norm_acc[-1] for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(snn_predictions)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %




In [None]:
"""
LENet_FCL to SNN Conversion Framework execution
"""
# Hyperparameters
EPOCHS = 100
BATCH_SIZE = 64
TIME_STEPS = 100  # T for SNN
TEST_SIZE = 0.2
DROP_OUT = 0.2


# Split the data
print(f"{100 - (TEST_SIZE * 100)}% of the dataset is used for training and {TEST_SIZE * 100}% is used for testing.")
train_data, test_data, train_label, test_label = train_test_split(datasetX, datasetY, test_size=TEST_SIZE, shuffle=True,
                                                                  random_state=0)

# Initialize model
cnn_model = LENet_FCL(classes_num=3, channel_count=channel_count, drop_out = DROP_OUT).to(device)
cnn_model.apply(initialize_weights)

# Train CNN model
train_acc, test_acc,  cnn_model = train_ann(cnn_model, train_data, train_label, test_data, test_label,
                                              ep=EPOCHS, batch=BATCH_SIZE)
max_norm_acc = anntosnn( cnn_model, train_data, train_label, test_data, test_label,
                        batch=BATCH_SIZE, T=TIME_STEPS)
snn_model = ann2snn.Converter(mode='max', dataloader=data_loader(train_data, train_label, batch=BATCH_SIZE))( cnn_model)

print('\n')
print('ANN accuracy: Test: %.4f%%' % (test_acc * 100))
print('SNN accuracy: max_norm: %.4f%%' % (max_norm_acc[-1] * 100))

In [None]:

# Evaluate models and visualize results
# Get CNN predictions
cnn_model.eval()
CNN = []
true_labels = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        outputs = cnn_model(inputs)
        _, predicted = outputs.max(1)
        CNN.extend(predicted.cpu().numpy())
        true_labels.extend(targets.cpu().numpy())

# Get SNN predictions
snn_predictions = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        for m in snn_model.modules():
            if hasattr(m, 'reset'):
                m.reset()
        for t in range(TIME_STEPS):
            if t == 0:
                outputs = snn_model(inputs)
            else:
                outputs += snn_model(inputs)
        _, predicted = outputs.max(1)
        snn_predictions.extend(predicted.cpu().numpy())

# Confusion Matrix for ANN
cm_cnn = confusion_matrix(true_labels, CNN)
disp_ann = ConfusionMatrixDisplay(confusion_matrix=cm_cnn)
disp_ann.plot()
plt.title('Confusion Matrix - ANN')
plt.show()

# Confusion Matrix for SNN
cm_snn = confusion_matrix(true_labels, snn_predictions)
disp_snn = ConfusionMatrixDisplay(confusion_matrix=cm_snn)
disp_snn.plot()
plt.title('Confusion Matrix - SNN')
plt.show()

# Calculate and print accuracy for each class (CNN)
class_names = ['Rest', 'Elbow', 'Hand']
print('CNN Accuracy: %.2f%%' % (test_acc * 100))  # Using test_acc for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(CNN)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %

# Calculate and print accuracy for each class (SNN)
print('\nSNN Accuracy: %.2f%%' % (max_norm_acc[-1] * 100))  # Using max_norm_acc[-1] for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(snn_predictions)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %


