In [None]:
import os
import torch
import models, train_classify
import datetime
import os
import time
import warnings
from tv_ref_classify import presets, transforms, utils
import torch
import torch.utils.data
import torchvision
from tv_ref_classify.sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import sys
import argparse
from sklearn.model_selection import StratifiedKFold

try:
    from torchvision import prototype
except ImportError:
    prototype = None

import json

model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]))


def set_deterministic(_seed_: int = 2020, disable_uda=False):
    random.seed(_seed_)
    np.random.seed(_seed_)
    torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
    torch.cuda.manual_seed_all(_seed_)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    if disable_uda:
        pass
    else:
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
        # set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
        torch.use_deterministic_algorithms(True, warn_only= True)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


In [9]:
def get_args_parser(add_help=True):

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    parser.add_argument("--datadir", default='/local/data/acxyle/Datasets', type=str)
    parser.add_argument("--dataset", default='C2k', type=str)

    parser.add_argument("--arch", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument("-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")
    parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument("-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)")
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=1e-2, type=float, help="initial learning rate")

    parser.add_argument('--neuron', type=str, default='LIF')
    parser.add_argument('--surrogate', type=str, default='ATan')
    parser.add_argument('--T', type=int, default=4)
    parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend")
    
    parser.add_argument('--data-fold-training', type=bool, default=True)
    parser.add_argument('--data-fold-number', type=int, default=5 )
    parser.add_argument('--data-fold-index', type=int, default=0)     # if None, run all experiments, 
    
    parser.add_argument("--add_info", type=str, default='')

    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument("--wd", "--weight-decay", default=0., type=float, metavar="W", help="weight decay (default: 0.)", dest="weight_decay")
    parser.add_argument("--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)")
    
    parser.add_argument("--label-smoothing", default=0.1, type=float, help="label smoothing (default: 0.1)", dest="label_smoothing")

    parser.add_argument("--mixup-alpha", default=0.2, type=float, help="mixup alpha (default: 0.2)")
    parser.add_argument("--cutmix-alpha", default=0.2, type=float, help="cutmix alpha (default: 1.0)")
    parser.add_argument("--lr-scheduler", default="cosa", type=str, help="the lr scheduler (default: cosa)")
    parser.add_argument("--lr-warmup-epochs", default=5, type=int, help="the number of epochs to warmup (default: 5)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--resume", default=None, type=str, help="path of checkpoint. If set to 'latest', it will try to load the latest checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")

    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )

    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )
    parser.add_argument("--auto-augment", default='ta_wide', type=str, help="auto augment policy (default: ta_wide)")
    parser.add_argument("--random-erase", default=0.1, type=float, help="random erasing probability (default: 0.1)")

    # distributed training parameters
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument(
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
    parser.add_argument(
        "--model-ema-decay",
        type=float,
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
    )
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=232, type=int, help="the resize size used for validation (default: 232)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=176, type=int, help="the random crop size used for training (default: 176)"
    )
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=4, type=int, help="number of repetitions for Repeated Augmentation (default: 4)"
    )

    # Prototype models only
    parser.add_argument(
        "--prototype",
        dest="prototype",
        help="Use prototype model builders instead those from main area",
        action="store_true",
    )
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
    parser.add_argument("--seed", default=2020, type=int, help="the random seed")

    parser.add_argument("--print-logdir", action="store_true", help="print the dirs for tensorboard logs and pt files and exit")
    parser.add_argument("--disable-pinmemory", action="store_true", help="not use pin memory in dataloader, which can help reduce memory consumption")
    parser.add_argument("--disable-amp", action="store_true", help="not use automatic mixed precision training")
    parser.add_argument("--local_rank", type=int, help="args for DDP, which should not be set by user")
    parser.add_argument("--disable-uda", action="store_true", help="not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation")

    return parser

In [19]:
# 保留脚本名，清掉所有 Jupyter 插件参数
sys.argv = sys.argv[:1]
args = get_args_parser().parse_args()

args.data_path = os.path.join(args.datadir, args.dataset)
args.distributed = False

In [14]:
print(args)

Namespace(datadir='/local/data/acxyle/Datasets', dataset='C2k', arch='resnet18', device='cuda', batch_size=32, epochs=300, workers=16, opt='sgd', lr=0.01, neuron='LIF', surrogate='ATan', T=4, cupy=False, data_fold_training=True, data_fold_number=5, data_fold_index=0, add_info='', momentum=0.9, weight_decay=0.0, norm_weight_decay=None, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=0.2, lr_scheduler='cosa', lr_warmup_epochs=5, lr_warmup_method='linear', lr_warmup_decay=0.01, lr_step_size=30, lr_gamma=0.1, resume=None, start_epoch=0, sync_bn=False, pretrained=False, auto_augment='ta_wide', random_erase=0.1, world_size=1, dist_url='env://', model_ema=False, model_ema_steps=32, model_ema_decay=0.99998, interpolation='bilinear', val_resize_size=232, val_crop_size=224, train_crop_size=176, clip_grad_norm=None, ra_sampler=False, ra_reps=4, prototype=False, weights=None, seed=2020, print_logdir=False, disable_pinmemory=False, disable_amp=False, local_rank=None, disable_uda=False, data_path

In [None]:
def load_fused_data_webdataset(args, ):
    """ 
    if does not detected sk list file, generate it
    if detected, load it
    
    be careful, the train and val datasets have different preprocessing strategies
    """

    dataset = torchvision.datasets.ImageFolder(root = args.data_path, 
                                                transform = presets.ClassificationPresetEval(
                                                            resize_size=args.val_resize_size, 
                                                            crop_size=args.val_crop_size, 
                                                            interpolation=InterpolationMode(args.interpolation)
                                                                                        ),)
    
    skf = StratifiedKFold(n_splits=args.data_fold_number, shuffle=True, random_state=42)
    imgs, clses = zip(*dataset.imgs)
    skfold_indices_list = [_ for _ in skf.split(imgs, clses)]     # the dataset processed here have same preprocess
    
    skfold_indices_list = [[__.tolist() for __ in _] for _ in skfold_indices_list]
    
    # now we have the indices for fold experiments
    assert args.data_fold_index < len(skfold_indices_list)     # throw error if out of bound
    [train_data_indices, val_data_indices] = skfold_indices_list[args.data_fold_index]
    
    # now use 2 different preprocesses to load the full dataset, then generate the exact train/val dataset by the indices
    val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
    interpolation = InterpolationMode(args.interpolation)

    auto_augment_policy = getattr(args, "auto_augment", None)
    random_erase_prob = getattr(args, "random_erase", 0.0)
    
    print("Loading training data")
    dataset = torchvision.datasets.ImageFolder(
                                            args.data_path,
                                            presets.ClassificationPresetTrain(
                                                                            crop_size=train_crop_size,
                                                                            interpolation=interpolation,
                                                                            auto_augment_policy=auto_augment_policy,
                                                                            random_erase_prob=random_erase_prob,
                                                                        ),
                                            )
    
    dataset = torch.utils.data.Subset(dataset, skfold_indices_list[args.data_fold_index][0])     # training data
    
    print("Loading validation data")
    dataset_test = torchvision.datasets.ImageFolder(
                                            args.data_path,
                                            presets.ClassificationPresetEval(
                                                                            crop_size=val_crop_size, 
                                                                            resize_size=val_resize_size, 
                                                                            interpolation=interpolation,
                                                                            ))
    
    dataset_test = torch.utils.data.Subset(dataset_test, skfold_indices_list[args.data_fold_index][1])     # val dataset
    
    print("Creating data loaders")
    loader_g = torch.Generator()
    loader_g.manual_seed(args.seed)

    if args.distributed:
        if hasattr(args, "ra_sampler") and args.ra_sampler:
            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps, seed=args.seed)
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=args.seed)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset, generator=loader_g)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
        

In [None]:
def load_fused_data_normal(args, ):
    """ 
    if does not detected sk list file, generate it
    if detected, load it
    
    be careful, the train and val datasets have different preprocessing strategies
    """

    dataset = torchvision.datasets.ImageFolder(root = args.data_path, 
                                                transform = presets.ClassificationPresetEval(
                                                            resize_size=args.val_resize_size, 
                                                            crop_size=args.val_crop_size, 
                                                            interpolation=InterpolationMode(args.interpolation)
                                                                                        ),)
    
    skf = StratifiedKFold(n_splits=args.data_fold_number, shuffle=True, random_state=42)
    imgs, clses = zip(*dataset.imgs)
    skfold_indices_list = [_ for _ in skf.split(imgs, clses)]     # the dataset processed here have same preprocess
    
    skfold_indices_list = [[__.tolist() for __ in _] for _ in skfold_indices_list]
    
    # now we have the indices for fold experiments
    assert args.data_fold_index < len(skfold_indices_list)     # throw error if out of bound
    [train_data_indices, val_data_indices] = skfold_indices_list[args.data_fold_index]
    
    # now use 2 different preprocesses to load the full dataset, then generate the exact train/val dataset by the indices
    val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
    interpolation = InterpolationMode(args.interpolation)

    auto_augment_policy = getattr(args, "auto_augment", None)
    random_erase_prob = getattr(args, "random_erase", 0.0)
    
    print("Loading training data")
    dataset = torchvision.datasets.ImageFolder(
                                            args.data_path,
                                            presets.ClassificationPresetTrain(
                                                                            crop_size=train_crop_size,
                                                                            interpolation=interpolation,
                                                                            auto_augment_policy=auto_augment_policy,
                                                                            random_erase_prob=random_erase_prob,
                                                                        ),
                                            )
    
    dataset = torch.utils.data.Subset(dataset, skfold_indices_list[args.data_fold_index][0])     # training data
    
    print("Loading validation data")
    dataset_test = torchvision.datasets.ImageFolder(
                                            args.data_path,
                                            presets.ClassificationPresetEval(
                                                                            crop_size=val_crop_size, 
                                                                            resize_size=val_resize_size, 
                                                                            interpolation=interpolation,
                                                                            ))
    
    dataset_test = torch.utils.data.Subset(dataset_test, skfold_indices_list[args.data_fold_index][1])     # val dataset
    
    print("Creating data loaders")
    loader_g = torch.Generator()
    loader_g.manual_seed(args.seed)

    if args.distributed:
        if hasattr(args, "ra_sampler") and args.ra_sampler:
            train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps, seed=args.seed)
        else:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=args.seed)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset, generator=loader_g)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
        

In [None]:
def load_data(args):
    return load_fused_data_normal(args, )

In [20]:
dataset, dataset_test, train_sampler, test_sampler = load_fused_data(args)

Loading training data
Loading validation data
Creating data loaders


In [15]:
def cal_acc1_acc5( output, target):
    # define how to calculate acc1 and acc5
    acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
    return acc1, acc5

In [None]:
def preprocess_train_sample( args, x: torch.Tensor):
    return x

def preprocess_test_sample( args, x: torch.Tensor):
    return x
    
def process_model_output( args, y: torch.Tensor):
    return y
    

def load_model( args, num_classes):

    if args.arch in model_names:
        
        model = models.__dict__[args.arch](num_classes=num_classes)

    return model

In [23]:
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

    header = f"Epoch: [{epoch}]"
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, -1, header)):
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        with torch.amp.autocast('cuda', enabled=scaler is not None):
            image = preprocess_train_sample(args, image)
            output = process_model_output(args, model(image))
            loss = criterion(output, target)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()

        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

        acc1, acc5 = cal_acc1_acc5(output, target)
        batch_size = target.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    train_loss, train_acc1, train_acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f'Train: train_acc1={train_acc1:.3f}, train_acc5={train_acc5:.3f}, train_loss={train_loss:.6f}, samples/s={metric_logger.meters["img/s"]}')
    return train_loss, train_acc1, train_acc5

In [24]:
def evaluate( args, model, criterion, data_loader, device, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    start_time = time.time()
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, -1, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            image = preprocess_test_sample(args, image)
            output = process_model_output(args, model(image))
            loss = criterion(output, target)

            acc1, acc5 = cal_acc1_acc5(output, target)
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = target.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    test_loss, test_acc1, test_acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg
    print(f'Test: test_acc1={test_acc1:.3f}, test_acc5={test_acc5:.3f}, test_loss={test_loss:.6f}, samples/s={num_processed_samples / (time.time() - start_time):.3f}')
    return test_loss, test_acc1, test_acc5

In [25]:
def set_optimizer( args, parameters):
    opt_name = args.opt.lower()
    if opt_name.startswith("sgd"):
        optimizer = torch.optim.SGD(
            parameters,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov="nesterov" in opt_name,
        )
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(
            parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
        )
    elif opt_name == "adamw":
        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = None
    return optimizer

def set_lr_scheduler( args, optimizer):
    args.lr_scheduler = args.lr_scheduler.lower()
    if args.lr_scheduler == "step":
        main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    elif args.lr_scheduler == "cosa":
        main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs - args.lr_warmup_epochs
        )
    elif args.lr_scheduler == "exp":
        main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
    else:
        main_lr_scheduler = None
    if args.lr_warmup_epochs > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
            )
        else:
            warmup_lr_scheduler = None
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
        )
    else:
        lr_scheduler = main_lr_scheduler

    return lr_scheduler

# training loop

In [26]:
set_deterministic(args.seed, args.disable_uda)

utils.init_distributed_mode(args)
print(args)

device = torch.device(args.device)
dataset, dataset_test, train_sampler, test_sampler = load_data(args)

Not using distributed mode
Namespace(datadir='/local/data/acxyle/Datasets', dataset='C2k', arch='resnet18', device='cuda', batch_size=32, epochs=300, workers=16, opt='sgd', lr=0.01, neuron='LIF', surrogate='ATan', T=4, cupy=False, data_fold_training=True, data_fold_number=5, data_fold_index=0, add_info='', momentum=0.9, weight_decay=0.0, norm_weight_decay=None, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=0.2, lr_scheduler='cosa', lr_warmup_epochs=5, lr_warmup_method='linear', lr_warmup_decay=0.01, lr_step_size=30, lr_gamma=0.1, resume=None, start_epoch=0, sync_bn=False, pretrained=False, auto_augment='ta_wide', random_erase=0.1, world_size=1, dist_url='env://', model_ema=False, model_ema_steps=32, model_ema_decay=0.99998, interpolation='bilinear', val_resize_size=232, val_crop_size=224, train_crop_size=176, clip_grad_norm=None, ra_sampler=False, ra_reps=4, prototype=False, weights=None, seed=2020, print_logdir=False, disable_pinmemory=False, disable_amp=False, local_rank=None, d

In [27]:
# ---
collate_fn = None
num_classes = len(dataset.classes) if hasattr(dataset, 'classes') else len(dataset.dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
    if torch.__version__ >= torch.torch_version.TorchVersion('1.10.0'):
        pass
    else:
        # TODO implement a CrossEntropyLoss to support for probabilities for each class.
        raise NotImplementedError("CrossEntropyLoss in pytorch < 1.11.0 does not support for probabilities for each class."
                                    "Set mixup_alpha=0. to avoid such a problem or update your pytorch.")
    mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
    mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
    mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
    collate_fn = lambda batch: mixupcutmix(*default_collate(batch))  # noqa: E731
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=args.batch_size,
    sampler=train_sampler,
    num_workers=args.workers,
    pin_memory=not args.disable_pinmemory,
    collate_fn=collate_fn,
    worker_init_fn=seed_worker
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=not args.disable_pinmemory,
    worker_init_fn=seed_worker
)


In [28]:
print("Creating model")
model = load_model(args, num_classes)
model.to(device)
print(model)

Creating model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(i

In [29]:
if args.distributed and args.sync_bn:
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
    parameters = model.parameters()
else:
    param_groups = torchvision.ops._utils.split_normalization_params(model)
    wd_groups = [args.norm_weight_decay, args.weight_decay]
    parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]

optimizer = set_optimizer(args, parameters)

if args.disable_amp:
    scaler = None
else:
    scaler = torch.amp.GradScaler('cuda')

lr_scheduler = set_lr_scheduler(args, optimizer)


model_without_ddp = model
if args.distributed:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    model_without_ddp = model.module

model_ema = None
if args.model_ema:
    adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
    alpha = 1.0 - args.model_ema_decay
    alpha = min(1.0, alpha * adjust)
    model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)


if utils.is_main_process():

    max_test_acc1 = -1.
    if model_ema:
        max_ema_test_acc1 = -1.

In [30]:
for epoch in range(args.start_epoch, args.epochs):
    start_time = time.time()
    if args.distributed:
        train_sampler.set_epoch(epoch)

    train_loss, train_acc1, train_acc5 = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)

    lr_scheduler.step()
    test_loss, test_acc1, test_acc5 = evaluate(args, model, criterion, data_loader_test, device=device)

    if model_ema:
        ema_test_loss, ema_test_acc1, ema_test_acc5 = evaluate(args, model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")

    if utils.is_main_process():
        save_max_test_acc1 = False
        save_max_ema_test_acc1 = False

        if test_acc1 > max_test_acc1:
            max_test_acc1 = test_acc1
            save_max_test_acc1 = True

        checkpoint = {
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "epoch": epoch,
            "args": args,
            "max_test_acc1": max_test_acc1,
        }
        if model_ema:
            if ema_test_acc1 > max_ema_test_acc1:
                max_ema_test_acc1 = ema_test_acc1
                save_max_ema_test_acc1 = True
            checkpoint["model_ema"] = model_ema.state_dict()
            checkpoint["max_ema_test_acc1"] = max_ema_test_acc1
        if scaler:
            checkpoint["scaler"] = scaler.state_dict()

    print(f'escape time={(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')
    print(args)



Epoch: [0] Total time: 0:00:22
Train: train_acc1=0.040, train_acc5=0.184, train_loss=7.969404, samples/s=1116.9598736958237
Test:  Total time: 0:00:03
Test: test_acc1=0.019, test_acc5=0.153, test_loss=7.942706, samples/s=4287.072
escape time=2025-07-05 23:02:57

Namespace(datadir='/local/data/acxyle/Datasets', dataset='C2k', arch='resnet18', device='cuda', batch_size=32, epochs=300, workers=16, opt='sgd', lr=0.01, neuron='LIF', surrogate='ATan', T=4, cupy=False, data_fold_training=True, data_fold_number=5, data_fold_index=0, add_info='', momentum=0.9, weight_decay=0.0, norm_weight_decay=None, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=0.2, lr_scheduler='cosa', lr_warmup_epochs=5, lr_warmup_method='linear', lr_warmup_decay=0.01, lr_step_size=30, lr_gamma=0.1, resume=None, start_epoch=0, sync_bn=False, pretrained=False, auto_augment='ta_wide', random_erase=0.1, world_size=1, dist_url='env://', model_ema=False, model_ema_steps=32, model_ema_decay=0.99998, interpolation='bilinear',



Test:  Total time: 0:00:03
Test: test_acc1=0.115, test_acc5=0.678, test_loss=7.610261, samples/s=4418.254
escape time=2025-07-05 22:59:17

Namespace(datadir='/local/data/acxyle/Datasets', dataset='C2k', arch='resnet18', device='cuda', batch_size=32, epochs=300, workers=16, opt='sgd', lr=0.01, neuron='LIF', surrogate='ATan', T=4, cupy=False, data_fold_training=True, data_fold_number=5, data_fold_index=0, add_info='', momentum=0.9, weight_decay=0.0, norm_weight_decay=None, label_smoothing=0.1, mixup_alpha=0.2, cutmix_alpha=0.2, lr_scheduler='cosa', lr_warmup_epochs=5, lr_warmup_method='linear', lr_warmup_decay=0.01, lr_step_size=30, lr_gamma=0.1, resume=None, start_epoch=0, sync_bn=False, pretrained=False, auto_augment='ta_wide', random_erase=0.1, world_size=1, dist_url='env://', model_ema=False, model_ema_steps=32, model_ema_decay=0.99998, interpolation='bilinear', val_resize_size=232, val_crop_size=224, train_crop_size=176, clip_grad_norm=None, ra_sampler=False, ra_reps=4, prototype=Fa

Exception in thread Thread-175:
Traceback (most recent call last):
  File "/local/data/acxyle/anaconda3/envs/sp/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/local/data/acxyle/anaconda3/envs/sp/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/local/data/acxyle/anaconda3/envs/sp/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/local/data/acxyle/anaconda3/envs/sp/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 59, in _pin_memory_loop
    do_one_step()
  File "/local/data/acxyle/anaconda3/envs/sp/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 35, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/local/data/acxyle/anaconda3/envs/sp/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/local/data/acxyle/an

KeyboardInterrupt: 

    s.connect(address)
FileNotFoundError: [Errno 2] No such file or directory
