In [4]:
# test the model
import torch
import torch.nn as nn
import torch.optim as optim
import snntorch as snn
#load the model from the checkpoint
imsize = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DSRN(nb_blocks=64, in_ch=1, out_ch=8, imsize= imsize, device=device, thresh=4).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.05, betas=(0.9, 0.999))
model.load_state_dict(torch.load('model.pth'))

test_losses = []
pred_results = torch.empty(0, device = device)
true_labels = torch.empty(0, device = device)
with torch.no_grad():
    model.eval()
    for x, y in test:
        x = dataset.str2image(x, imsize, imsize)
        x = x.unsqueeze(1).to(device)
        outputs = model(x)
        _, predicted = torch.max(outputs.sum(1), 1)
        pred_results= torch.cat((pred_results,predicted), dim = 0)
        true_labels = torch.cat((true_labels, y.to(device)), dim = 0)
        tests_loss = criterion(outputs.sum(1), y.to(device))
        test_losses.append(tests_loss.item())



TypeError: DSRN.__init__() missing 1 required positional argument: 'thresh'

In [None]:
# confusion matrix
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

true_labels = true_labels.cpu()
pred_results = pred_results.cpu()
cm = confusion_matrix(true_labels, pred_results)
# Plot the confusion matrix
classes = np.unique(true_labels)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
print('Accuracy: ', accuracy_score(true_labels, pred_results)*100, '%')

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import snntorch as snn
import math

#convert the residual shrinkage block to pytorch
class SoftThresholdLayer(nn.Module):
    def __init__(self, threshold_init=1.0):
        super(SoftThresholdLayer, self).__init__()
        self.threshold = nn.Parameter(torch.tensor(threshold_init))
    def forward(self, x, thres):
        return  torch.sign(x) * torch.relu(torch.abs(x) - thres)

class ResnetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, cellsize, device, betta = 0.85, thrs = 1, kernel_size = 3, stride=1, padding=1):
        super(ResnetBlock, self).__init__()

        self.Conv2d = nn.Conv2d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)
        self.conv_s = (cellsize - kernel_size + 2*padding)//stride + 1
        self.BN = nn.BatchNorm2d(out_ch)
        self.IF = snn.Leaky(beta = betta, threshold= thrs)

    def meminit(self):
        mem = self.IF.init_leaky()
        return mem
    
    def forward(self, x, mem):
        x = self.Conv2d(x)
        x = self.BN(x)
        x, mem = self.IF(x, mem)
        return x, mem


class SRU(nn.Module):
    def __init__(self, catORadd ,in_ch, out_ch, imsize, device, betta = 0.85, thrs = 1, kernel_size = 3, stride=1, padding=1):
        super(SRU, self).__init__()
        self.action = catORadd # 0 for add, 1 for cat
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.stride = stride
        self.in_ch = in_ch
        self.padding = padding
        self.betta = betta # torch.rand(imsize)
        self.th = thrs
        self.imsize = imsize
        self.device = device
        self.ks = (self.imsize - self.kernel_size + 2*self.padding)//(self.stride) + 1
        self.out_s = [self.ks, self.ks*(1+self.action)]

        
        self.Res1 = ResnetBlock(self.in_ch, self.out_ch, self.imsize, self.device, betta=self.betta, thrs=self.th)
        self.Res1_s = self.Res1.conv_s

        self.Res2 = ResnetBlock(self.out_ch, self.out_ch, self.Res1_s, self.device, betta=self.betta, thrs=self.th)
        self.Res2_s = self.Res2.conv_s

        self.Res3 = ResnetBlock(self.out_ch, self.out_ch, self.Res2_s, self.device, betta=self.betta, thrs=self.th)
        self.Res3_s = self.Res3.conv_s

        self.Res4 = ResnetBlock(self.out_ch, self.out_ch, self.Res3_s, self.device, betta=self.betta, thrs=self.th)
        self.Res4_s = self.Res4.conv_s
        
        #self.AvgPool2d = nn.AvgPool2d(kernel_size= self.kernel_size, stride = self.stride , padding = self.padding)
        self.Avgpool2d = nn.AvgPool2d(kernel_size= self.Res4_s)
        self.FC1 = nn.Linear(1,4)
        self.FC2 = nn.Linear(4,1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.soft_threshold = SoftThresholdLayer(threshold_init=1.0)

    def meminit(self):
        mem1 = self.Res1.meminit()
        mem2 = self.Res2.meminit()
        mem3 = self.Res3.meminit()
        mem4 = self.Res4.meminit()
        return [mem1, mem2, mem3, mem4]
    
    def forward(self, x, mem):
        residual = x
        residual, mem[0] = self.Res1(residual, mem[0])
        residual, mem[1] = self.Res2(residual, mem[1])
        residual, mem[2] = self.Res3(residual, mem[2])
        residual, mem[3] = self.Res4(residual, mem[3])

        #  Squeeze and Excitation layer
        scales = self.Avgpool2d(residual)
        feedback = self.relu(self.FC1(scales))
        feedback = self.sigmoid(self.FC2(feedback))
        # multiply scales with feedback
        thres = torch.multiply(scales, feedback)
        #soft thresholding
        residual = self.soft_threshold(residual, thres)
        #residual = torch.multiply(residual, feedback)

        if self.action == 0:
            out = torch.add(x , residual) # sum layer
        else:
            out = torch.cat((x, residual), 3)  # concat layer
        return out, mem

class DSRN(nn.Module):
    def __init__(self, nb_blocks, in_ch, out_ch,imsize,device, thresh, strides=2, kernelsize=5, padding=2):
        super(DSRN, self).__init__()
        self.nb_blocks = nb_blocks
        self.out_ch = out_ch
        self.kernel_size = kernelsize
        self.padding = padding
        self.stride = strides
        self.in_ch = in_ch
        self.imsize = imsize
        self.betta = 0.10 #torch.rand(self.conv1_s)
        self.th = thresh
        self.device = device
        self.class_size = 10

        self.Conv2d1 = nn.Conv2d(self.in_ch, self.out_ch, self.kernel_size, stride= self.stride, padding= self.padding)
        self.conv1_s = (self.imsize[0] - self.kernel_size + 2*self.padding)//(self.stride) + 1
        self.FC1 = nn.Linear(self.conv1_s, self.conv1_s)
        self.IFC1 = snn.Leaky(beta = self.betta, threshold= self.th, learn_beta=True)    #LIF(self.out_ch)
        

        self.conv2d2 = nn.Conv2d(self.out_ch, self.out_ch, self.kernel_size, stride= self.stride, padding= self.padding)
        self.conv2_s = (self.conv1_s - self.kernel_size + 2*self.padding)//(self.stride) + 1
        self.FC2 = nn.Linear(self.conv2_s, self.conv2_s)
        self.IFC2 = snn.Leaky(beta = self.betta, threshold= self.th, learn_beta=True)    #LIF(self.out_ch)


        self.SRBU1 = SRU(1, self.out_ch, self.out_ch, self.conv2_s, self.device, 
                         betta=self.betta, thrs=self.th, kernel_size= self.kernel_size, stride=self.stride, padding = self.padding) # no 
        self.SRBU1_s = self.SRBU1.out_s
        self.maxpool1 = nn.MaxPool2d((1,2),(1,2),0)
        self.maxpool1_s = [(self.SRBU1_s[0]-1)//1 + 1, (self.SRBU1_s[1]-2)//2 + 1]


        self.SRBU2 = SRU(1, self.out_ch, self.out_ch, self.conv2_s, self.device, 
                         betta=self.betta, thrs=self.th) # no downsampling
        self.SRBU2_s = self.SRBU2.out_s
        self.maxpool2 = nn.MaxPool2d((1,2),(1,2),0)
        self.maxpool2_s = [(self.maxpool1_s[0]-1)//1 + 1, (self.maxpool1_s[1]-2)//2 + 1]


        self.SRBU3 = SRU(0, self.out_ch, self.out_ch, self.conv2_s, self.device, 
                         betta=self.betta, thrs=self.th) # downsampling by 2
        self.SRBU3_s = self.SRBU3.out_s
        #self.maxpool3 = nn.MaxPool2d(3,2,1)
        self.maxpool3_s = [(self.maxpool2_s[0] -1)//2 + 1, (self.maxpool2_s[0] -1)//2 + 1]


        self.SRBU4 = SRU(0, self.out_ch, self.out_ch, self.conv2_s, self.device, 
                         betta=self.betta, thrs=self.th) # downsampling by 2
        self.SRBU4_s = self.SRBU4.out_s
        #self.SRBU4 = SRU(self.nb_blocks//4, self.out_ch, self.out_ch, self.conv1_s, self.device)
        #self.SRBU5 = SRU(self.nb_blocks//4, self.out_ch, self.out_ch, self.conv1_s, self.device)


        self.BN1 = nn.BatchNorm2d(self.out_ch)
        self.flatten = nn.Flatten(1)
        self.flatten_s = self.out_ch*self.SRBU4_s[0]*self.SRBU4_s[1]

        self.IFC3 = snn.Leaky(beta = self.betta, threshold= self.th)    #LIF(self.out_ch)
        self.FC3 = nn.Linear(self.flatten_s, int(math.sqrt(self.flatten_s)))
        self.FC4 = nn.Linear(int(math.sqrt(self.flatten_s)), self.class_size)
        self.IFC4 = snn.Leaky(beta= self.betta, threshold= self.th)      #LIF(self.out_ch)
        #self.softmax = nn.Softmax(dim =1)



    def forward(self, x):
        out = torch.empty(x.shape[0], 0,self.class_size, device= self.device)
        memc1 = self.IFC1.init_leaky()
        memc2 = self.IFC2.init_leaky()
        memc3 = self.IFC3.init_leaky()
        memc4 = self.IFC4.init_leaky()
        mem1 = self.SRBU1.meminit()
        mem2 = self.SRBU2.meminit()
        mem3 = self.SRBU3.meminit()
        mem4 = self.SRBU4.meminit()

        for step in range(self.nb_blocks):
            next_l = self.Conv2d1(x)
            next_l = self.FC1(next_l)
            next_l, memc1 = self.IFC1(next_l, memc1)


            next_l = self.conv2d2(next_l)
            next_l = self.FC2(next_l)
            next_l, memc2 = self.IFC2(next_l, memc2)



            next_l, mem1 = self.SRBU1(next_l, mem1)
            next_l = self.maxpool1(next_l)

            next_l, mem2 = self.SRBU2(next_l, mem2)
            next_l = self.maxpool2(next_l)

            next_l, mem3 = self.SRBU3(next_l, mem3)
            #next_l = self.maxpool3(next_l)

            next_l, mem4= self.SRBU4(next_l, mem4)
            #next_l, mem41, mem42 = self.SRBU4(next_l, mem41, mem42)

            next_l = self.BN1(next_l)
            next_l = self.flatten(next_l)
            next_l, memc3 = self.IFC3(next_l, memc3)
            next_l = self.FC3(next_l)
            next_l = self.FC4(next_l)
            next_l, memc4 = self.IFC4(next_l, memc4)
            out = torch.cat((out, next_l.unsqueeze(1)), 1)
        return out