In [87]:
import torch
from Models import modelpool
from Preprocess import datapool
from torch import nn
from spiking_layer_ours import SPIKE_layer
import numpy as np
from copy import deepcopy
from utils_my import add_dimension
from utils_my import replace_maxpool2d_by_avgpool2d, replace_layer_by_tdlayer

def isActivation(name):
    if 'relu' in name.lower():
        return True
    return False

def replace_activation_by_spike(model, thresholds, thresholds1, n_steps, counter=0):
    thresholds_new = deepcopy(thresholds)
    thresholds_new1 = deepcopy(thresholds1)
    
    for name, module in model._modules.items():
        if hasattr(module,"_modules"):
            model._modules[name], counter = replace_activation_by_spike(module, thresholds_new, thresholds_new1, n_steps, counter)
        if isActivation(module.__class__.__name__.lower()):
            thresholds_new[counter, n_steps:] = thresholds_new1[counter,1] / n_steps  # thresholds_out_sum/n_steps# thresholds1[counter,1] / n_steps
            thresholds_new[counter, :n_steps] = thresholds_new1[counter,0] / n_steps  # thresholds_inner_sum/n_steps#thresholds1[counter,0] / n_steps
            model._modules[name] = SPIKE_layer(thresholds_new[counter, n_steps:], thresholds_new[counter, 0:n_steps])
            counter += 1
    return model, counter


def interpolate_state_dicts(state_dict_1, state_dict_2, weight,
                            bias_norm=False):
    if not bias_norm:
        return {key: (1 - weight) * state_dict_1[key] +
                weight * state_dict_2[key] for key in state_dict_1.keys()}
    else:
        model_state = deepcopy(state_dict_1)
        height = 0
        for p_name in model_state:
            if "batches" not in p_name:
                model_state[p_name].zero_()
                if "weight" in p_name:
                    model_state[p_name].add_(1.0 - weight, state_dict_1[p_name])
                    model_state[p_name].add_(weight, state_dict_2[p_name])
                    height += 1
                if "bias" in p_name:
                    model_state[p_name].add_((1.0 - weight)**height, state_dict_1[p_name])
                    model_state[p_name].add_(weight**height, state_dict_2[p_name])
                if "res_scale" in p_name:
                    model_state[p_name].add_(1.0 - weight, state_dict_1[p_name])
                    model_state[p_name].add_(weight, state_dict_2[p_name])
        return model_state

def interpolate_multi_state_dicts(sd_s, weight_s):
    sd_interpolated = deepcopy(sd_s[0])
    for key in sd_s[0].keys():
        sd_interpolated[key] = weight_s[0] * sd_s[0][key]
        for i in range(1, len(sd_s)):
            sd_interpolated[key] += weight_s[i] * sd_s[i][key]
    return sd_interpolated
    

def validate_snn(model, loader, n_steps, thresholds, device):
    model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data = add_dimension(data, n_steps)
        data, target = data.to(device), target.to(device)
        output = model(data, thresholds, L=0, t=n_steps)
        output = torch.mean(output, dim=1)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    return acc
def validate_snn_ensemble(models, loader, n_steps, thresholds, device):
    for model in models:
        model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data = add_dimension(data, n_steps)
        data, target = data.to(device), target.to(device)
        output = torch.zeros(data.size(0), 10).to(device)
        for model in models:
            output += torch.mean(model(data, thresholds, L=0, t=n_steps), dim=1)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    return acc
    
def validate_ann(model, loader, device):
    model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    return acc



def ann_to_snn(model, thresholds, thresholds1, n_steps):
    model, _ = replace_activation_by_spike(model, thresholds, thresholds1, n_steps)
    model = replace_maxpool2d_by_avgpool2d(model)
    model = replace_layer_by_tdlayer(model)
    return model



In [88]:
class args:
    model = 'vgg16'
    dataset = 'cifar10'
    batch_size = 128
    t = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
n_steps = args.t
train_loader, test_loader = datapool(args.dataset, args.batch_size, 0, shuffle=True)

Files already downloaded and verified


Files already downloaded and verified


In [89]:
model_s = [modelpool(args.model, args.dataset).to(device) for _ in range(3)]
ckpt_model_idx = 3
sd = torch.load(f'saved_models/cifar10_vgg16_{ckpt_model_idx}.pth')
for model in model_s:
    model.load_state_dict(sd)

num_relu = str(model_s[0]).count('ReLU')
threshold_all_noaug1 = np.load(f'cifar10_vgg16_{ckpt_model_idx}_threshold_all_noaug{n_steps}.npy')
threshold_pos_all_noaug1 = np.load(f'cifar10_vgg16_{ckpt_model_idx}_threshold_pos_all_noaug{n_steps}.npy')
thresholds = torch.zeros(num_relu,2*n_steps)
thresholds1 = torch.Tensor(np.load(f'cifar10_vgg16_{ckpt_model_idx}_threshold_all_noaug{n_steps}.npy'))

  sd = torch.load(f'saved_models/cifar10_vgg16_{ckpt_model_idx}.pth')


In [90]:
validate_ann(model_s[0], test_loader, device)

Accuracy of the network on the test images: 95.880000


95.88

In [91]:
model_s = [ann_to_snn(model, thresholds, thresholds1, n_steps) for model in model_s]
validate_snn(model_s[0], test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 45.690000


45.69

In [47]:
# model_s = [ann_to_snn(model, thresholds, thresholds1, n_steps) for model in model_s]
sd_s = [torch.load('cifar10_vgg16_0_updated_snn1_1.pth'),
        torch.load('cifar10_vgg16_0_updated_snn1_1_v2.pth'),
        torch.load('cifar10_vgg16_0_updated_snn1_1_v3.pth')]
for model, sd in zip(model_s, sd_s):
    model.load_state_dict(sd)

  sd_s = [torch.load('cifar10_vgg16_0_updated_snn1_1.pth'),
  torch.load('cifar10_vgg16_0_updated_snn1_1_v2.pth'),
  torch.load('cifar10_vgg16_0_updated_snn1_1_v3.pth')]


In [48]:
acc_s = [validate_snn(model, test_loader, n_steps, thresholds, device) for model in model_s]
acc_ensemble = validate_snn_ensemble(model_s, test_loader, n_steps, thresholds, device)
print(acc_s, acc_ensemble)

Accuracy of the network on the test images: 92.330000
Accuracy of the network on the test images: 92.210000
Accuracy of the network on the test images: 92.040000
Accuracy of the network on the test images: 93.860000
[92.33, 92.21, 92.04] 93.86


In [73]:
# bn_loader, _ = datapool(args.dataset, 512, 0, shuffle=False)
bn_loader = train_loader
model_mid = modelpool(args.model, args.dataset)
model_mid = ann_to_snn(model_mid, thresholds, thresholds1, n_steps)
merged_indices = [0, 2]
merged_sds = [sd_s[i] for i in merged_indices]
num_models = len(merged_sds)
sd_mid = interpolate_multi_state_dicts(merged_sds, [1/num_models]*num_models)
model_mid.load_state_dict(sd_mid)
model_mid = model_mid.to(device)
acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 92.500000


In [72]:
def bn_calibration_init(m):
    """ calculating post-statistics of batch normalization """
    if getattr(m, 'track_running_stats', False):
        # reset all values for post-statistics
        m.reset_running_stats()
        # set bn in training mode to update post-statistics
        m.training = True
        # use cumulative moving average
        m.momentum = None

def reset_bn_stats(model, device, bn_loader, n_steps):
    # Reset batch norm statistics
    for m in model.modules():
        bn_calibration_init(m)
    model.train()
    with torch.no_grad():
        for data, _ in bn_loader:
            data = add_dimension(data, n_steps)
            data = data.to(device)
            model(data)
    model.eval()

model_mid.load_state_dict(sd_mid)
reset_bn_stats(model_mid, device, bn_loader, n_steps)
reset_acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 92.200000


try fine-tuning to tune bn statistics

In [79]:
# def train_snn(train_dataloader, test_dataloader, model, epochs, device, loss_fn, lr=0.1, wd=5e-4):
#     model = model.to(device)
#     optimizer = torch.optim.SGD(model.parameters(), lr=lr,momentum=0.9,weight_decay=wd) 
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs,verbose=True)
#     for epoch in range(epochs):
#         model.train()
#         epoch_loss = 0
#         length = 0
#         total = 0
#         correct = 0
#         for img, label in train_dataloader:
#             img = add_dimension(img,n_steps)
#             img = img.cuda(device)
            
#             labels = label.cuda(device)
#             outputs = model(img,thresholds,L=0,t=n_steps) 
#             outputs = torch.sum(outputs,1)
#             optimizer.zero_grad()
#             loss = loss_fn(outputs/n_steps, labels)

#             with torch.autograd.set_detect_anomaly(True):
#                 loss.backward()
#                 optimizer.step()
                
#             epoch_loss += loss.item()*img.shape[0]
#             length += len(label)
#             _, predicted = torch.max(outputs.data, 1)

#             #print(predicted)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#             if total%(256*8) == 0:
#                 print('Epoch:%d, Accuracy of the snn network on the %d train images: %f, loss:%f'%(epoch,total,100 * correct / total,epoch_loss/total))
#         print('Epoch:%d, Accuracy of the snn network on the %d train images: %f, loss:%f'%(epoch,total,100 * correct / total,epoch_loss/total))
#         scheduler.step()
#         validate_snn(model, test_dataloader, n_steps, thresholds, device)
# model_mid.load_state_dict(sd_mid) 
# train_snn(train_loader, test_loader, model_mid, 5, device, criterion, lr=1e-5, wd=5e-4)

In [119]:
th = np.load('cifar10_vgg16_0_threshold_all_noaug2.npy')

In [120]:
acc = np.load('cifar10_vgg16_0_updated_snn1_test_acc_1.npy')