In [1]:
# Authors: z.shen1@tue.nl, f.corradi@tue.nl
# Training SpikeVision for DVS 128 dataset
import os
from datetime import datetime
import numpy as np
import gc
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from timm.scheduler.step_lr import StepLRScheduler
from torch.utils.data import DataLoader, ConcatDataset, random_split
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
folder_src = "./results/"
try:
    os.mkdir(folder_src)
except:
    pass

dataset = "DVS128"
folder = folder_src + f"{dataset}/"
try:
    os.mkdir(folder)
except:
    pass

from SV import SpikeVision

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from torch.amp import GradScaler, autocast
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = GradScaler()
time = datetime.now().strftime("%Y%m%d-%H%M%S")




In [3]:
# Hyper-parameters
batch_size = 16
layers = 2
in_channels = 2
train_threshold = False
image_size = 128
dataset_classes = 11
num_epochs = 200
time_steps = 8
embed_dim = 256
threshold = [128/128, 128/128, 128/128]
pooling_state = "1111"
criteria = nn.CrossEntropyLoss()
precision_epochs = num_epochs
precision_bits = 8
load = False
train_loss_fn, test_loss_fn = criteria, criteria

In [4]:
# Accumulating
event_per_map = 15000
def integrate_fixed_events(events, H, W, events_per_map = event_per_map):
    t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
    total_events = len(t)
    num_maps = int(np.ceil(total_events / events_per_map))
    frames = np.zeros([num_maps, 2, H, W], dtype=np.float32)

    for i in range(num_maps):
        start_index = i * events_per_map
        end_index = min((i + 1) * events_per_map, total_events)
        for j in range(start_index, end_index):
            if p[j] == 1:
                frames[i, 1, y[j], x[j]] += 1
            else:
                frames[i, 0, y[j], x[j]] += 1

    return frames

In [5]:
dataset_train = DVS128Gesture(
    "../data/DVS128/",
    train=True,
    data_type="frame",
    custom_integrate_function=integrate_fixed_events
)
dataset_test = DVS128Gesture(
    "../data/DVS128/",
    train=False,
    data_type="frame", 
    custom_integrate_function=integrate_fixed_events
)

The directory [../data/DVS128/integrate_fixed_events] already exists.
The directory [../data/DVS128/integrate_fixed_events] already exists.


In [6]:
# Split dataset
combined_dataset = ConcatDataset([dataset_train, dataset_test])
train_size = int(0.8 * len(combined_dataset))
test_size = len(combined_dataset) - train_size
generator = torch.Generator()
seed = np.random.randint(low=0, high=100)
print(seed)
generator.manual_seed(seed)
dataset_train, dataset_test = random_split(combined_dataset, [train_size, test_size], generator=generator)

25


In [7]:
# Create dataloaders
def custom_collate_fn(batch):
    max_timesteps = time_steps
    padded_batch = []

    for item in batch:
        data, label = item
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data).float()

        current_timesteps = data.size(0)

        if current_timesteps < max_timesteps:
            padding_size = (0, 0, 0, 0, 0, 0, 0, max_timesteps - current_timesteps)
            padded_data = torch.nn.functional.pad(data, pad=padding_size, mode='constant', value=0)
        else:
            padded_data = data[:max_timesteps]

        padded_batch.append((padded_data, label))

    return torch.utils.data.dataloader.default_collate(padded_batch)
loader_train = DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=custom_collate_fn,
    num_workers=8,
    pin_memory=False,
)
loader_test = DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=custom_collate_fn,
    num_workers=8,
    pin_memory=False,
)

In [8]:
# Define model
model = SpikeVision(
    dataset=dataset,
    image_size_h=image_size,
    image_size_w=image_size,
    input_channels=in_channels,
    num_classes=dataset_classes,
    embed_dims=embed_dim,
    threshold_head=threshold[0],
    threshold_conv=threshold[1],
    threshold_scre=threshold[2],
    depths=layers,
    pooling_state=pooling_state,
    train_threshold=train_threshold,
)
model = model.to(device)

In [9]:
base_params = [p for p in model.parameters() if p.requires_grad]

In [10]:
train_loss_list = []
train_eval_loss_list = []
test_loss_list = []
def train(model, train_loader, test_loader, optimizer, scheduler=None, num_epochs=50, precision_epochs=num_epochs, low_precision=8, load=False, path=None):
    if load:
        model.load_state_dict(torch.load(path))
    
    acc_state = 0
    for epoch in tqdm(range(num_epochs)):
        train_acc = 0
        train_loss_sum = 0
        optimizer.zero_grad()
        predictions = []

        if epoch == 0:
            print(f"{len(train_loader)} batches in one epoch.")
        
        for i, (images, labels) in tqdm(enumerate(train_loader)):
            images = images.to(device)
            labels = labels.to(device)
            model.train()

            with autocast(device_type=f"{device}", dtype=torch.float16):
                outputs = model(images)
                loss = train_loss_fn(outputs, labels)
                prediction = outputs.argmax(axis=1)
            scaler.scale(loss).backward()
            train_loss_sum += loss.item()

            train_acc += (prediction == labels).sum().item()
            scaler.step(optimizer)
            scaler.update()

        if (epoch >= precision_epochs):
            low_precision_state = low_precision(model.state_dict(), precision=low_precision)
            if epoch == precision_epochs:
                torch.save(model.state_dict(), path)
                path = folder + "low_precision_model" + path
                for param in optimizer.param_groups:
                    param['lr'] = 1e-3
            model.load_state_dict(low_precision_state)
            model.to(device)
        
        test_acc, train_eval_acc, test_loss = test(model, test_loader, train_loader)
        if test_acc >= acc_state:
            acc_state = test_acc
            torch.save(model.state_dict(), path)
            print("Checkpoint saved.")
        
        train_loss = train_loss_sum / len(train_loader)
        train_loss_list.append(train_loss)
        test_loss_list.append(test_loss)
        scheduler.step(epoch)

        print(f"Highest test accuracy: {acc_state}")
        print(f"Epoch: {epoch:3d}, Train loss: {train_loss:.4f}, Train accuracy: {train_acc / len(train_loader.dataset):.4f}")

def evaluate(model, loader):
    total_loss = 0
    total_correct = 0
    for i, (images, labels) in enumerate(loader):
        images = images.to(device)
        labels = labels.to(device)
        model.eval()
        if len(images.shape) == 3:
            images = images.unsqueeze(1)
        outputs = model(images)
        loss = test_loss_fn(outputs, labels)
        total_loss += loss.item()
        prediction = outputs.argmax(axis=1)
        total_correct += (prediction == labels).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    accuracy = total_correct / len(loader.dataset)
    return accuracy, avg_loss

def test(model, test_loader, train_loader):
    model.eval()
    test_acc, test_loss = evaluate(model, test_loader)
    train_eval_acc, train_eval_loss = evaluate(model, train_loader)

    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
    print(f'Train eval Loss: {train_eval_loss:.4f}, Train eval Acc: {train_eval_acc:.4f}')
    return test_acc, train_eval_acc, test_loss

def precision_transfer(x, precision_bits=8, threshold=1.0, max_value=0.5, min_value=-0.5):
    x_flat = x.flatten()
    step = np.diff(np.linspace(min_value, max_value, num=2**precision_bits)[0:2])
    max_value = max_value - step[0]
    q_list = np.round(np.linspace(min_value, max_value, num=2**precision_bits), precision_bits - 2)

    func = lambda x: q_list[np.abs(q_list - x).argmin()]
    q_list = np.array(list(map(func, x_flat)))
    q_list = q_list.reshape(x.shape)
    return q_list

def low_precision(state_dict, precision=8, threshold=1.0, max_value=0.5, min_value=-0.5):
    for key in state_dict.keys():
        if "threshold" in key:
            continue
        state_dict[key] = precision_transfer(state_dict[key].cpu().numpy(), precision_bits=precision, threshold=threshold)

    return state_dict


In [11]:
def create_scheduler(args, optimizer):
    num_epochs = args.epochs

    if getattr(args, 'lr_noise', None) is not None:
        lr_noise = getattr(args, 'lr_noise')
        if isinstance(lr_noise, (list, tuple)):
            noise_range = [n * num_epochs for n in lr_noise]
            if len(noise_range) == 1:
                noise_range = noise_range[0]
        else:
            noise_range = lr_noise * num_epochs
    else:
        noise_range = None
    noise_args = dict(
        noise_range_t=noise_range,
        noise_pct=getattr(args, 'lr_noise_pct', 0.67),
        noise_std=getattr(args, 'lr_noise_std', 1.),
        noise_seed=getattr(args, 'seed', 42),
    )
    cycle_args = dict(
        cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
        cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
        cycle_limit=getattr(args, 'lr_cycle_limit', 1),
    )
    lr_scheduler = StepLRScheduler(
        optimizer,
        decay_t=args.decay_epochs,
        decay_rate=args.decay_rate,
        warmup_lr_init=args.warmup_lr,
        warmup_t=args.warmup_epochs,
        **noise_args,
    )


    return lr_scheduler, num_epochs

In [12]:
# Define optimizer
learning_rate = 1e-3
path = folder + f"model_{dataset}_embeddim_{embed_dim}_depth_{layers}_{time}.pth"
cooldown = 10
class Args:
    def __init__(self):
        self.epochs = num_epochs
        self.sched = 'step'  
        self.min_lr = 1e-5
        self.warmup_lr = 3e-4
        self.warmup_epochs = 20
        self.decay_rate = 0.9
        self.cooldown_epochs = cooldown
        self.lr_noise = [0.6, 0.9]
        self.lr_noise_pct = 0.67
        self.lr_noise_std = 1.0
        self.seed = 42
        self.decay_epochs = 20
        self.patience_epochs = 5

args = Args()

optimizer = torch.optim.Adam([{ 'params': base_params}], lr=learning_rate, weight_decay=0)

scheduler, num_epochs = create_scheduler(args, optimizer)

In [13]:
train(model, loader_train, loader_test, optimizer, scheduler, num_epochs=num_epochs, precision_epochs=precision_epochs, low_precision=precision_bits, load=load, path=path)

  0%|                                                                                                                                                                                                       | 0/200 [00:00<?, ?it/s]

74 batches in one epoch.



0it [00:00, ?it/s][A
  0%|                                                                                                                                                                                                       | 0/200 [00:00<?, ?it/s]


RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor) should be the same