In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install spikingjelly -q

In [None]:
import os
import time
import datetime
import errno
import math
from collections import defaultdict, deque

import torch
import torch.nn as nn
import torch.utils.data
import torch.distributed as dist
from torch import amp
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.dataloader import default_collate

from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode
from spikingjelly.clock_driven import layer, functional


In [None]:
# import torch
# import torch.nn as nn
# from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode
# from spikingjelly.clock_driven import layer

def conv3x3(in_channels, out_channels):
    return nn.Sequential(
        layer.SeqToANNContainer(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
        ),
        MultiStepParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

def conv1x1(in_channels, out_channels):
    return nn.Sequential(
        layer.SeqToANNContainer(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
        ),
        MultiStepParametricLIFNode(init_tau=2.0, detach_reset=True)
    )

class SEWBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, connect_f=None):
        super(SEWBlock, self).__init__()
        self.connect_f = connect_f
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        out = self.conv(x)
        if self.connect_f == 'ADD':
            out += x
        elif self.connect_f == 'AND':
            out *= x
        elif self.connect_f == 'IAND':
            out = x * (1. - out)
        else:
            raise NotImplementedError(self.connect_f)

        return out

class PlainBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(PlainBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),
            conv3x3(mid_channels, in_channels),
        )

    def forward(self, x: torch.Tensor):
        return self.conv(x)

class BasicBlock(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(BasicBlock, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(in_channels, mid_channels),

            layer.SeqToANNContainer(
                nn.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1, stride=1, bias=False),
                nn.BatchNorm2d(in_channels),
            ),
        )
        self.sn = MultiStepParametricLIFNode(init_tau=2.0, detach_reset=True)

    def forward(self, x: torch.Tensor):
        return self.sn(x + self.conv(x))


class ResNetN(nn.Module):
    def __init__(self, layer_list, num_classes, connect_f=None):
        super(ResNetN, self).__init__()
        in_channels = 2
        conv = []

        for cfg_dict in layer_list:
            channels = cfg_dict['channels']

            if 'mid_channels' in cfg_dict:
                mid_channels = cfg_dict['mid_channels']
            else:
                mid_channels = channels

            if in_channels != channels:
                if cfg_dict['up_kernel_size'] == 3:
                    conv.append(conv3x3(in_channels, channels))
                elif cfg_dict['up_kernel_size'] == 1:
                    conv.append(conv1x1(in_channels, channels))
                else:
                    raise NotImplementedError

            in_channels = channels


            if 'num_blocks' in cfg_dict:
                num_blocks = cfg_dict['num_blocks']
                if cfg_dict['block_type'] == 'sew':
                    for _ in range(num_blocks):
                        conv.append(SEWBlock(in_channels, mid_channels, connect_f))
                elif cfg_dict['block_type'] == 'plain':
                    for _ in range(num_blocks):
                        conv.append(PlainBlock(in_channels, mid_channels))
                elif cfg_dict['block_type'] == 'basic':
                    for _ in range(num_blocks):
                        conv.append(BasicBlock(in_channels, mid_channels))
                else:
                    raise NotImplementedError

            if 'k_pool' in cfg_dict:
                k_pool = cfg_dict['k_pool']
                conv.append(layer.SeqToANNContainer(nn.MaxPool2d(k_pool, k_pool)))

        conv.append(nn.Flatten(2))

        self.conv = nn.Sequential(*conv)

        with torch.no_grad():
            x = torch.zeros([1, 1, 128, 128])
            for m in self.conv.modules():
                if isinstance(m, nn.MaxPool2d):
                    x = m(x)
            out_features = x.numel() * in_channels

        self.out = nn.Linear(out_features, num_classes, bias=True)

    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)  # [T, N, 2, *, *]
        x = self.conv(x)
        return self.out(x.mean(0))

def SEWResNet(connect_f):
    layer_list = [
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'sew', 'k_pool': 2},
    ]
    num_classes = 5
    return ResNetN(layer_list, num_classes, connect_f)

def PlainNet(*args, **kwargs):
    layer_list = [
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'plain', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'plain', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'plain', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'plain', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'plain', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'plain', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'plain', 'k_pool': 2},
    ]
    num_classes = 11
    return ResNetN(layer_list, num_classes)

def SpikingResNet(*args, **kwargs):
    layer_list = [
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'basic', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'basic', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'basic', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'basic', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'basic', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'basic', 'k_pool': 2},
        {'channels': 32, 'up_kernel_size': 1, 'mid_channels': 32, 'num_blocks': 1, 'block_type': 'basic', 'k_pool': 2},
    ]
    num_classes = 11
    return ResNetN(layer_list, num_classes)

In [None]:
# from collections import defaultdict, deque
# import datetime
# import time
# import torch
# import torch.distributed as dist

# import errno
# import os


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {}'.format(header, total_time_str))


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k * (100.0 / batch_size))
        return res


def mkdir(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    elif hasattr(args, "rank"):
        pass
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    setup_for_distributed(args.rank == 0)




In [None]:
def parse_args():
    class Args:
        model = "SEWResNet"
        train_data_path = "/content/drive/MyDrive/npz_events_best"
        test_data_path = "/content/drive/MyDrive/npz_events_test_best"
        device = "cuda"
        batch_size = 16
        epochs = 25
        workers = 4
        lr = 5e-4
        momentum = 0.9
        weight_decay = 1e-4
        lr_step_size = 64
        lr_gamma = 0.1
        print_freq = 64
        output_dir = "/content/drive/MyDrive/logs"
        resume = ""
        start_epoch = 0
        sync_bn = False
        test_only = False
        amp = True
        world_size = 1
        dist_url = "env://"
        tb = True
        adam = True
        connect_f = "ADD"
        T_train = 12

    return Args()


In [None]:
_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class SpikingjellyDataset:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.files = []  # Store (file_path, label) tuples

        # Iterate through each folder (gesture class)
        for class_label, class_folder in enumerate(sorted(os.listdir(dataset_path))):
            class_path = os.path.join(dataset_path, class_folder)
            if os.path.isdir(class_path):  # Ensure it's a directory
                for file in sorted(os.listdir(class_path)):
                    if file.endswith(".npz"):  # Only use .npz files
                        self.files.append((os.path.join(class_path, file), class_label))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path, label = self.files[idx]  # Get file path and label
        data = np.load(file_path, allow_pickle=True)

        # Ensure keys exist
        required_keys = {"x", "y", "t", "p", "f"}
        if not required_keys.issubset(data.files):
            raise ValueError(f"Missing keys in {file_path}: {set(data.files) - required_keys}")

        x = data["x"].astype(np.float32)
        y = data["y"].astype(np.float32)
        t = data["t"].astype(np.float32)
        p = data["p"].astype(np.float32)
        folder_name = data["f"].item()

        events = np.stack([x, y, t, p], axis=1)  # Shape: (num_events, 4)

        return torch.from_numpy(events), label, folder_name

# Loader Class for Batch Processing
class Loader:
    def __init__(self, dataset, args, device, distributed, batch_size, drop_last=True , to_train = True):
        self.device = device
        if distributed is True:
            self.sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        else:
            if to_train:
                self.sampler = torch.utils.data.RandomSampler(dataset)
            else:self.sampler = torch.utils.data.SequentialSampler(dataset)

        self.loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=self.sampler,
                                                  num_workers=args.workers, pin_memory=True,
                                                  collate_fn=collate_events, drop_last=drop_last)

    def __iter__(self):
        for data in self.loader:
            data = [d.to(self.device) for d in data]
            yield data

    def __len__(self):
        return len(self.loader)

# Collate function to handle batching of events
def collate_events(data):
    labels = []
    events = []
    folder_names = []
    for i, d in enumerate(data):
        labels.append(d[1])
        folder_names.append(d[2])
        ev = torch.cat([d[0], i * torch.ones((len(d[0]), 1), dtype=torch.float32)], 1)
        events.append(ev)
    events = torch.cat(events, 0)
    labels = default_collate(labels)
    folder_names = default_collate(folder_names)
    return events, labels, folder_names

class QuantizationLayerVoxGrid(nn.Module):
    def __init__(self, dim, mode):
        nn.Module.__init__(self)
        self.dim = dim
        self.mode = mode

    def forward(self, events):
        epsilon = 10e-3
        B = int(1+events[-1, -1].item())
        # tqdm.write(str(B))
        num_voxels = int(2 * np.prod(self.dim) * B)
        C, H, W = self.dim
        vox = events[0].new_full([num_voxels, ], fill_value=0)
        # get values for each channel
        x, y, t, p, b = events.T
        x = x.to(torch.int64)
        y = y.to(torch.int64)
        p = p.to(torch.int64)
        b = b.to(torch.int64)
        # normalizing timestamps
        unit_len = []
        t_idx = []
        for bi in range(B):
            bi_idx = events[:, -1] == bi
            t[bi_idx] /= t[bi_idx].max()
            unit_len.append(int(bi_idx.float().sum() / C))
            _, t_idx_ = torch.sort(t[events[:, -1] == bi])
            t_idx.append(t_idx_)
        idx_before_bins = x \
                          + W * y \
                          + 0 \
                          + W * H * C * p \
                          + W * H * C * 2 * b


        for i_bin in range(C):
            values = torch.zeros_like(t)
            for bi in range(B):
                bin_idx = t_idx[bi][i_bin * unit_len[bi]: (i_bin + 1) * unit_len[bi]]
                bin_values = values[events[:, -1] == bi]
                bin_values[bin_idx] = 1
                values[events[:, -1] == bi] = bin_values
            # draw in voxel grid
            idx = idx_before_bins + W * H * i_bin
            vox.put_(idx.long(), values, accumulate=True)

        vox = vox.view(-1, 2, C, H, W)#.clamp(0, 1)
        if self.mode == "TB":
            vox = vox.permute(2, 0, 1, 3, 4).contiguous()
        elif self.mode == "BT":
            vox = vox.permute(0, 2, 1, 3, 4).contiguous()
        else:
            raise Exception
        return vox

def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, quantizer,print_freq, scaler=None, T_train=None):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)

    for event, target , foldername in data_loader:
        start_time = time.time()
        image = quantizer(event)
        image, target = image.to(device), target.to(device)
        image = image.float()  # [N, T, C, H, W]

        if T_train:
            sec_list = np.random.choice(image.shape[1], T_train, replace=False)
            sec_list.sort()
            image = image[:, sec_list]

        if scaler is not None:
            with torch.amp.autocast(device_type="cuda"):
                output = model(image)
                loss = criterion(output, target)
        else:
            output = model(image)
            loss = criterion(output, target)

        optimizer.zero_grad()

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        else:
            loss.backward()
            optimizer.step()

        functional.reset_net(model)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        loss_s = loss.item()
        if math.isnan(loss_s):
            raise ValueError('loss is Nan')
        acc1_s = acc1.item()
        acc5_s = acc5.item()

        metric_logger.update(loss=loss_s, lr=optimizer.param_groups[0]["lr"])

        metric_logger.meters['acc1'].update(acc1_s, n=batch_size)
        metric_logger.meters['acc5'].update(acc5_s, n=batch_size)
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    return metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg



def evaluate(model, criterion, data_loader, device,quantizer, print_freq=100, header='Test:'):
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    with torch.no_grad():
        for event, target , foldername in data_loader:
            image = quantizer(event)
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            image = image.float()
            output = model(image)
            loss = criterion(output, target)
            functional.reset_net(model)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

    loss, acc1, acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    # print(f' * Acc@1 = {acc1}, Acc@5 = {acc5}, loss = {loss}')
    return loss, acc1, acc5


def main(args):

    max_test_acc1 = 0.
    test_acc5_at_max_test_acc1 = 0.


    train_tb_writer = None
    te_tb_writer = None

    init_distributed_mode(args)
    print(args)

    output_dir = os.path.join(args.output_dir, f'{args.model}_b{args.batch_size}')

    if args.T_train:
        output_dir += f'_Ttrain{args.T_train}'

    if args.weight_decay:
        output_dir += f'_wd{args.weight_decay}'

    output_dir += f'_steplr{args.lr_step_size}_{args.lr_gamma}'

    if args.adam:
        output_dir += '_adam'
    else:
        output_dir += '_sgd'

    if args.connect_f:
        output_dir += f'_cnf_{args.connect_f}'

    if not os.path.exists(output_dir):
        mkdir(output_dir)

    output_dir = os.path.join(output_dir, f'lr{args.lr}')
    if not os.path.exists(output_dir):
        mkdir(output_dir)



    device = torch.device(args.device)

    # data_path = args.data_path
    dataset_train = SpikingjellyDataset(args.train_data_path)
    dataset_test = SpikingjellyDataset(args.test_data_path)

    print(f"dataset_train {len(dataset_train)} , dataset_test {len(dataset_test)}")

    distributed = False
    batch_size  = 16
    data_loader = Loader(dataset=dataset_train, args=args, device=device, distributed=distributed, batch_size=batch_size)
    data_loader_test = Loader(dataset=dataset_test, args=args, device=device, distributed=distributed, batch_size=batch_size , to_train = False)

    quantizer = QuantizationLayerVoxGrid(dim=(16, 128 ,128), mode="BT")

    model = SEWResNet(args.connect_f)
    print("Creating model")

    model.to(device)

    num_params = count_trainable_params(model)
    print(f"Total Trainable Parameters: {num_params:,}")
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    if args.adam:
        optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.amp:
        scaler = torch.amp.GradScaler() if args.amp else None
    else:
        scaler = None

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    # if args.resume:
    #     checkpoint = torch.load(args.resume, map_location='cpu')
    #     model_without_ddp.load_state_dict(checkpoint['model'])
    #     optimizer.load_state_dict(checkpoint['optimizer'])
    #     lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    #     args.start_epoch = checkpoint['epoch'] + 1
    #     max_test_acc1 = checkpoint['max_test_acc1']
    #     test_acc5_at_max_test_acc1 = checkpoint['test_acc5_at_max_test_acc1']


    # if args.tb and is_main_process():
    #     purge_step_train = args.start_epoch
    #     purge_step_te = args.start_epoch
    #     train_tb_writer = SummaryWriter(output_dir + '_logs/train', purge_step=purge_step_train)
    #     te_tb_writer = SummaryWriter(output_dir + '_logs/te', purge_step=purge_step_te)
    #     with open(output_dir + '_logs/args.txt', 'w', encoding='utf-8') as args_txt:
    #         args_txt.write(str(args))

    #     print(f'purge_step_train={purge_step_train}, purge_step_te={purge_step_te}')

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        save_max = False
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_loss, train_acc1, train_acc5 = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,quantizer, args.print_freq, scaler, args.T_train )
        # print(f"Train Loss {train_loss} , Train_acc1 = {train_acc1} , Train acc5 = {train_acc5}")
        # if is_main_process():
        #     train_tb_writer.add_scalar('train_loss', train_loss, epoch)
        #     train_tb_writer.add_scalar('train_acc1', train_acc1, epoch)
        #     train_tb_writer.add_scalar('train_acc5', train_acc5, epoch)
        lr_scheduler.step()

        test_loss, test_acc1, test_acc5 = evaluate(model, criterion, data_loader_test,quantizer=quantizer, device=device, header='Test:')
        print(f"Epoch {epoch} ,Train Loss {train_loss} , Train_acc1 = {train_acc1} , Train acc5 = {train_acc5} , test_loss = {test_loss} , test_acc_1 ={test_acc1}")
        # if te_tb_writer is not None:
        #     if is_main_process():

        #         te_tb_writer.add_scalar('test_loss', test_loss, epoch)
        #         te_tb_writer.add_scalar('test_acc1', test_acc1, epoch)
        #         te_tb_writer.add_scalar('test_acc5', test_acc5, epoch)


        # if max_test_acc1 < test_acc1:
        #     max_test_acc1 = test_acc1
        #     test_acc5_at_max_test_acc1 = test_acc5
        #     save_max = True


        # if output_dir:

        #     checkpoint = {
        #         'model': model_without_ddp.state_dict(),
        #         'optimizer': optimizer.state_dict(),
        #         'lr_scheduler': lr_scheduler.state_dict(),
        #         'epoch': epoch,
        #         'args': args,
        #         'max_test_acc1': max_test_acc1,
        #         'test_acc5_at_max_test_acc1': test_acc5_at_max_test_acc1,
        #     }

            # if save_max:
            #     save_on_master(
            #         checkpoint,
            #         os.path.join(output_dir, 'checkpoint_max_test_acc1.pth'))
        # print(args)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))

        # print('Training time {}'.format(total_time_str), 'max_test_acc1', max_test_acc1, 'test_acc5_at_max_test_acc1', test_acc5_at_max_test_acc1)
        # print(output_dir)
    # if output_dir:
    #     save_on_master(
    #         checkpoint,
    #         os.path.join(output_dir, f'checkpoint_{epoch}.pth'))

    return max_test_acc1



if __name__ == "__main__":
    args = parse_args()
    main(args)



Not using distributed mode
<__main__.parse_args.<locals>.Args object at 0x796e091edc10>
dataset_train 8600 , dataset_test 1252
Creating model
Total Trainable Parameters: 130,228
Start training
Epoch 0 ,Train Loss 1.1465527811529916 , Train_acc1 = 55.24906890130354 , Train acc5 = 100.0 , test_loss = 1.115483219233843 , test_acc_1 =55.12820512820513
Epoch 1 ,Train Loss 0.8607932464592506 , Train_acc1 = 68.45903165735568 , Train acc5 = 100.0 , test_loss = 1.177893294690129 , test_acc_1 =54.006410256410255
Epoch 2 ,Train Loss 0.7235828130826826 , Train_acc1 = 74.39478584729981 , Train acc5 = 100.0 , test_loss = 0.6894715288892771 , test_acc_1 =75.64102564102564
Epoch 3 ,Train Loss 0.6296454407760329 , Train_acc1 = 78.0842644320298 , Train acc5 = 100.0 , test_loss = 0.6348970941721629 , test_acc_1 =77.00320512820512
