In [98]:
import torch
from Models import modelpool
from Preprocess import datapool
from torch import nn
from spiking_layer_ours import SPIKE_layer
import numpy as np
from copy import deepcopy
from utils_my import add_dimension
from utils_my import replace_maxpool2d_by_avgpool2d, replace_layer_by_tdlayer
from slerp import slerp
from collections import OrderedDict

def isActivation(name):
    if 'relu' in name.lower():
        return True
    return False

def replace_activation_by_spike(model, thresholds, thresholds1, n_steps, counter=0):
    thresholds_new = deepcopy(thresholds)
    thresholds_new1 = deepcopy(thresholds1)
    
    for name, module in model._modules.items():
        if hasattr(module,"_modules"):
            model._modules[name], counter = replace_activation_by_spike(module, thresholds_new, thresholds_new1, n_steps, counter)
        if isActivation(module.__class__.__name__.lower()):
            thresholds_new[counter, n_steps:] = thresholds_new1[counter,1] / n_steps  # thresholds_out_sum/n_steps# thresholds1[counter,1] / n_steps
            thresholds_new[counter, :n_steps] = thresholds_new1[counter,0] / n_steps  # thresholds_inner_sum/n_steps#thresholds1[counter,0] / n_steps
            model._modules[name] = SPIKE_layer(thresholds_new[counter, n_steps:], thresholds_new[counter, 0:n_steps])
            counter += 1
    return model, counter


def interpolate_state_dicts(state_dict_1, state_dict_2, weight,
                            bias_norm=False, use_slerp=False):
    if use_slerp:
        model_state = deepcopy(state_dict_1)
        for p_name in model_state:
            if "batches" not in p_name:
                model_state[p_name] = slerp(weight, state_dict_1[p_name], state_dict_2[p_name])
        return model_state
    elif bias_norm:
        model_state = deepcopy(state_dict_1)
        height = 0
        for p_name in model_state:
            if "batches" not in p_name:
                model_state[p_name].zero_()
                if "weight" in p_name:
                    model_state[p_name].add_(1.0 - weight, state_dict_1[p_name])
                    model_state[p_name].add_(weight, state_dict_2[p_name])
                    height += 1
                if "bias" in p_name:
                    model_state[p_name].add_((1.0 - weight)**height, state_dict_1[p_name])
                    model_state[p_name].add_(weight**height, state_dict_2[p_name])
                if "res_scale" in p_name:
                    model_state[p_name].add_(1.0 - weight, state_dict_1[p_name])
                    model_state[p_name].add_(weight, state_dict_2[p_name])
        return model_state
    else:
        return {key: (1 - weight) * state_dict_1[key] + 
                weight * state_dict_2[key] for key in state_dict_1.keys()}
        

def interpolate_multi_state_dicts(sd_s, weight_s, use_slerp=False):
    if use_slerp:
        sd_interpolated = deepcopy(sd_s[0])
        weight_interpolated = weight_s[0]
        for i in range(1, len(sd_s)):
            sd_next = sd_s[i]
            weight_next = weight_s[i]
            t = weight_next / (weight_interpolated + weight_next)
            sd_interpolated = interpolate_state_dicts(sd_interpolated, sd_next, t, use_slerp=use_slerp)
            weight_interpolated += weight_next
        return sd_interpolated
    else:
        sd_interpolated = deepcopy(sd_s[0])
        for key in sd_s[0].keys():
            sd_interpolated[key] = weight_s[0] * sd_s[0][key]
            for i in range(1, len(sd_s)):
                sd_interpolated[key] += weight_s[i] * sd_s[i][key]
        return sd_interpolated
    

def validate_snn(model, loader, n_steps, thresholds, device, verbose=1):
    model.to(device)
    model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data = add_dimension(data, n_steps)
        data, target = data.to(device), target.to(device)
        output = model(data, thresholds, L=0, t=n_steps)
        output = torch.mean(output, dim=1)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    if verbose != 0:
        print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    model.to('cpu')
    return acc

def validate_snn_ensemble(models, loader, n_steps, thresholds, device, verbose=1):
    for model in models:
        model.to(device)
        model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data = add_dimension(data, n_steps)
        data, target = data.to(device), target.to(device)
        output = torch.zeros(data.size(0), 10).to(device)
        for model in models:
            output += torch.mean(model(data, thresholds, L=0, t=n_steps), dim=1)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    if verbose != 0:
        print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    for model in models:
        model.to('cpu')
    return acc
    
def validate_ann(model, loader, device, verbose=1):
    model.to(device)
    model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    if verbose != 0:
        print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    model.to('cpu')
    return acc


def ann_to_snn(model, thresholds, thresholds1, n_steps):
    model, _ = replace_activation_by_spike(model, thresholds, thresholds1, n_steps)
    model = replace_maxpool2d_by_avgpool2d(model)
    model = replace_layer_by_tdlayer(model)
    return model

def snn_to_ann(model, use_maxpooling=True):
    new_model = deepcopy(model)
    for name, module in new_model._modules.items():
        if hasattr(module, "_modules"):
            new_model._modules[name] = snn_to_ann(module)
        if module.__class__.__name__ == 'SPIKE_layer':
            new_model._modules[name] = nn.ReLU()
        elif module.__class__.__name__ == 'tdLayer':
            if use_maxpooling and isinstance(module.layer.module, nn.AvgPool2d):
                module.layer.module = nn.MaxPool2d(kernel_size=module.layer.module.kernel_size,
                                                stride=module.layer.module.stride,
                                                padding=module.layer.module.padding)
            new_model._modules[name] = module.layer.module
        elif module.__class__.__name__ == 'Flatten':
            new_model._modules[name] = nn.Flatten()
    return new_model

def tranfer_bn_stats_from_ann_to_snn(ann_model, snn_model):
    new_snn_model = deepcopy(snn_model)
    ann_bn_layer_names = [name for name, layer in ann_model.named_modules() if isinstance(layer, (nn.BatchNorm2d))]
    snn_bn_layer_names = [name for name, layer in new_snn_model.named_modules() if isinstance(layer, (nn.BatchNorm2d))]
    for ann_bn_name, snn_bn_name in zip(ann_bn_layer_names, snn_bn_layer_names):
        ann_bn_layer = dict(ann_model.named_modules())[ann_bn_name]
        snn_bn_layer = dict(new_snn_model.named_modules())[snn_bn_name]
        snn_bn_layer.running_mean = ann_bn_layer.running_mean
        snn_bn_layer.running_var = ann_bn_layer.running_var
        snn_bn_layer.num_batches_tracked = ann_bn_layer.num_batches_tracked
    return new_snn_model


def bn_calibration_init(m):
    """ calculating post-statistics of batch normalization """
    if getattr(m, 'track_running_stats', False):
        # reset all values for post-statistics
        m.reset_running_stats()
        # set bn in training mode to update post-statistics
        m.training = True
        # use cumulative moving average
        m.momentum = None

def reset_bn_stats(model, device, bn_loader, n_steps, layerwise=False, ann=False):
    # Reset batch norm statistics
    model.to(device)
    # get batchnorm layer names
    bn_layer_names = [name for name, layer in model.named_modules() if isinstance(layer, (nn.BatchNorm2d))]
    for m in model.modules():
        bn_calibration_init(m)
    model.train()
    with torch.no_grad():
        L_range = range(1, 14) if layerwise else [0]
        for L in L_range:
            for data, _ in bn_loader:
                if not ann:
                    data = add_dimension(data, n_steps)
                data = data.to(device)
                model(data, thresholds=thresholds, L=L, t=n_steps)
            # set current bn layer in eval mode
            if layerwise:
                layer = dict(model.named_modules())[bn_layer_names[L-1]]
                layer.eval()
    model.eval()
    model.to('cpu')

def get_device(item):
    if isinstance(item, nn.Module):
        return next(item.parameters()).device
    elif isinstance(item, OrderedDict):
        return next(iter(item.values())).device

In [99]:
class args:
    model = 'vgg16'
    dataset = 'cifar10'
    batch_size = 128
    t = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
n_steps = args.t
train_loader, test_loader = datapool(args.dataset, args.batch_size, 0, shuffle=True)

Files already downloaded and verified


Files already downloaded and verified


In [100]:
v_idx_s = list(range(2, 16)) + list(range(22, 37))
# v_idx_s = [2]
num_models = len(v_idx_s) + 1

In [102]:
model_s = [modelpool(args.model, args.dataset) for _ in range(num_models)]
ckpt_model_idx = 3
sd = torch.load(f'saved_models/cifar10_vgg16_{ckpt_model_idx}.pth', map_location='cpu', weights_only=True)
for model in model_s:
    model.load_state_dict(sd)

num_relu = str(model_s[0]).count('ReLU')
prefix = 'cifar10_vgg16_3_lr_1e3/'
threshold_all_noaug1 = np.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_threshold_all_noaug{n_steps}.npy')
threshold_pos_all_noaug1 = np.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_threshold_pos_all_noaug{n_steps}.npy')
thresholds = torch.zeros(num_relu,2*n_steps)
thresholds1 = torch.Tensor(np.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_threshold_all_noaug{n_steps}.npy'))

In [103]:
validate_ann(model_s[0], test_loader, device)

Accuracy of the network on the test images: 95.880000


95.88

In [104]:
model_s = [ann_to_snn(model, thresholds, thresholds1, n_steps) for model in model_s]
validate_snn(model_s[0], test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 45.690000


45.69

In [76]:
# for case study
# model_s[0].load_state_dict(torch.load(f'cifar10_vgg16_3_updated_snn1_1_v2.pth', map_location='cpu'))
# validate_snn(model_s[0], test_loader, n_steps, thresholds, device)

In [106]:
# model_s = [ann_to_snn(model, thresholds, thresholds1, n_steps) for model in model_s]
sd_s = [torch.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_updated_snn1_1.pth', weights_only=True, map_location='cpu')] + \
        [torch.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_updated_snn1_1_v{v}.pth', weights_only=True, map_location='cpu') for v in v_idx_s]
for model, sd in zip(model_s, sd_s):
    model.load_state_dict(sd)

In [87]:
# acc_s = [validate_snn(model, test_loader, n_steps, thresholds, device) for model in model_s]
# acc_ensemble = validate_snn_ensemble(model_s, test_loader, n_steps, thresholds, device)
# print(acc_s, acc_ensemble)

Accuracy of the network on the test images: 88.810000
Accuracy of the network on the test images: 88.620000
Accuracy of the network on the test images: 89.890000
[88.81, 88.62] 89.89


In [89]:
reset_bn_stats(model_s[0], device, train_loader, n_steps, layerwise=False, ann=False)
validate_snn(model_s[0], test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 88.580000


88.58

In [107]:
# acc_list = np.array(acc_s + [acc_ensemble])
# np.save('cifar10_vgg16_3_ft_acc_list.npy', acc_list)

acc_list = np.load('cifar10_vgg16_3_ft_acc_list.npy')
acc_s = acc_list[:-1]
acc_ensemble = acc_list[-1]
print(acc_s, acc_ensemble)

[92.44 92.42 92.39 92.09 92.4  92.09 92.18 92.13 92.38 92.3  92.43 92.61
 92.42 92.17 92.46 92.42 92.27 92.47 92.47 92.33 92.37 92.2  92.25 92.25
 92.38 92.25 92.12 92.4  92.51 92.23] 94.36


In [114]:
acc_s

array([92.44, 92.42, 92.39, 92.09, 92.4 , 92.09, 92.18, 92.13, 92.38,
       92.3 , 92.43, 92.61, 92.42, 92.17, 92.46, 92.42, 92.27, 92.47,
       92.47, 92.33, 92.37, 92.2 , 92.25, 92.25, 92.38, 92.25, 92.12,
       92.4 , 92.51, 92.23])

In [113]:
acc_s[merged_indices]

array([92.61, 92.51, 92.47, 92.47, 92.46, 92.44, 92.43, 92.42, 92.42,
       92.42])

#### merging

##### uniform soup

In [None]:
# bn_loader, _ = datapool(args.dataset, 512, 0, shuffle=False)
model_mid = modelpool(args.model, args.dataset)
model_mid = ann_to_snn(model_mid, thresholds, thresholds1, n_steps)
merged_indices = range(0, num_models)

sorted_indices = np.argsort(acc_s)[::-1]
merged_indices = sorted_indices[:30]
# merged_indices = [0, 2]

merged_sds = [sd_s[i] for i in merged_indices]
num_models = len(merged_sds)
merged_weights = [1/num_models]*num_models

sd_mid = interpolate_multi_state_dicts(merged_sds, merged_weights, use_slerp=False)
model_mid.load_state_dict(sd_mid)
model_mid = model_mid.to(device)
acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
print(f'linear acc: {acc_mid}')
reset_bn_stats(model_mid, device, train_loader, n_steps, layerwise=False, ann=False)
acc_mid_reset = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
print(f'linear acc reset: {acc_mid_reset}')

sd_mid = interpolate_multi_state_dicts(merged_sds, merged_weights, use_slerp=True)
model_mid.load_state_dict(sd_mid)
model_mid = model_mid.to(device)
acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
print(f'slerp acc: {acc_mid}')
reset_bn_stats(model_mid, device, train_loader, n_steps, layerwise=False, ann=False)
acc_mid_reset = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
print(f'slerp acc reset: {acc_mid_reset}')

Accuracy of the network on the test images: 92.380000
linear acc: 92.38
Accuracy of the network on the test images: 92.090000
linear acc reset: 92.09
Accuracy of the network on the test images: 92.530000
slerp acc: 92.53
Accuracy of the network on the test images: 92.320000
slerp acc reset: 92.32


In [60]:
reset_bn_stats(model_mid, device, train_loader, n_steps, layerwise=False, ann=False)
acc_mid_reset = validate_snn(model_mid, test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 89.980000


In [45]:
model_mid = deepcopy(model_s[0])
sd_mid = interpolate_multi_state_dicts(merged_sds, merged_weights, use_slerp=True)
model_mid.load_state_dict(sd_mid)
model_mid = model_mid.to(device)
# acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
# print(f'slerp acc: {acc_mid}')
reset_bn_stats(model_mid, device, train_loader, n_steps, layerwise=False, ann=False)
acc_mid_reset = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
print(f'slerp acc reset: {acc_mid_reset}')

Accuracy of the network on the test images: 92.360000
slerp acc reset: 92.36


##### greedy soup

1. sort all models according to their test acc (descending)
2. add the next best model to the soup
3. remain the model only if it has a better test acc than the current soup
4. repeat step 2 and 3 until no more models can be added

In [12]:
def grid_search(model_1, model_2, grid, use_slerp=False):
    t_best = None
    acc_best = -1
    for t in grid:
        sd_mid = interpolate_state_dicts(model_1.state_dict(), model_2.state_dict(), t, use_slerp=use_slerp)
        model_mid.load_state_dict(sd_mid)
        acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device, verbose=0)
        if acc_mid > acc_best:
            acc_best = acc_mid
            t_best = t
    return sd_mid, acc_best, t_best
        

# 1: sort indices by acc_s in descending order
sorted_indices = np.argsort(acc_s)[::-1]
acc_cur = -1
model_mid = None
num_merged = 0
merged_indices = []
use_slerp = True
print(f"use_slerp: {use_slerp}")
for i in sorted_indices:
    print(f"test model {i}")
    if model_mid is None:
        model_mid = deepcopy(model_s[i])
        acc_cur = acc_s[i]
        print(f"add model {i}, current acc: {acc_cur}" + "\n")
        num_merged += 1
        merged_indices.append(i)
    else:
        sd_mid_ori = deepcopy(model_mid.state_dict())
        t = 1 / (num_merged + 1)
        sd_mid = interpolate_state_dicts(model_mid.state_dict(), model_s[i].state_dict(), t, use_slerp=use_slerp)
        model_mid.load_state_dict(sd_mid)
        acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device, verbose=0)
        print(f'acc: {acc_mid}')
        # reset_bn_stats(model_mid, device, train_loader, n_steps, layerwise=False, ann=False)
        # acc_mid_reset = validate_snn(model_mid, test_loader, n_steps, thresholds, device, verbose=0)
        # print(f'reset acc: {acc_mid_reset}')

        # if acc_mid > acc_mid_reset:
        #     model_mid.load_state_dict(sd_mid)
        #     acc_best = acc_mid
        # else:
        #     acc_best = acc_mid_reset
        if acc_mid > acc_cur: # use acc_mid for comparison or acc_best
            acc_cur = acc_mid
            print(f"add model {i}, current acc: {acc_cur}" + "\n")
            num_merged += 1
            merged_indices.append(i)
        else:
            print('\n')
            model_mid.load_state_dict(sd_mid_ori)

use_slerp: {use_slerp}
test model 11
add model 11, current acc: 92.61

test model 0


acc: 92.19
test model 10
acc: 92.25
test model 12
acc: 92.18
test model 1
acc: 92.31
test model 4
acc: 92.11
test model 2
acc: 92.38
test model 8
acc: 92.37
test model 9
acc: 92.25
test model 6
acc: 92.13
test model 13
acc: 92.1
test model 7
acc: 92.25
test model 5
acc: 92.13
test model 3
acc: 92.31


In [51]:
for _ in range(10):
    reset_bn_stats(model_mid, device, train_loader, n_steps, layerwise=False, ann=False)
    validate_snn(model_mid, test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 92.350000
Accuracy of the network on the test images: 92.360000
Accuracy of the network on the test images: 92.530000
Accuracy of the network on the test images: 92.180000
Accuracy of the network on the test images: 92.220000
Accuracy of the network on the test images: 92.540000
Accuracy of the network on the test images: 92.460000
Accuracy of the network on the test images: 92.360000
Accuracy of the network on the test images: 92.360000
Accuracy of the network on the test images: 92.370000


#### grid search on t

In [14]:
sorted_indices

array([11,  0, 10, 12,  1,  4,  2,  8,  9,  6, 13,  7,  5,  3])

In [13]:
acc_s

[92.44,
 92.42,
 92.39,
 92.09,
 92.4,
 92.09,
 92.18,
 92.13,
 92.38,
 92.3,
 92.43,
 92.61,
 92.42,
 92.17]

In [61]:
# two models
merged_indices = [0, 1]
merged_sds = [sd_s[i] for i in merged_indices]
t_s = np.linspace(0.52, 0.53, 20)
acc_t_s = []
for t in t_s:
    merged_weights = [t, 1-t]
    sd_mid = interpolate_multi_state_dicts(merged_sds, merged_weights, use_slerp=True)
    # sd_mid = interpolate_state_dicts(merged_sds[0], merged_sds[1], 0.5, use_slerp=True)
    model_mid.load_state_dict(sd_mid)
    model_mid = model_mid.to(device)
    acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
    acc_t_s.append(acc_mid)
t_acc_dict = {t: acc for t, acc in zip(t_s, acc_t_s)}

Accuracy of the network on the test images: 90.010000
Accuracy of the network on the test images: 89.870000
Accuracy of the network on the test images: 89.810000
Accuracy of the network on the test images: 89.930000
Accuracy of the network on the test images: 90.310000
Accuracy of the network on the test images: 89.900000
Accuracy of the network on the test images: 90.160000
Accuracy of the network on the test images: 89.950000
Accuracy of the network on the test images: 90.100000
Accuracy of the network on the test images: 89.940000
Accuracy of the network on the test images: 89.790000
Accuracy of the network on the test images: 89.980000
Accuracy of the network on the test images: 90.270000
Accuracy of the network on the test images: 90.130000
Accuracy of the network on the test images: 89.920000
Accuracy of the network on the test images: 90.060000
Accuracy of the network on the test images: 89.710000
Accuracy of the network on the test images: 90.000000
Accuracy of the network on t

In [42]:
t_acc_dict

{0.4: 92.03,
 0.4105263157894737: 92.22,
 0.4210526315789474: 92.17,
 0.43157894736842106: 92.22,
 0.4421052631578948: 92.16,
 0.45263157894736844: 92.24,
 0.4631578947368421: 92.31,
 0.4736842105263158: 92.2,
 0.4842105263157895: 92.02,
 0.49473684210526314: 92.25,
 0.5052631578947369: 92.19,
 0.5157894736842106: 92.13,
 0.5263157894736842: 92.16,
 0.5368421052631579: 92.48,
 0.5473684210526315: 92.04,
 0.5578947368421052: 92.16,
 0.5684210526315789: 92.18,
 0.5789473684210527: 92.27,
 0.5894736842105264: 92.47,
 0.6: 92.15}

In [46]:
t = 0.5
merged_weights = [t, 1-t]
sd_mid = interpolate_multi_state_dicts(merged_sds, merged_weights, use_slerp=True)
# sd_mid = interpolate_state_dicts(merged_sds[0], merged_sds[1], 0.5, use_slerp=True)
model_mid.load_state_dict(sd_mid)
model_mid = model_mid.to(device)
acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)

Accuracy of the network on the test images: 91.890000


#### reset bn

In [47]:
reset_acc_s = []
# bn_loader, _ = datapool(args.dataset, 128, 0, shuffle=True)
for _ in range(10):
    bn_loader = train_loader
    model_mid.load_state_dict(sd_mid)
    reset_bn_stats(model_mid, device, bn_loader, n_steps, layerwise=False)
    reset_acc_mid = validate_snn(model_mid, test_loader, n_steps, thresholds, device)
    # print("Training acc", validate_snn(model_mid, bn_loader, n_steps, thresholds, device))
    reset_acc_s.append(reset_acc_mid)

Accuracy of the network on the test images: 92.310000
Accuracy of the network on the test images: 92.200000
Accuracy of the network on the test images: 92.360000
Accuracy of the network on the test images: 92.280000
Accuracy of the network on the test images: 92.140000
Accuracy of the network on the test images: 92.240000
Accuracy of the network on the test images: 92.120000
Accuracy of the network on the test images: 91.870000
Accuracy of the network on the test images: 92.560000
Accuracy of the network on the test images: 92.580000


In [38]:
max(reset_acc_s)

92.68

#### try fine-tuning to tune bn statistics

In [79]:
# def train_snn(train_dataloader, test_dataloader, model, epochs, device, loss_fn, lr=0.1, wd=5e-4):
#     model = model.to(device)
#     optimizer = torch.optim.SGD(model.parameters(), lr=lr,momentum=0.9,weight_decay=wd) 
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs,verbose=True)
#     for epoch in range(epochs):
#         model.train()
#         epoch_loss = 0
#         length = 0
#         total = 0
#         correct = 0
#         for img, label in train_dataloader:
#             img = add_dimension(img,n_steps)
#             img = img.cuda(device)
            
#             labels = label.cuda(device)
#             outputs = model(img,thresholds,L=0,t=n_steps) 
#             outputs = torch.sum(outputs,1)
#             optimizer.zero_grad()
#             loss = loss_fn(outputs/n_steps, labels)

#             with torch.autograd.set_detect_anomaly(True):
#                 loss.backward()
#                 optimizer.step()
                
#             epoch_loss += loss.item()*img.shape[0]
#             length += len(label)
#             _, predicted = torch.max(outputs.data, 1)

#             #print(predicted)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#             if total%(256*8) == 0:
#                 print('Epoch:%d, Accuracy of the snn network on the %d train images: %f, loss:%f'%(epoch,total,100 * correct / total,epoch_loss/total))
#         print('Epoch:%d, Accuracy of the snn network on the %d train images: %f, loss:%f'%(epoch,total,100 * correct / total,epoch_loss/total))
#         scheduler.step()
#         validate_snn(model, test_dataloader, n_steps, thresholds, device)
# model_mid.load_state_dict(sd_mid) 
# train_snn(train_loader, test_loader, model_mid, 5, device, criterion, lr=1e-5, wd=5e-4)

In [119]:
th = np.load('cifar10_vgg16_0_threshold_all_noaug2.npy')

In [32]:
acc = np.load('logs/cifar10_vgg16_3_updated_snn1_test_acc_1_v22.npy')

In [33]:
acc

array(91.95)