In [None]:
import os
import time
import warnings
from tqdm.auto import tqdm

#----------------------------------------------------------#

import numpy as np
import pandas as pd

#----------------------------------------------------------#

import seaborn as sns
import matplotlib.pyplot as plt

#----------------------------------------------------------#

from sklearn.metrics import (confusion_matrix,
                             classification_report,
                             ConfusionMatrixDisplay)

#-----------------------torch imports-----------------------------------#

import torch
import torch.nn as nn
#import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchsummary
#-----------------------spikingjelly imports-----------------------------------#

from spikingjelly.activation_based import neuron, surrogate, functional,layer

warnings.filterwarnings('ignore')
print(f'PyTorch Version: {torch.__version__}')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Current Device: {device}.')

In [None]:
# To download Images( results graphs) using Python
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 1000

In [None]:
# Vars

root = 'C:/Users/sandi/BrainT/'
train, test = 'Training', 'Testing'
target_size = (128, 128)
batch_size = 16
num_classes = 2


learning_rate = 0.001 # Learning rate
mom=0.9 #momentum
epochs = 40

train_accuracies = []
train_losses = []

test_accuracies = []
test_losses = []


In [None]:
# functions

def accuracy(y_hat, y_true):
    correct = torch.eq(y_hat, y_true).sum().item()
    accuracy = (correct / len(y_true))
    return accuracy

In [None]:
# Image Preprocess
def data_transforms(type = None, target_size = target_size):
    if type == train:
        type_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(target_size, antialias = True),
        transforms.ColorJitter(brightness = (0.85, 1.15)),
        transforms.RandomAffine(degrees = 0, translate = (0.002, 0.002)),
        transforms.RandomRotation(degrees = 10),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
    elif type == test:
        type_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(target_size, antialias = True),
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
        ])
    
    return type_transforms


train_data = ImageFolder(os.path.join(root, train), transform = data_transforms(type = train))
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle = True)

test_data = ImageFolder(os.path.join(root, test), transform = data_transforms(type = test))
test_loader = DataLoader(test_data, batch_size = batch_size, shuffle = False)

print(f'Batch size: {batch_size}')
print(f'Found {len(train_data)} validated image filenames belonging to {num_classes} classes.')
print(f'Found {len(test_data)} validated image filenames belonging to {num_classes} classes.')

In [None]:
train_data.classes

In [None]:
# Visualizing classes
# Showing Images from dataset
data_iter = iter(train_loader)
images, labels = next(data_iter)

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
images = (images.numpy().transpose((0, 2, 3, 1)) * std + mean).clip(0, 1)

num_images = min(len(images), 16)
rows = 4 
fig, axes = plt.subplots(rows, 4, figsize = (15, 4 * rows)) 

for i, ax in enumerate(axes.flat):
    if i < num_images:
        ax.imshow(images[i])
        ax.set_title(f'Label: {train_data.classes[labels[i]]}', fontsize = 15, fontweight = 'bold')
        ax.axis('off')

for ax in axes.flat[num_images:]:
    ax.axis('off')

plt.tight_layout(pad = 1)
plt.savefig('test.jpeg', dpi = 500)
plt.show()

In [None]:
from spikingjelly.activation_based import layer, surrogate, neuron

class Medical_SNN(nn.Module):
    def __init__(self, T:int, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
        layer.Conv2d(3, 32, kernel_size=11, padding=1, bias=False),
        layer.BatchNorm2d(32),
        neuron.IFNode(surrogate_function=surrogate.Sigmoid(),detach_reset=False),
        layer.MaxPool2d(2, 2),

        layer.Conv2d(32, 64, kernel_size=11, padding=1, bias=False),
        layer.BatchNorm2d(64),
        neuron.IFNode(surrogate_function=surrogate.Sigmoid(),detach_reset=False),
        layer.MaxPool2d(3, 3),  # 7 * 7

        layer.Conv2d(64, 128, kernel_size=11, padding=1, bias=False),
        layer.BatchNorm2d(128),
        neuron.IFNode(surrogate_function=surrogate.Sigmoid(),detach_reset=False),
        layer.MaxPool2d(2, 2),  # 7 * 7

        layer.Flatten(),
        nn.LazyLinear(out_features=256,bias=False),
        #layer.Linear(2304, 256, bias=False),
        neuron.IFNode(surrogate_function=surrogate.Sigmoid(),detach_reset=False),

        layer.Linear(256, 4, bias=False),
        neuron.IFNode(surrogate_function=surrogate.Sigmoid(),detach_reset=False),
        )

        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        # x.shape = [N, C, H, W]
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr
    

In [None]:
#Setting the Timestep ( Important for SNNs)

net=Medical_SNN(T=4)

In [None]:
# Running on GPU
model=net.cuda()

In [None]:
from torchinfo import summary


summary(net, input_size=(batch_size, 3, 128, 128))

In [None]:
print(model)

In [None]:
#------------------------- Loss and optim----------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = learning_rate ,momentum=mom)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

In [None]:
# Trainning
torch.manual_seed(0)
max_test_acc = -1
out_dir = os.path.join('')
for epoch in tqdm(range(epochs)):
    start_time = time.time()

    train_accuracy = 0.0
    train_loss = 0.0

    model.train()

    for idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        y_hat = model(inputs)
        loss = criterion(y_hat, labels)
        # loss.backward()
        loss.backward()
        optimizer.step()

        predictions = torch.argmax(torch.softmax(y_hat, dim=1), dim=1)
        train_accuracy += accuracy(predictions, labels)
        train_loss += loss.item()
        functional.reset_net(model)

    train_accuracy /= len(train_loader)
    train_accuracies.append(train_accuracy)

    train_loss /= len(train_loader)
    train_losses.append(train_loss)

    lr_scheduler.step()

    test_accuracy = 0.0
    test_loss = 0.0

    model.eval()

    with torch.inference_mode():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            y_hat = model(inputs)
            loss = criterion(y_hat, labels)

            predictions = torch.argmax(torch.softmax(y_hat, dim=1), dim=1)
            test_accuracy += accuracy(predictions, labels)
            test_loss += loss.item()
            functional.reset_net(model)
        test_accuracy /= len(test_loader)
        test_accuracies.append(test_accuracy)

        test_loss /= len(test_loader)
        test_losses.append(test_loss)

    elapsed_time = time.time() - start_time

    #lr = optimizer.param_groups[0]['lr']
    lr = lr_scheduler.get_last_lr()



    print(f'Epoch {epoch + 1}/{epochs} | Learning Rate: {lr}')
    print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}')
    print(f'Val Loss: {test_loss:.4f}, Val Accuracy: {test_accuracy:.4f}')
    print(f'Elapsed Time: {elapsed_time:.2f} seconds\n')

    #lr_scheduler.step(test_loss)
    save_max = False
    if test_accuracy > max_test_acc:
        max_test_acc = test_accuracy
        save_max = True

    checkpoint = {
        'net': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
        'max_test_acc': max_test_acc
    }

    if save_max:
        torch.save(checkpoint, os.path.join(out_dir, 'BrainT_Mymodel2_detach_false_checkpoint_SGD_0p001_m0p9_max.pth'))

    torch.save(checkpoint, os.path.join(out_dir, 'BrainT_vgg13_Detach_flase_checkpoint_SGD_0p001_m0p9_latest.pth'))


    if test_accuracy >= 0.99:
        print('\nDesired Accuracy Achieved!')
        break

In [None]:
true_labels = []
predicted_labels = []

with torch.inference_mode():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        #model.load_state_dict(torch.load("BrainT_MYMODEL2_Detach_true_checkpoint_SGD_0p001_m0p9_T_3_max.pth"))       
        y_hat = model(inputs)
        predictions = torch.argmax(torch.softmax(y_hat, dim=1), dim=1)
        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(predictions.cpu().numpy())
        functional.reset_net(model)
print(f'Test Accuracy Score: {accuracy(torch.tensor(predicted_labels), torch.tensor(true_labels))*100:.2f} %')

class_labels = list(test_data.classes)

print(f'Classification Report (Test) --> \n\n' + \
f'{classification_report(true_labels, predicted_labels, target_names = class_labels)}')

In [None]:
_, ax = plt.subplots(ncols=2, figsize=(15, 6))

    # accuracy

ax[0].plot(train_accuracies, marker='o', color='blue', markersize=7)
ax[0].plot(test_accuracies, marker='x', color='red', markersize=7)
ax[0].set_title('Model Accuracy')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Accuracy')
ax[0].legend(['Train', 'Validation'])
ax[0].grid(alpha=0.2)

    # loss

ax[1].plot(train_losses, marker='o', color='blue', markersize=7)
ax[1].plot(test_losses, marker='x', color='red', markersize=7)
ax[1].set_title('Model Loss')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Loss')
ax[1].legend(['Train', 'Validation'])

ax[1].grid(alpha=0.2)
plt.savefig('BrainT_Mymodel2_detach_false_checkpoint_SGD_0p001_m0p9_loss.png', dpi=1000)
plt.show()