<a href="https://colab.research.google.com/github/AlonResearch/SNN-for-MI-EEG/blob/main/MI3_SNNforMIeeg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.utils.data as da
from torch import nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import scipy.io as scio
#!pip install spikingjelly
from spikingjelly.activation_based import ann2snn
#!apt install -y graphviz graphviz-dev
#!pip install nnviz
from nnviz import drawing, inspection


In [2]:

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
!nvidia-smi

Using cpu device
/bin/bash: line 1: nvidia-smi: command not found


In [5]:
# @title Data loader and other functions
#Defining functions

def data_loader(data, label, batch=64, shuffle=True, drop=False):
    """
    Preprocess the data to fit model.
    Feed data into data_loader.
    input:
        data (float): samples*length*ch (samples*ch*length).
        label (int): samples, ie.: [0, 1, 1, 0, ..., 2].
        batch (int): batch size
        shuffle (bool): shuffle data before input into decoder
        drop (bool): drop the last samples if True
    output:
        data loader
    """
    label = torch.LongTensor(label.flatten()).to(device)
    if data.shape[1] >= data.shape[2]:
        data = torch.tensor(data.swapaxes(1, 2))
    data = torch.unsqueeze(data, dim=1).type('torch.FloatTensor').to(device)
    data = da.TensorDataset(data, label)
    loader = da.DataLoader(dataset=data, batch_size=batch, shuffle=shuffle, drop_last=drop)
    return loader


def val_snn(Dec, test_loader, T=None):
    Dec.eval().to(device)
    correct = 0
    total = 0
    if T is not None:
        corrects = np.zeros(T)
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if T is None:
                outputs = Dec(inputs)
                correct += (outputs.argmax(dim=1) == targets.to(device)).float().sum().item()
            else:
                for m in Dec.modules():
                    if hasattr(m, 'reset'):
                        m.reset()
                for t in range(T):
                    if t == 0:
                        outputs = Dec(inputs)
                    else:
                        outputs += Dec(inputs)
                    corrects[t] += (outputs.argmax(dim=1) == targets.to(device)).float().sum().item()
            total += targets.shape[0]
    return correct / total if T is None else corrects / total


def anntosnn(cnn_model, train_x, train_y, test_x, test_y, batch=64, T=None):
    # Define data loader
    train_loader = data_loader(train_x, train_y, batch=batch)
    test_loader = data_loader(test_x, test_y, batch=batch)

    print('---------------------------------------------')
    print('Converting using MaxNorm')
    model_converter = ann2snn.Converter(mode='max', dataloader=train_loader)
    snn_model = model_converter(cnn_model)
    mode_max_accs = val_snn(snn_model, test_loader, T=T)

    return mode_max_accs


torch.backends.cudnn.benchmark = True

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)


def train_ann(cnn_model, train_x, train_y, test_x, test_y, ep=500, batch=64):
    """
    input:
        train_x, test_x (float): samples*length*ch (samples*ch*length).
        train_y, test_y (int): samples, ie.: [0, 1, 1, 0, ..., 2].
        ep (int): total train and test epoch
        batch (int): batch size
    output:
        train acc, test acc, weight_file
    """
    # Define training configuration
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.01)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ep)

    # Define data loader
    train_loader = data_loader(train_x, train_y, batch=batch)
    test_loader = data_loader(test_x, test_y, batch=batch)

    train_acc = []
    test_acc = []
    for epoch in range(ep):
        # Train ANN
        cnn_model.train()
        train_loss = 0
        correct = 0
        total = 0
        loss = 0
        print('\n')
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = cnn_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print(batch_idx, len(train_loader), 'Epoch: %d | ANN: trainLoss: %.4f | trainAcc: %.4f%% (%d/%d)'
                  % (epoch, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

        lr_scheduler.step()
        train_acc.append(round(correct / total, 4))

        # Test ANN
        cnn_model.eval()
        val_loss = 0
        correct = 0
        total = 0
        loss = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                outputs = cnn_model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                print(batch_idx, len(test_loader), 'Epoch: %d | ANN: testLoss: %.4f | testAcc: %.4f%% (%d/%d)'
                      % (epoch, val_loss / (batch_idx + 1), 100. * correct / total, correct, total))

        test_acc.append(round(correct / total, 4))

    train_acc = np.asarray(train_acc[-1])
    test_acc = np.asarray(test_acc[-1])
    return train_acc, test_acc,cnn_model


In [4]:
# @title Model definitions
# Model 2a

class LENet(nn.Module):
    """
        LENet Model
    input:
         data shape as: batch_size*1*channel*length (64*1*22*1000) BCI IV-2a
         batch_size：64
         channel：22
         length：1000
    output:
        classes_num
    """

    def __init__(self, classes_num=3, channel_count=22, drop_out = 0.5):
        super(LENet, self).__init__()
        self.drop_out = drop_out

        self.block_TCB_1 = nn.Sequential(
            # Temporal Convolution block kernel_size (1,64) #
            nn.ZeroPad2d((32, 31, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 64),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_2 = nn.Sequential(
            # Temporal Convolution block kernel_size (1,32) #
            nn.ZeroPad2d((16, 15, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 32),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_3 = nn.Sequential(
            # Temporal Convolution block kernel_size (1,16) #
            nn.ZeroPad2d((8, 7, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 16),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )

        self.TCB_fusion = nn.Sequential(
            # Temporal Convolution block fusion kernel_size (1,1) #
            nn.Conv2d(
                in_channels=24,
                out_channels=24,
                kernel_size=(1, 1),
                bias=False,
            ),
            nn.BatchNorm2d(24)
        )

        self.SCB = nn.Sequential(
            # Spatial Convolution block kernel_size (channel,1) #
            nn.Conv2d(
                in_channels=24,
                out_channels=16,
                kernel_size=(channel_count, 1),
                groups=8,
                bias=False
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(self.drop_out)
        )

        self.FFCB = nn.Sequential(
            # Feature Fusion Convolution block kernel_size (1,16) and (1,1) #
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 16),
                groups=16,
                bias=False
            ),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 1),
                bias=False
            ),  #
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(self.drop_out)
        )

        self.CCB = nn.Sequential(
            # Classification Convolution block kernel_size (1,1) #
            nn.Conv2d(
                in_channels=16,
                out_channels=classes_num,
                kernel_size=(1, 1),
                bias=False
            ),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
        )

    def forward(self, x):
        x1 = self.block_TCB_1(x)
        x2 = self.block_TCB_2(x)
        x3 = self.block_TCB_3(x)
        x4 = torch.cat([x1, x2, x3], dim=1)
        x = self.TCB_fusion(x4)
        x = self.SCB(x)
        x = self.FFCB(x)
        x = self.CCB(x)
        return x

class LENet_FCL(nn.Module):
    def __init__(self, classes_num=3, channel_count=60, drop_out=0.5):
        super(LENet_FCL, self).__init__()
        self.drop_out = drop_out

        # Keep all the convolutional layers the same
        self.block_TCB_1 = nn.Sequential(
            nn.ZeroPad2d((32, 31, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 64),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_2 = nn.Sequential(
            nn.ZeroPad2d((16, 15, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 32),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )
        self.block_TCB_3 = nn.Sequential(
            nn.ZeroPad2d((8, 7, 0, 0)),
            nn.Conv2d(
                in_channels=1,
                out_channels=8,
                kernel_size=(1, 16),
                bias=False,
            ),
            nn.BatchNorm2d(8)
        )

        self.TCB_fusion = nn.Sequential(
            nn.Conv2d(
                in_channels=24,
                out_channels=24,
                kernel_size=(1, 1),
                bias=False,
            ),
            nn.BatchNorm2d(24)
        )

        self.SCB = nn.Sequential(
            nn.Conv2d(
                in_channels=24,
                out_channels=16,
                kernel_size=(channel_count, 1),
                groups=8,
                bias=False
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(self.drop_out)
        )

        self.FFCB = nn.Sequential(
            nn.ZeroPad2d((7, 8, 0, 0)),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 16),
                groups=16,
                bias=False
            ),
            nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(self.drop_out)
        )

        # We'll determine the size of the FC layer in the forward pass
        self.flatten = nn.Flatten()
        self.fc = None
        self.classes_num = classes_num

    def forward(self, x):
        x1 = self.block_TCB_1(x)
        x2 = self.block_TCB_2(x)
        x3 = self.block_TCB_3(x)
        x4 = torch.cat([x1, x2, x3], dim=1)
        x = self.TCB_fusion(x4)
        x = self.SCB(x)
        x = self.FFCB(x)

        # Flatten the output
        x = self.flatten(x)

        # Create the FC layer on first forward pass if it doesn't exist
        if self.fc is None:
            in_features = x.shape[1]
            self.fc = nn.Linear(in_features, self.classes_num).to(x.device)
            # Initialize weights for the new layer
            nn.init.kaiming_normal_(self.fc.weight, mode="fan_out", nonlinearity="relu")
            if self.fc.bias is not None:
                nn.init.constant_(self.fc.bias, 0)

        # Apply the FC layer
        x = self.fc(x)
        return x


In [3]:
# @title Loading the data
"""
Loading the data
"""

# Getting real samples
#Locally load the dataset
#file = scio.loadmat('Datasets\BCICIV_2a_gdf\Derivatives\A01T.mat')

# Google Colab load the dataset
file = scio.loadmat('/content/sub-011_eeg90hz.mat')


all_data = file['all_data']
all_label = file['all_label']

# Print data information
channel_count = all_data.shape[1]
num_classes = len(np.unique(all_label.flatten()))
data_length = all_data.shape[2] # Assuming data is samples*ch*length
print(f"Channel count: {channel_count}")
print(f"Data shape: {all_data.shape}")
print(f"Label shape: {all_label.shape}")
print(f"Class distribution: Rest: {np.sum(all_label == 0)}, Elbow: {np.sum(all_label == 1)}, Hand: {np.sum(all_label == 2)}")

datasetX = torch.tensor(all_data, dtype=torch.float32)
datasetY = torch.tensor(all_label, dtype=torch.int64)

Channel count: 62
Data shape: (965, 62, 360)
Label shape: (965, 1)
Class distribution: Rest: 365, Elbow: 300, Hand: 300


In [8]:
# @title LENet to SNN Conversion Framework execution

# Hyperparameters
EPOCHS = 1
BATCH_SIZE = 64
TIME_STEPS = 100  # T for SNN
TEST_SIZE = 0.2
DROP_OUT = 0.5

# Split the data
print(f"{100 - (TEST_SIZE * 100)}% of the dataset is used for training and {TEST_SIZE * 100}% is used for testing.")
train_data, test_data, train_label, test_label = train_test_split(datasetX, datasetY, test_size=TEST_SIZE, shuffle=True,
                                                                  random_state=0)

# Initialize model
cnn_model = LENet(classes_num=3, channel_count=channel_count, drop_out = DROP_OUT).to(device)
cnn_model.apply(initialize_weights)

# Train CNN model
train_acc, test_acc,  cnn_model = train_ann(cnn_model, train_data, train_label, test_data, test_label,
                                              ep=EPOCHS, batch=BATCH_SIZE)
max_norm_acc = anntosnn( cnn_model, train_data, train_label, test_data, test_label,
                        batch=BATCH_SIZE, T=TIME_STEPS)
snn_model = ann2snn.Converter(mode='max', dataloader=data_loader(train_data, train_label, batch=BATCH_SIZE))( cnn_model)

print('\n')
print('ANN accuracy: Test: %.4f%%' % (test_acc * 100))
print('SNN accuracy: max_norm: %.4f%%' % (max_norm_acc[-1] * 100))

80.0% of the dataset is used for training and 20.0% is used for testing.


0 13 Epoch: 0 | ANN: trainLoss: 1.3034 | trainAcc: 32.8125% (21/64)
1 13 Epoch: 0 | ANN: trainLoss: 1.2398 | trainAcc: 32.8125% (42/128)
2 13 Epoch: 0 | ANN: trainLoss: 1.2214 | trainAcc: 34.3750% (66/192)
3 13 Epoch: 0 | ANN: trainLoss: 1.1984 | trainAcc: 34.7656% (89/256)
4 13 Epoch: 0 | ANN: trainLoss: 1.1832 | trainAcc: 34.3750% (110/320)
5 13 Epoch: 0 | ANN: trainLoss: 1.1689 | trainAcc: 35.4167% (136/384)
6 13 Epoch: 0 | ANN: trainLoss: 1.1537 | trainAcc: 37.0536% (166/448)
7 13 Epoch: 0 | ANN: trainLoss: 1.1448 | trainAcc: 36.7188% (188/512)
8 13 Epoch: 0 | ANN: trainLoss: 1.1396 | trainAcc: 36.6319% (211/576)
9 13 Epoch: 0 | ANN: trainLoss: 1.1294 | trainAcc: 36.7188% (235/640)
10 13 Epoch: 0 | ANN: trainLoss: 1.1142 | trainAcc: 38.2102% (269/704)
11 13 Epoch: 0 | ANN: trainLoss: 1.1019 | trainAcc: 39.5833% (304/768)
12 13 Epoch: 0 | ANN: trainLoss: 1.0815 | trainAcc: 39.7668% (307/772)
0 4 Epoch: 0 | AN

100%|██████████| 13/13 [00:12<00:00,  1.07it/s]
100%|██████████| 13/13 [00:11<00:00,  1.12it/s]



ANN accuracy: Test: 40.9300%
SNN accuracy: max_norm: 40.9326%





In [13]:
# @title NN Visualization

# This block should be placed after the training and SNN conversion
# of the model you wish to visualize. For example, after the
# "LENet to SNN Conversion Framework execution" cell or
# after the "LENet_FCL to SNN Conversion Framework execution" cell.
# The `cnn_model` and `snn_model` variables from that preceding cell
# will be used for visualization.

# Install nnviz if you haven't already (uncomment the line below if needed)
# !pip install nnviz

# Import necessary nnviz modules
from nnviz import drawing, inspection
import torch # Ensure torch is imported

print("Starting Neural Network Visualization...")

# --- CNN Model Visualization ---
if 'cnn_model' in locals() and isinstance(cnn_model, torch.nn.Module):
    model_to_visualize_cnn = cnn_model
    cnn_model_name = type(model_to_visualize_cnn).__name__
    print(f"\nVisualizing CNN model: {cnn_model_name}")

    try:
        # Move model to CPU for inspection (safer for fx tracing)
        cnn_model_cpu = model_to_visualize_cnn.to('cpu')

        # Create an inspector
        # TorchFxInspector uses torch.fx to trace the model graph
        cnn_inspector = inspection.TorchFxInspector()

        # Inspect the CNN model
        # For models like LENet_FCL, training (as done in your script) ensures that
        # dynamically created layers (e.g., self.fc) are initialized before inspection.
        print(f"Inspecting {cnn_model_name} on CPU...")
        cnn_graph = cnn_inspector.inspect(cnn_model_cpu)

        # Create a drawer for saving the visualization (e.g., to a PDF file)
        # You can change the output format by changing the extension (e.g., .png, .svg)
        cnn_viz_filename = f"{cnn_model_name}_cnn_architecture.png"
        cnn_drawer = drawing.GraphvizDrawer(cnn_viz_filename)

        # Draw the graph and save it to the file
        cnn_drawer.draw(cnn_graph)
        print(f"CNN model visualization saved to: {cnn_viz_filename}")
        print(
            f"Note: If the output file is empty or shows an error, ensure Graphviz "
            f"(specifically the 'dot' command) is installed and accessible in your system's PATH."
        )

    except Exception as e:
        print(
            f"An error occurred during CNN model ({cnn_model_name}) visualization: {e}"
        )
        print("Troubleshooting tips:")
        print("- Ensure 'nnviz' is installed.")
        print(
            "- Ensure 'graphviz' (dot executable) is installed and in your system's PATH."
        )
        print(
            "- The model structure might contain operations not traceable by torch.fx. "
            "Check nnviz documentation for advanced usage or alternative inspectors if needed."
        )
else:
    print(
        "\nCNN model ('cnn_model') not found in the current scope, or it's not a torch.nn.Module. "
        "Skipping CNN visualization."
    )
print("\nNeural Network Visualization process complete.")


Starting Neural Network Visualization...

Visualizing CNN model: LENet
Inspecting LENet on CPU...
CNN model visualization saved to: LENet_cnn_architecture.png
Note: If the output file is empty or shows an error, ensure Graphviz (specifically the 'dot' command) is installed and accessible in your system's PATH.
LENet(
  (block_TCB_1): Module(
    (0): ZeroPad2d((32, 31, 0, 0))
    (1): Conv2d(1, 8, kernel_size=(1, 64), stride=(1, 1))
  )
  (block_TCB_2): Module(
    (0): ZeroPad2d((16, 15, 0, 0))
    (1): Conv2d(1, 8, kernel_size=(1, 32), stride=(1, 1))
  )
  (block_TCB_3): Module(
    (0): ZeroPad2d((8, 7, 0, 0))
    (1): Conv2d(1, 8, kernel_size=(1, 16), stride=(1, 1))
  )
  (TCB_fusion): Module(
    (0): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))
  )
  (SCB): Module(
    (0): Conv2d(24, 16, kernel_size=(62, 1), stride=(1, 1), groups=8)
    (3): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
    (4): Dropout(p=0.5, inplace=False)
  )
  (FFCB): Module(
    (0): ZeroPad2

In [None]:

# Evaluate models and visualize results
# Get CNN predictions
cnn_model.eval()
cnn_predictions = []
true_labels = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        outputs = cnn_model(inputs)
        _, predicted = outputs.max(1)
        cnn_predictions.extend(predicted.cpu().numpy())
        true_labels.extend(targets.cpu().numpy())

# Get SNN predictions
snn_predictions = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        for m in snn_model.modules():
            if hasattr(m, 'reset'):
                m.reset()
        for t in range(TIME_STEPS):
            if t == 0:
                outputs = snn_model(inputs)
            else:
                outputs += snn_model(inputs)
        _, predicted = outputs.max(1)
        snn_predictions.extend(predicted.cpu().numpy())

# Confusion Matrix for ANN
cm_cnn = confusion_matrix(true_labels, cnn_predictions)
disp_ann = ConfusionMatrixDisplay(confusion_matrix=cm_cnn)
disp_ann.plot()
plt.title('Confusion Matrix - ANN')
plt.show()

# Confusion Matrix for SNN
cm_snn = confusion_matrix(true_labels, snn_predictions)
disp_snn = ConfusionMatrixDisplay(confusion_matrix=cm_snn)
disp_snn.plot()
plt.title('Confusion Matrix - SNN')
plt.show()

# Calculate and print accuracy for each class (ANN)
class_names = ['Rest', 'Elbow', 'Hand']
print('ANN Accuracy: %.2f%%' % (test_acc * 100))  # Using test_acc for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(cnn_predictions)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %

# Calculate and print accuracy for each class (SNN)
print('\nSNN Accuracy: %.2f%%' % (max_norm_acc[-1] * 100))  # Using max_norm_acc[-1] for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(snn_predictions)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %




In [None]:
"""
LENet_FCL to SNN Conversion Framework execution
"""
# Hyperparameters
EPOCHS = 100
BATCH_SIZE = 64
TIME_STEPS = 100  # T for SNN
TEST_SIZE = 0.2
DROP_OUT = 0.2


# Split the data
print(f"{100 - (TEST_SIZE * 100)}% of the dataset is used for training and {TEST_SIZE * 100}% is used for testing.")
train_data, test_data, train_label, test_label = train_test_split(datasetX, datasetY, test_size=TEST_SIZE, shuffle=True,
                                                                  random_state=0)

# Initialize model
cnn_model = LENet_FCL(classes_num=3, channel_count=channel_count, drop_out = DROP_OUT).to(device)
cnn_model.apply(initialize_weights)

# Train CNN model
train_acc, test_acc,  cnn_model = train_ann(cnn_model, train_data, train_label, test_data, test_label,
                                              ep=EPOCHS, batch=BATCH_SIZE)
max_norm_acc = anntosnn( cnn_model, train_data, train_label, test_data, test_label,
                        batch=BATCH_SIZE, T=TIME_STEPS)
snn_model = ann2snn.Converter(mode='max', dataloader=data_loader(train_data, train_label, batch=BATCH_SIZE))( cnn_model)

print('\n')
print('ANN accuracy: Test: %.4f%%' % (test_acc * 100))
print('SNN accuracy: max_norm: %.4f%%' % (max_norm_acc[-1] * 100))

In [None]:

# Evaluate models and visualize results
# Get CNN predictions
cnn_model.eval()
CNN = []
true_labels = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        outputs = cnn_model(inputs)
        _, predicted = outputs.max(1)
        CNN.extend(predicted.cpu().numpy())
        true_labels.extend(targets.cpu().numpy())

# Get SNN predictions
snn_predictions = []
with torch.no_grad():
    for inputs, targets in data_loader(test_data, test_label, batch=BATCH_SIZE, shuffle=False, drop=False):
        for m in snn_model.modules():
            if hasattr(m, 'reset'):
                m.reset()
        for t in range(TIME_STEPS):
            if t == 0:
                outputs = snn_model(inputs)
            else:
                outputs += snn_model(inputs)
        _, predicted = outputs.max(1)
        snn_predictions.extend(predicted.cpu().numpy())

# Confusion Matrix for ANN
cm_cnn = confusion_matrix(true_labels, CNN)
disp_ann = ConfusionMatrixDisplay(confusion_matrix=cm_cnn)
disp_ann.plot()
plt.title('Confusion Matrix - ANN')
plt.show()

# Confusion Matrix for SNN
cm_snn = confusion_matrix(true_labels, snn_predictions)
disp_snn = ConfusionMatrixDisplay(confusion_matrix=cm_snn)
disp_snn.plot()
plt.title('Confusion Matrix - SNN')
plt.show()

# Calculate and print accuracy for each class (CNN)
class_names = ['Rest', 'Elbow', 'Hand']
print('CNN Accuracy: %.2f%%' % (test_acc * 100))  # Using test_acc for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(CNN)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %

# Calculate and print accuracy for each class (SNN)
print('\nSNN Accuracy: %.2f%%' % (max_norm_acc[-1] * 100))  # Using max_norm_acc[-1] for overall accuracy, formatted to 2 decimal places
for class_idx in range(len(class_names)):
    class_indices = np.where(np.array(true_labels) == class_idx)
    class_predictions = np.array(snn_predictions)[class_indices]
    class_true_labels = np.array(true_labels)[class_indices]
    class_accuracy = np.sum(class_predictions == class_true_labels) / len(class_true_labels)
    print(f"        {class_names[class_idx]}: %.2f%%" % (class_accuracy * 100))  # Formatted to 2 decimal places with %


