In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import argparse
import pickle as pkl
import time
from copy import deepcopy
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm

In [2]:
import os
import pandas as pd
from torchvision.io import read_image
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from random import shuffle
import torch.nn.functional as F
from torchvision.transforms import v2
from torch import nn



In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, end_maxpool = False):
        super(ResidualBlock, self).__init__()
        if(downsample is not None):
            self.conv1 = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same',bias=False),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=False),
                            nn.MaxPool2d(kernel_size=2, stride=2)
                            )  # Changed inplace to False
        else:
            self.conv1 = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same',bias=False),
                            nn.BatchNorm2d(out_channels),
                            nn.Hardtanh(min_val=0.0, max_val=1.0, inplace=False)
                            )
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1,bias=False),
                        nn.BatchNorm2d(out_channels),
                        nn.Hardtanh(min_val=0.0, max_val=1.0, inplace=False))  # Changed inplace to False
        self.downsample = downsample
        self.relu = nn.Hardtanh(min_val=0.0, max_val=1.0, inplace=False)  # Changed inplace to False
        self.out_channels = out_channels
        self.end_maxpool = end_maxpool

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out = out + residual
        if self.end_maxpool:
            out = F.relu(out, inplace=False)
        else:
            out = F.hardtanh(out, inplace=False, min_val=0.0, max_val=1.0)   # Use non-in-place ReLU
        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 2, in_chanels = 10):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_chanels, 64, kernel_size = 7, stride = 2, padding = 3),
                        nn.BatchNorm2d(64),
                        nn.ReLU(inplace=False))
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride = 2, end_maxpool = True)
        self.avgpool = nn.MaxPool2d(7, stride=1)
        self.fc = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1, end_maxpool = False):
        downsample = None
        if stride != 1 or self.inplanes != planes:

            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=1, padding='same',bias=False),
                nn.BatchNorm2d(planes),
                nn.ReLU(inplace=False),
                nn.MaxPool2d(kernel_size=2, stride=2)
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            if i == blocks-1 and end_maxpool:
                layers.append(block(self.inplanes, planes, end_maxpool = True))
            else:
                layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = F.hardtanh(x, min_val=0.0, max_val=1.0)
        x = self.fc2(x)
        return x

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, end_maxpool = False):
        super(ResidualBlock, self).__init__()
        if(downsample is not None):
            self.conv1 = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=False),
                            nn.MaxPool2d(kernel_size=2, stride=2)
                            )  # Changed inplace to False
        else:
            self.conv1 = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
                            nn.BatchNorm2d(out_channels),
                            nn.Hardtanh(min_val=0.0, max_val=1.0, inplace=False)
                            )
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(out_channels),
                        nn.Hardtanh(min_val=0.0, max_val=1.0, inplace=False))  # Changed inplace to False
        self.downsample = downsample
        self.relu = nn.Hardtanh(min_val=0.0, max_val=1.0, inplace=False)  # Changed inplace to False
        self.out_channels = out_channels
        self.end_maxpool = end_maxpool

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out = out + residual
        if self.end_maxpool:
            out = F.relu(out, inplace=False)
        else:
            out = F.hardtanh(out, inplace=False, min_val=0.0, max_val=1.0)   # Use non-in-place ReLU
        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 2, in_channels = 5):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, 64, kernel_size = 7, stride = 2, padding = 3),
                        nn.BatchNorm2d(64),
                        nn.ReLU(inplace=False))
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride = 2, end_maxpool = True)
        self.avgpool = nn.MaxPool2d(7, stride=1)
        self.fc = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.dropout = nn.Dropout(0.1)

    def _make_layer(self, block, planes, blocks, stride=1, end_maxpool = False):
        downsample = None
        if stride != 1 or self.inplanes != planes:

            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=1, padding='same'),
                nn.BatchNorm2d(planes),
                nn.ReLU(inplace=False),
                nn.MaxPool2d(kernel_size=2, stride=2)
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            if i == blocks-1 and end_maxpool:
                layers.append(block(self.inplanes, planes, end_maxpool = True))
            else:
                layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = F.hardtanh(x,min_val=0, max_val=1)
        # x = self.dropout(x)
        x = self.fc2(x)
        return x# F.hardtanh(x)

In [5]:
# model_resnet = ResNet(ResidualBlock, [5, 6, 6, 4], in_channels = 6, num_classes=101).to("cuda")

In [6]:
model_resnet = ResNet(ResidualBlock, [5, 6, 6, 4], num_classes = 101, in_channels=6)

In [7]:
model_resnet.load_state_dict(torch.load("best_resnet_nCaltech101_ReLU1_ReLUmaxpool_EST__FC2_corrected_exp_copy1.pt", weights_only=True))
model_resnet.eval().to("cuda")
print(model_resnet)

  return self.fget.__get__(instance, owner)()


ResNet(
  (conv1): Sequential(
    (0): Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer0): Sequential(
    (0): ResidualBlock(
      (conv1): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Hardtanh(min_val=0.0, max_val=1.0)
      )
      (conv2): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Hardtanh(min_val=0.0, max_val=1.0)
      )
      (relu): Hardtanh(min_val=0.0, max_val=1.0)
    )
    (1): ResidualBlock(
      (conv1): Sequential(
        (0): Conv2d(64, 64, kernel_s

In [8]:
temp = torch.rand((1,6,224,224))
model_resnet(temp.to("cuda"))

tensor([[ 3.2163,  0.4308, -1.1207, -3.5953,  1.6711,  1.5722,  1.7449,  0.7376,
         -3.8387,  0.9331,  1.3603, -4.8497, -2.3277,  0.3880, -3.4415, -5.0047,
         -0.3977, -6.0568,  1.5379, -2.9769, -0.9855,  0.7568,  2.3783,  1.0269,
         -2.4769, -0.4099,  1.4334, -2.0948, -3.3150,  1.1027, -0.8569,  0.0149,
         -1.6380, -3.6980, -0.1839,  2.6007,  1.1846,  0.1530, -1.5050, -1.8727,
          0.6083, -2.7449, -0.2699, -2.9457, -4.5308,  1.4536, -2.6619,  5.2084,
         -5.5468, -2.3732, -5.0009, -1.0081,  1.5027, -1.6777, -3.3897,  0.4584,
          1.3667, -1.3254, -1.9281, -0.7660, -1.4549, -1.1950,  1.8507, -0.5560,
          0.9874, -2.6832,  0.6909, -2.3641, -3.1094,  0.6242, -1.7520, -3.2381,
          2.4555, -2.7448,  0.7174, -1.1286, -2.0757, -2.7847,  1.0638, -2.4839,
         -2.5655, -2.7873,  2.5571, -6.8863, -2.7477, -1.5022,  1.7753, -2.4863,
          0.7694,  3.8319, -1.4426, -2.2564,  6.5023, -3.5000,  0.1364, -2.3715,
         -4.0597,  3.9817,  

In [9]:
from SNN.Check_ReLU1 import fuse_conv_and_bn
from SNN.Check_ReLU1 import SpikingConv2D
from SNN.Check_ReLU1 import MaxMinPool2D
from SNN.Check_ReLU1 import LayerSNN_ReLU1
from SNN.Check_ReLU1 import SpikingDense_positive_ReLU1, SpikingDense_pos_all, SpikingDense

In [10]:
class ResNet_ttfs(nn.Module):
    def __init__(self, model : ResNet, in_channels = 5):
        super(ResNet_ttfs, self).__init__()
        model.eval()
        robustness_params={
            'noise':0.0,
            'time_bits':0,
            'weight_bits': 0,
            'latency_quantiles':0.0
        }
        model.to('cuda')
        conv_fused = fuse_conv_and_bn(model.conv1[0], model.conv1[1], device = 'cuda')
        self.conv_first = SpikingConv2D(64, "temp1", device = 'cuda', padding=(3,3), stride=2, kernel_size=(7,7),robustness_params=robustness_params, kernels=conv_fused.weight.data, biases= conv_fused.bias.data)
        max_vect = torch.tensor([1]*in_channels)
        tmin, tmax, max_vect, scalar = self.conv_first.set_params(0.0,1.0,max_vect)
        self.tmin_post_pool = tmax
        self.pool = MaxMinPool2D(3, tmax.data,2,1).to("cuda")
        self.layer0SNN = LayerSNN_ReLU1(model.layer0, 64, 64, 5,device = 'cuda')
        tmax_prev = tmax
        tmin, tmax, max_vect, scalar = self.layer0SNN.set_params(tmin, tmax, max_vect, in_scalar=scalar)
        self.layer1SNN = LayerSNN_ReLU1(model.layer1, 64, 128, 6,device = 'cuda')
        tmin, tmax, max_vect, scalar = self.layer1SNN.set_params(tmin, tmax, max_vect, in_scalar=scalar)
        self.layer2SNN = LayerSNN_ReLU1(model.layer2, 128, 256, 6,device = 'cuda')
        tmin, tmax, max_vect, scalar = self.layer2SNN.set_params(tmin, tmax, max_vect, in_scalar=scalar)
        self.layer3SNN = LayerSNN_ReLU1(model.layer3, 256, 512, 4,device = 'cuda',end_maxpool=True)
        tmin, tmax, max_vect, scalar = self.layer3SNN.set_params(tmin, tmax, max_vect, in_scalar=scalar)
        self.pool2 = MaxMinPool2D(7, tmax.data,1,0).to("cuda")
        # temp = torch.ones((512,2,2))
        # max_vect = ((temp.T)*max_vect[:512]).T
        # max_vect = max_vect.view(-1)

        max_vect = (torch.ones([512, 1, 1]).to("cuda").T*max_vect).T.contiguous().view(512).to("cuda")

        weights = model.fc.weight.detach().clone()
        biases = model.fc.bias.detach().clone()
        self.layer_fc = SpikingDense_positive_ReLU1(256, 512,"test",robustness_params=robustness_params,device = 'cuda',weights=weights, biases=biases)
        
        tmin, tmax, max_vect, scalar = self.layer_fc.set_params(tmin, tmax, max_vect, in_scalar=scalar)
        self.layer_fc2 = SpikingDense_pos_all(101,256, '',model.fc2.weight, model.fc2.bias,robustness_params=robustness_params,device = 'cuda')
        tmin, tmax, max_vect, scalar = self.layer_fc2.set_params(tmin, tmax, max_vect, in_scalar=scalar)
        self.tmin, self.tmax = tmin, tmax
        self.scalar = scalar
    
    def forward(self, tj):
        spike_sum, v_sqare_sum = 0,0.0

        x, spike, v  = self.conv_first(tj)

        spike_sum += spike
        v_sqare_sum +=v
        x, spike = self.pool(x)

        spike_sum += spike
        # print(x.shape)
        x = x
        x, spike1, v1 = self.layer0SNN(x)

        x, spike2, v2 = self.layer1SNN(x)

        x, spike3, v3 = self.layer2SNN(x)

        x, spike4, v4 = self.layer3SNN(x)

        # print(x[1].max(),x[1].min())
        x, spike5 = self.pool2(x)
        x = x.contiguous().view(x.size(0), -1)
        x, spike6, v5 = self.layer_fc(x)

        x, spike7, v6 = self.layer_fc2(x)

        spike_sum += spike1+spike2+spike3+spike4+spike5+spike6+spike7
        v_sqare_sum += v1+v2+v3+v4+v5+v6
        return x, spike_sum, v_sqare_sum

In [11]:
model_ttfs = ResNet_ttfs(model_resnet,6)

  self.t_max = torch.tensor(t_min + self.B_n*max_V+eps_V, dtype=torch.float64, requires_grad=False)
  self.t_max = torch.tensor(max(t_min + self.B_n*max_V+eps_V, minimal_t_max), dtype=torch.float64, requires_grad=False)
  self.t_min = torch.tensor(t_min, dtype=torch.float64, requires_grad=False)
  self.t_min = torch.tensor(t_min, dtype=torch.float64, requires_grad=False)
  self.t_max = torch.tensor(max(t_min + self.B_n*max_V, minimal_t_max), dtype=torch.float64, requires_grad=False)
  self.t_min = torch.tensor(t_min, dtype=torch.float64, requires_grad=False).to(self.device)
  self.t_max = torch.tensor(t_min + self.B_n*max_V, dtype=torch.float64, requires_grad=False).to(self.device)
  self.t_min = torch.tensor(t_min, dtype=torch.float64, requires_grad=False).to(self.device)
  self.t_max = torch.tensor(max(t_min + self.B_n*max_V, minimal_t_max), dtype=torch.float64, requires_grad=False).to(self.device)
  self.t_min_prev = torch.tensor(t_min_prev, dtype=torch.float64, requires_grad=False)

In [12]:
temp = torch.rand((1,6,224,224))
temp_ttfs = 1 - temp
# print(model_ttfs.forward(temp_ttfs))
out, spike_sum, v_sqare_sum = model_ttfs.forward(temp_ttfs.to("cuda"))

print(spike_sum, v_sqare_sum)
print(((model_ttfs.tmax - out)*model_ttfs.scalar)[0,:101] - ((model_ttfs.tmax - out)*model_ttfs.scalar)[0,101:])

tensor(6456915, device='cuda:0') tensor(529491.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor([ 2.8947,  0.4039, -1.2676, -3.5619,  1.7669,  1.6346,  1.9091,  0.6477,
        -3.7858,  1.2738,  1.4795, -4.9124, -2.0185,  0.3367, -3.1000, -4.8390,
        -0.5208, -6.1394,  1.3058, -2.6729, -0.7281,  0.7290,  2.5128,  1.0366,
        -2.2848, -0.3446,  1.4807, -2.1247, -3.3326,  1.3377, -1.1378, -0.0867,
        -1.5823, -3.7456, -0.1256,  2.8010,  1.3228,  0.2430, -1.4812, -1.9609,
         0.2256, -2.5879, -0.2733, -2.8508, -4.3779,  1.3920, -2.8964,  5.1185,
        -5.5012, -2.6086, -4.6218, -0.8687,  1.5102, -1.6151, -3.3670,  0.6029,
         1.3008, -1.2377, -1.7246, -0.7928, -1.5728, -1.3493,  1.9646, -0.5245,
         0.5710, -2.3229,  0.7286, -2.2595, -3.1004,  0.4798, -1.8245, -3.3902,
         2.5551, -3.0262,  0.5963, -1.1345, -2.2159, -2.8213,  0.9346, -2.4465,
        -2.7852, -2.7342,  2.3237, -6.7722, -2.9644, -1.7615,  1.5442, -2.1599,
         0.7501,  3.8596, 

In [13]:
model_resnet(temp.to("cuda"))

tensor([[ 2.9002,  0.4078, -1.2630, -3.5515,  1.7679,  1.6281,  1.9090,  0.6418,
         -3.7841,  1.2735,  1.4803, -4.9193, -2.0207,  0.3394, -3.1032, -4.8344,
         -0.5181, -6.1425,  1.3027, -2.6710, -0.7210,  0.7268,  2.5181,  1.0448,
         -2.2818, -0.3474,  1.4786, -2.1328, -3.3226,  1.3367, -1.1439, -0.0896,
         -1.5788, -3.7460, -0.1333,  2.8032,  1.3185,  0.2451, -1.4808, -1.9630,
          0.2196, -2.5879, -0.2691, -2.8526, -4.3784,  1.3905, -2.8941,  5.1162,
         -5.4992, -2.6100, -4.6183, -0.8585,  1.5119, -1.6161, -3.3759,  0.5940,
          1.2981, -1.2374, -1.7209, -0.7953, -1.5695, -1.3563,  1.9699, -0.5235,
          0.5704, -2.3181,  0.7195, -2.2651, -3.1022,  0.4874, -1.8184, -3.3926,
          2.5557, -3.0322,  0.5854, -1.1347, -2.2217, -2.8287,  0.9368, -2.4488,
         -2.7861, -2.7304,  2.3245, -6.7750, -2.9689, -1.7672,  1.5368, -2.1581,
          0.7554,  3.8485, -1.4995, -2.0207,  6.1667, -3.7898,  0.2828, -2.4903,
         -3.6815,  4.1028,  

In [14]:
from torchvision.transforms import v2
torch.manual_seed(19)

transforms = v2.Compose([
    # v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(60),
    v2.ToDtype(torch.float32)
])


class NCaltech101ImageDataset(Dataset):
    def __init__(self, img_dir_file, transform=None, target_transform=None):
        self.images = np.load(img_dir_file + '_x.npy')
        self.labels = np.load(img_dir_file + '_y.npy')
        self.transform = transform
        self.transform = transform
        self.stage = 0

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        # if self.transform:
        #     image = self.transform(image)
        # if self.target_transform:
        #     label = self.target_transform(label)
        label_temp = np.zeros((101,))
        label_temp[label] = 1
        if self.stage == 0:
            return self.transform(torch.tensor(image)), torch.tensor(label_temp)
        else:
            return torch.tensor(image), torch.tensor(label_temp)
    
    def set_stage(self, stage):
        self.stage = stage

In [15]:
from torch.utils.data import DataLoader
# training_data = NCarsImageDataset("./Datasety/nCars_train_EST_exp_", transform=transforms)
data = NCaltech101ImageDataset("./Datasety/Ncaltech101_EST_exp_corr", transform=transforms)


In [16]:
data.set_stage(1)
spike_acc = 0
MAE_sum = 0
MSE_sum = 0
TP = 0
for i in tqdm(range(len(data))):#10
    with torch.no_grad():
        with torch.cuda.amp.autocast():
            temp_data, temp_label = deepcopy(data[i])
            temp_ttfs = (1 - temp_data).unsqueeze(0).to("cuda")
            temp_data = temp_data.to("cuda")
            # print(temp_data.shape)
            # print(model_ttfs.forward(temp_ttfs))
            out, spike_sum, v_sqare_sum = model_ttfs.forward(temp_ttfs)
            spike_acc+=spike_sum
            output = ((model_ttfs.tmax - out)*model_ttfs.scalar)[0,:101] - ((model_ttfs.tmax - out)*model_ttfs.scalar)[0,101:]
            output_true = model_resnet(temp_data.unsqueeze(0))
            MAE = torch.sum((output-output_true).abs())/101.0
            MSE = torch.sqrt(torch.sum((output-output_true)*(output-output_true)))
            MAE_sum += MAE
            MSE_sum += MSE
            TP += (torch.argmax(temp_label)==torch.argmax(output)) 
            del temp_data, temp_ttfs
    if i%100==0:
        torch.cuda.empty_cache()
print(spike_acc/len(data))
print("Mean MAE:", MAE_sum/len(data))
print("Mean MSE:", MSE_sum/(len(data)*101.0))
print("accuracy:",float(TP)/len(data))

100%|██████████| 8709/8709 [1:11:43<00:00,  2.02it/s]

tensor(6672495.5000, device='cuda:0')
Mean MAE: tensor(0.0088, device='cuda:0')
Mean MSE: tensor(0.0011, device='cuda:0')
accuracy: 0.7449764611321621





In [19]:
print(TP)

tensor(6488, device='cuda:0')


In [23]:
i = 2201

with torch.no_grad():
    with torch.cuda.amp.autocast():
        temp_data, temp_label = deepcopy(data[i])
        temp_ttfs = (1 - temp_data).unsqueeze(0).to("cuda")
        temp_data = temp_data.to("cuda")
        # print(temp_data.shape)
        # print(model_ttfs.forward(temp_ttfs))
        out, spike_sum, v_sqare_sum = model_ttfs.forward(temp_ttfs)
        spike_acc+=spike_sum
        output = ((model_ttfs.tmax - out)*model_ttfs.scalar)[0,:101] - ((model_ttfs.tmax - out)*model_ttfs.scalar)[0,101:]
        output_true = model_resnet(temp_data.unsqueeze(0))
        MAE = torch.sum((output-output_true).abs())/101.0
        MSE = torch.sqrt(torch.sum((output-output_true)*(output-output_true)))
        MAE_sum += MAE
        MSE_sum += MSE
        TP += (torch.argmax(temp_label)==torch.argmax(output)) 
        print(torch.argmax(temp_label), torch.argmax(output))
        del temp_data, temp_ttfs

tensor(16) tensor(16, device='cuda:0')
