In [18]:
import nbformat as nbf
import autogendianet as autogen
from nbformat.v4 import new_notebook, new_code_cell
from datetime import datetime

# 创建一个新的笔记本对象
nb = new_notebook()

# 获取当前日期, 转换为指定的形式(如202307311453)
current_time = datetime.now()
formatted_time = current_time.strftime("%Y%m%d%H%M")

filename = 'auto_mnist_3layer_' + formatted_time + '.ipynb'

test_mode = 0
# device_setting, 配置snn的执行位置
auto = 'device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")'
user = 'device = "cpu"'
device_setting = auto

#脚本使用的参数, 用于生成dianet拓扑和配置
attrb_num = 256
label_num = 10

#限制突触
w_up = 3
w_dwn = -3
b_up = 3
b_dwn = -3

# preprocess_dataset, 这段代码会嵌入到数据集处理的头部, 没办法, 参数太多不如单独写了
preprocess_dataset = '''
train_db = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,)),
                   ]))
train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size, shuffle=True)

test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))]))
test_loader = torch.utils.data.DataLoader(test_db,
    batch_size=10000, shuffle=True)
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
#print('db1:', len(train_db), 'db2:', len(val_db))
train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(
    val_db,
    batch_size=batch_size, shuffle=True)
'''

#____________________________________________cell body_______________________________________
parameter_cell = f'''
# snn-Dianet parameters
# 获取当前时间
current_time = datetime.now()
print('Start time: ' + str(current_time))

# Neuron
step_num = 4
thresh = 1.0 # neuronal threshold
lens = 0.5 # hyper-parameters of approximate function

# Training
initial_lr = 0.1
epoch_num = 150

# Set the project
batch_size = 800
dtype = torch.float
{device_setting}
print("Message: This project will run on " + str(device) + ". ")
'''
#____________________________________________________________________________________________



In [19]:
# 生成输入的拓扑
input_site_array, input_site_width = autogen.create_input_site(attrb_num)

#_______________test the function_______________
if test_mode:
    print(input_site_array)
    print(input_site_width)

In [20]:
#创建dianet拓扑
dianet_neuron_array, dianet_layer_width = autogen.create_dianet_neuron_array(input_site_array, label_num)

#_______________test the function_______________
if test_mode:
    print(dianet_neuron_array)
    print(dianet_layer_width)

In [21]:
#创建dianet的属性列表
dianet_neuron_feature = autogen.create_dianet_neuron_feature(input_site_array, dianet_neuron_array)

#_______________test the function_______________
if test_mode:
    print(dianet_neuron_feature)

In [22]:
#_______________test the function_______________
if test_mode:
    autogen.print_dianet_topolog(input_site_array, dianet_neuron_array)

This is the topolog: 
                             0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20    21    22
                          0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20    21    22    23
                       0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20    21    22    23    24
                    0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20    21    22    23    24    25
                 0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20    21    22    23    24    25    26
              0     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    

In [23]:
#_______________test the function_______________
if test_mode:
    autogen.print_dianet_neuron_feature(dianet_neuron_feature)

In [24]:
# 创建一个新的代码单元格，并设置其内容

#____________________________________________cell body_______________________________________
improt_cell = '''
# Imports

    #避免出现vscode下plt工作不正常
import os 
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"   

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

    # 解除打印tensor尺寸的限制
torch.set_printoptions(profile="full")
# torch.set_printoptions(profile="default")

import matplotlib.pyplot as plt
import numpy as np
import itertools
import pandas as pd

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

from datetime import datetime

'''
#____________________________________________________________________________________________



In [25]:
#____________________________________________cell body_______________________________________
organize_dataset_cell = f'''
# Organize dataset

{preprocess_dataset}


'''
#____________________________________________________________________________________________



In [26]:
def gen_upper_3_layer_evtmask(this_layer_num, upper_3_layer_num):
    original_list = []
    for i in range(this_layer_num):
        row = [0] * (i) + [1] * 4 + [0] * (this_layer_num - i - 1)
        original_list.append(row)

    if (upper_3_layer_num - this_layer_num) == -3:      #ex-ex-ex-ex/max, 各删3列
        upper_3_layer_evtmask = [row[3:-3] for row in original_list]

    elif (upper_3_layer_num - this_layer_num) == -1:    #ex-ex-max-cp, 各删2列
        upper_3_layer_evtmask = [row[2:-2] for row in original_list]

    elif (upper_3_layer_num - this_layer_num) == 1:    #ex-max-cp-cp, 各删1列
        upper_3_layer_evtmask = [row[1:-1] for row in original_list]

    elif (upper_3_layer_num - this_layer_num) == 3:    #max/cp-cp-cp-cp, 不需处理
        upper_3_layer_evtmask = original_list

    else:
        print("Sth error!")

    return upper_3_layer_evtmask

if test_mode:
    this_layer_num = 9  # 修改这里的值来改变列表的行数
    upper_3_layer_num = 12  # 修改这里的值来改变列表的列数
    generated_list = gen_upper_3_layer_evtmask(this_layer_num, upper_3_layer_num)
    for row in generated_list:
        print(row)


[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]


In [27]:
def gen_upper_2_layer_evtmask(this_layer_num, upper_2_layer_num):
    original_list = []
    for i in range(this_layer_num):
        row = [0] * (i) + [1] * 3 + [0] * (this_layer_num - i - 1)
        original_list.append(row)

    if (upper_2_layer_num - this_layer_num) == -2:      #ex-ex-ex-ex/max, 各删2列
        upper_2_layer_evtmask = [row[2:-2] for row in original_list]

    elif (upper_2_layer_num - this_layer_num) == 0:    #ex-ex-max-cp, 各删1列
        upper_2_layer_evtmask = [row[1:-1] for row in original_list]

    elif (upper_2_layer_num - this_layer_num) == 2:    #max/cp-cp-cp-cp, 不需处理
        upper_2_layer_evtmask = original_list

    else:
        print("Sth error!")
        
    return upper_2_layer_evtmask

if test_mode:
    this_layer_num = 11  # 修改这里的值来改变列表的行数
    upper_2_layer_num = 11  # 修改这里的值来改变列表的列数
    generated_list = gen_upper_2_layer_evtmask(this_layer_num, upper_2_layer_num)
    for row in generated_list:
        print(row)

[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]


In [28]:
#隐藏层的层数, 从拓扑获取
hid_layer_num = len(dianet_neuron_array) - 2


#生成linear, mask___________________________________________________________________________________
create_fc_nrm_msk = ''  #保存生成代码的字符串, 直接嵌入f{}中
tab_pad = 4 * ' ' * 2

for i in range(1, len(dianet_neuron_array)):    #从第[1]行神经元开始遍历, 因为第[0]行是纯输入, 处理方式不同
    upper_layer_neuron_num = len(dianet_neuron_array[i-1])
    input_site_num = 0
    for j, neuron_feature in enumerate(dianet_neuron_feature[i]):
        if neuron_feature['tap_i'] != None:
            input_site_num += 1
        
    fc_in_num = upper_layer_neuron_num + input_site_num    #本行神经元一共包含多少个输入, 包含上一层神经元数与本层额外输入

    fc_msk_array = []   #创建本行的空白模板, 二维的, [列表行数]对应[神经元个数], [每行的成员数]对应[输入个数]
    base_idx = 0        #用于将二维列表中, 每一行开始改1的基地址记忆下来, dianet的mask均为[多个0, 连续的2-3个1, 多个零]这种形式
    for j, neuron_feature in enumerate(dianet_neuron_feature[i]):
        temp_fc_msk_unit = [0] * fc_in_num
        if neuron_feature['left_i'] != None:
            temp_fc_msk_unit[base_idx] = 1
            base_idx += 1
        if neuron_feature['tap_i'] != None:
            temp_fc_msk_unit[base_idx] = 1
            base_idx += 1
        if neuron_feature['right_i'] != None:
            temp_fc_msk_unit[base_idx] = 1
            base_idx += 1

        base_idx -= 1 #补偿机制, 因为当前神经元的右输入, 和下一个神经元的左输入, 都来源于上一层夹在中间的那个神经元, 因此基地址计数器要倒退一格
        
        fc_msk_array.append(temp_fc_msk_unit)
        #print(temp_fc_msk_unit)

    create_fc_nrm_msk += (f'{tab_pad}self.fc{i} = nn.Linear({str(fc_in_num)}, {len(dianet_neuron_array[i])})\n')
    create_fc_nrm_msk += (f'{tab_pad}self.bn{i} = nn.BatchNorm1d(num_features={len(dianet_neuron_array[i])})\n')
    create_fc_nrm_msk += (f'{tab_pad}self.msk{i} = np.array({fc_msk_array})\n\n')

    if i == 3:
        upper_2_layer_evtmask = gen_upper_2_layer_evtmask(len(dianet_neuron_array[i]), len(dianet_neuron_array[i-2]))

        create_fc_nrm_msk += (f'{tab_pad}self.fc{i}evt = nn.Linear({len(dianet_neuron_array[i-2])}, {len(dianet_neuron_array[i])}, bias=False)\n')
        create_fc_nrm_msk += (f'{tab_pad}self.bn{i}evt = nn.BatchNorm1d(num_features={len(dianet_neuron_array[i])})\n')
        create_fc_nrm_msk += (f'{tab_pad}self.msk{i}evt = np.array({upper_2_layer_evtmask})\n\n')



    elif i > 3:
        upper_3_layer_evtmask = gen_upper_3_layer_evtmask(len(dianet_neuron_array[i]), len(dianet_neuron_array[i-3]))
        upper_2_layer_evtmask = gen_upper_2_layer_evtmask(len(dianet_neuron_array[i]), len(dianet_neuron_array[i-2]))
        upper_3_2_layer_evtmask = [upper_3_layer_row + upper_2_layer_row for upper_3_layer_row, upper_2_layer_row in zip(upper_3_layer_evtmask, upper_2_layer_evtmask)]

        create_fc_nrm_msk += (f'{tab_pad}self.fc{i}evt = nn.Linear({len(dianet_neuron_array[i-3]) + len(dianet_neuron_array[i-2])}, {len(dianet_neuron_array[i])}, bias=False)\n')
        create_fc_nrm_msk += (f'{tab_pad}self.bn{i}evt = nn.BatchNorm1d(num_features={len(dianet_neuron_array[i])})\n')
        create_fc_nrm_msk += (f'{tab_pad}self.msk{i}evt = np.array({upper_3_2_layer_evtmask})\n\n')


create_fc_nrm_msk += '#auto_gen'


#create_neuron_layer = ''
#tab_pad = 4 * ' ' * 2   # 生成格式化空格, 虽然叫tab, 实际是填4个半角空格
#for i in range(1, hid_layer_num + 1):
#    create_neuron_layer += (f'{tab_pad}self.hid{i} = snn.Leaky(threshold={threshold}, beta=beta)\n')
#
#create_neuron_layer += f'{tab_pad}self.out = snn.Leaky(threshold={threshold}, beta=beta)\n'
#create_neuron_layer += '#auto_gen'


#初始化mem和spk状态___________________________________________________________________________________
init_neuron_mem_spk = ''
tab_pad = 4 * ' ' * 2
init_neuron_mem_spk += f'{tab_pad}batch_size = pre_x.size(0)\n'

init_neuron_mem_spk += f'{tab_pad}mem_conv1 = spk_conv1 = torch.zeros(batch_size, 16, 26, 26, device=device)\n'
init_neuron_mem_spk += f'{tab_pad}mem_conv2 = spk_conv2 = torch.zeros(batch_size, 16, 9, 9, device=device)\n'

for i in range(1, hid_layer_num + 1):
    init_neuron_mem_spk += (f'{tab_pad}mem_hid{i} = spk_hid{i} = torch.zeros(batch_size, {len(dianet_neuron_array[i])}, device=device)\n')

init_neuron_mem_spk += f'{tab_pad}mem_out = spk_out = spksum_out = torch.zeros(batch_size, {len(dianet_neuron_array[-1])}, device=device)\n\n'
init_neuron_mem_spk += '#auto_gen'


#创建神经元的输入输出记录___________________________________________________________________________________
create_log_rec = ''
tab_pad = 4 * ' ' * 2
for i in range(1, hid_layer_num + 1):
    create_log_rec += (f'{tab_pad}self.spk_hid{i}_rec = []\n')
    create_log_rec += (f'{tab_pad}self.mem_hid{i}_rec = []\n')

create_log_rec += f'{tab_pad}self.spk_out_rec = []\n'
create_log_rec += f'{tab_pad}self.mem_out_rec = []\n'
create_log_rec += '#auto_gen'


#主迭代体, 包含各层的传输____________________________________________________________________________
dianet_iterate = ''
tab_pad = 4 * ' ' * 3

hid1_tap_num = 0
for j, neuron_feature in enumerate(dianet_neuron_feature[1]):
    if (neuron_feature['tap_i'] != None):
        hid1_tap_num += 1

spk_to_hid1 = len(input_site_array[0]) + hid1_tap_num
dianet_iterate += f'{tab_pad}spk_to_hid1 = x[:,0:{spk_to_hid1}]\n\n'


tab_pad = 4 * ' ' * 3
for i in range(1, len(dianet_neuron_array)):
    upper_layer_neuron_num = len(dianet_neuron_array[i-1])
    input_site_num = 0
    for j, neuron_feature in enumerate(dianet_neuron_feature[i]):
        if neuron_feature['tap_i'] != None:
            input_site_num += 1
        
    fc_in_num = upper_layer_neuron_num + input_site_num

    if i < (len(dianet_neuron_array) - 2):
        netlist_array = []
        for j, neuron_feature in enumerate(dianet_neuron_feature[i+1]):
            if (neuron_feature['left_i'] != None) and (len(netlist_array) == 0):
                netlist_array.append(neuron_feature['left_i'][0])

            if (neuron_feature['tap_i'] != None):
                netlist_array.append(neuron_feature['tap_i'])

            if (neuron_feature['right_i'] != None):
                netlist_array.append(neuron_feature['right_i'][0])

        #print(netlist_array)
        netlist_str = ''
        for item in netlist_array:
            if isinstance(item, list):
                netlist_str += f'spk_hid{item[0]}_pth[:, {item[1]}:{item[1]+1}], '
                
            elif isinstance(item, int):
                netlist_str += f'x[:, {item}:{item+1}], '
            else:
                print('Detected an illegal item!')

        hid_with_patch = ''
        if i < 3:
            hid_with_patch = f'{tab_pad}spk_hid{i}_pth = spk_hid{i}\n'

        else:
            hid_with_patch = f'{tab_pad}spk_hid{i}_pth = spk_hid{i}\n'

#            if len(dianet_neuron_array[i]) > len(dianet_neuron_array[i-2]):
#                hid_with_patch += f'{tab_pad}spk_hid{i}_pth = torch.cat((spk_hid{i}[:, 0:1], spk_hid{i-2}_pth + spk_hid{i}[:,1:{len(dianet_neuron_array[i])-1}], spk_hid{i}[:, {len(dianet_neuron_array[i])-1}:{len(dianet_neuron_array[i])}]), dim=1)\n'
#            elif len(dianet_neuron_array[i]) == len(dianet_neuron_array[i-2]):
#                hid_with_patch = f'{tab_pad}spk_hid{i}_pth = spk_hid{i-2}_pth + spk_hid{i}\n'
#            elif len(dianet_neuron_array[i]) < len(dianet_neuron_array[i-2]):
#                hid_with_patch = f'{tab_pad}spk_hid{i}_pth = spk_hid{i-2}_pth[:,1:{len(dianet_neuron_array[i-2])-1}] + spk_hid{i}\n'
#            else:
#                hid_with_patch = '\n'

        cur_to_layer = ''
        if i < 3:
            cur_to_layer = f'{tab_pad}cur_to_hid{i} = self.bn{i}(self.fc{i}(spk_to_hid{i}))\n'

        elif i == 3:
            cur_to_layer = f'{tab_pad}cur_to_hid{i} = self.bn{i}(self.fc{i}(spk_to_hid{i})) + self.bn{i}evt(self.fc{i}evt(spk_hid{i-2}))\n'

        else:
            cur_to_layer = f'{tab_pad}cur_to_hid{i} = self.bn{i}(self.fc{i}(spk_to_hid{i})) + self.bn{i}evt(self.fc{i}evt(torch.cat((spk_hid{i-3}, spk_hid{i-2}), 1)))\n'  

        dianet_iterate += f'{tab_pad}self.fc{i}.weight.data *= torch.from_numpy(self.msk{i}).float().to(device)\n'
        #dianet_iterate += f'{tab_pad}#self.fc{i}.weight.data = torch.clamp(self.fc{i}.weight.data, min=w_dwn, max=w_up).to(device)\n'
        #dianet_iterate += f'{tab_pad}#self.fc{i}.bias.data = torch.clamp(self.fc{i}.bias.data, min=b_dwn, max=b_up).to(device)\n'
        dianet_iterate += cur_to_layer
        dianet_iterate += f'{tab_pad}mem_hid{i}, spk_hid{i} = mem_update(cur_to_hid{i}, mem_hid{i}, spk_hid{i})\n'
        dianet_iterate += hid_with_patch
        dianet_iterate += f'{tab_pad}spk_to_hid{i+1} = torch.cat(({netlist_str[:-2]}), 1)\n\n'

        

    elif i == (len(dianet_neuron_array) - 2):
        netlist_array = []
        for j, neuron_feature in enumerate(dianet_neuron_feature[i+1]):
            if (neuron_feature['left_i'] != None) and (len(netlist_array) == 0):
                netlist_array.append(neuron_feature['left_i'][0])

            if (neuron_feature['tap_i'] != None):
                netlist_array.append(neuron_feature['tap_i'])

            if (neuron_feature['right_i'] != None):
                netlist_array.append(neuron_feature['right_i'][0])
        
        #print(netlist_array)
        netlist_str = ''
        for item in netlist_array:
            if isinstance(item, list):
                netlist_str += f'spk_hid{item[0]}_pth[:, {item[1]}:{item[1]+1}], '
                
            elif isinstance(item, int):
                netlist_str += f'x[:, {item}:{item+1}], '
            else:
                print('Detected an illegal item!')

        hid_with_patch = ''
        if i < 3:
            hid_with_patch = f'{tab_pad}spk_hid{i}_pth = spk_hid{i}\n'
        
        elif i >= 3:
            if len(dianet_neuron_array[i]) > len(dianet_neuron_array[i-2]):
                hid_with_patch += f'{tab_pad}spk_hid{i}_pth = torch.cat((spk_hid{i}[:, 0:1], spk_hid{i-2}_pth + spk_hid{i}[:,1:{len(dianet_neuron_array[i])-1}], spk_hid{i}[:, {len(dianet_neuron_array[i])-1}:{len(dianet_neuron_array[i])}]), dim=1)\n'
            elif len(dianet_neuron_array[i]) == len(dianet_neuron_array[i-2]):
                hid_with_patch = f'{tab_pad}spk_hid{i}_pth = spk_hid{i-2}_pth + spk_hid{i}\n'
            elif len(dianet_neuron_array[i]) < len(dianet_neuron_array[i-2]):
                hid_with_patch = f'{tab_pad}spk_hid{i}_pth = spk_hid{i-2}_pth[:,1:{len(dianet_neuron_array[i-2])-1}] + spk_hid{i}\n'
            else:
                hid_with_patch = '\n'

        dianet_iterate += f'{tab_pad}self.fc{i}.weight.data *= torch.from_numpy(self.msk{i}).float().to(device)\n'
        #dianet_iterate += f'{tab_pad}#self.fc{i}.weight.data = torch.clamp(self.fc{i}.weight.data, min=w_dwn, max=w_up).to(device)\n'
        #dianet_iterate += f'{tab_pad}#self.fc{i}.bias.data = torch.clamp(self.fc{i}.bias.data, min=b_dwn, max=b_up).to(device)\n'
        dianet_iterate += f'{tab_pad}cur_to_hid{i} = self.bn{i}(self.fc{i}(spk_to_hid{i})) + self.bn{i}evt(self.fc{i}evt(torch.cat((spk_hid{i-3}, spk_hid{i-2}), 1)))\n' 
        dianet_iterate += f'{tab_pad}mem_hid{i}, spk_hid{i} = mem_update(cur_to_hid{i}, mem_hid{i}, spk_hid{i})\n'
        dianet_iterate += hid_with_patch
        dianet_iterate += f'{tab_pad}spk_to_out = torch.cat(({netlist_str[:-2]}), 1)\n\n'
        
    
    else:
        dianet_iterate += f'{tab_pad}self.fc{i}.weight.data *= torch.from_numpy(self.msk{i}).float().to(device)\n'
        #dianet_iterate += f'{tab_pad}#self.fc{i}.weight.data = torch.clamp(self.fc{i}.weight.data, min=w_dwn, max=w_up).to(device)\n'
        #dianet_iterate += f'{tab_pad}#self.fc{i}.bias.data = torch.clamp(self.fc{i}.bias.data, min=b_dwn, max=b_up).to(device)\n'
        dianet_iterate += f'{tab_pad}cur_to_out = self.bn{i}(self.fc{i}(spk_to_out)) + self.bn{i}evt(self.fc{i}evt(torch.cat((spk_hid{i-3}, spk_hid{i-2}), 1)))\n' 
        dianet_iterate += f'{tab_pad}mem_out, spk_out = mem_update(cur_to_out, mem_out, spk_out)\n\n'

dianet_iterate += '#auto_gen'



save_spk_mem_rec = ''
tab_pad = 4 * ' ' * 3
for i in range(1, hid_layer_num + 1):
    save_spk_mem_rec += (f'{tab_pad}self.spk_hid{i}_rec.append(spk_hid{i})\n')
    save_spk_mem_rec += (f'{tab_pad}self.mem_hid{i}_rec.append(mem_hid{i})\n')

save_spk_mem_rec += f'{tab_pad}self.spk_out_rec.append(spk_out)\n'
save_spk_mem_rec += f'{tab_pad}self.mem_out_rec.append(mem_out)\n'
save_spk_mem_rec += '#auto_gen'



print_spk_log = ''
tab_pad = 4 * ' ' * 4
for i in range(1, hid_layer_num + 1):
    print_spk_log += (f'{tab_pad}print(" spk_hid{i}:\t|", [float(num) for num in self.spk_hid{i}_rec[step][batch].tolist()])\n')

print_spk_log += f'{tab_pad}print(" spk_out:\t|", [float(num) for num in self.spk_out_rec[step][batch].tolist()])\n'
print_spk_log += '#auto_gen'

print_mem_log = ''
tab_pad = 4 * ' ' * 4
for i in range(1, hid_layer_num + 1):
    print_mem_log += (f'{tab_pad}print(" mem_hid{i}:\t|", [float(num) for num in self.mem_hid{i}_rec[step][batch].tolist()])\n')

print_mem_log += f'{tab_pad}print(" mem_out:\t|", [float(num) for num in self.mem_out_rec[step][batch].tolist()])\n'
print_mem_log += '#auto_gen'

#____________________________________________cell body_______________________________________
create_dianet_cell = f'''

class Actfun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        #temp = abs(input - thresh) < lens
        #temp=1/(1+torch.exp(-(input-0)))
        temp = torch.exp( -(input - thresh) **2/(2 * lens ** 2) ) / ((2 * lens * 3.141592653589793) ** 0.5) 
        return grad_input * temp.float()

actfun = Actfun.apply
# membrane potential update
def mem_update(x, mem, spike):
    mem = mem *(1. - spike) + x
    spike = actfun(mem) # actfun : approximation firing function
    return mem, spike


# Create Dianet

class Dianet(nn.Module):
    def __init__(self):
        super(Dianet, self).__init__()
        
        #mnist所需的卷积操作, 压缩输入维度
        self.conv1 = nn.Conv2d(1, 16, 3, 1, 0)
        self.bnconv1 = nn.BatchNorm2d(16)
        self.maxpool1 = nn.MaxPool2d(2,stride=2)
        self.conv2 = nn.Conv2d(16, 16, 5)
        self.bnconv2 = nn.BatchNorm2d(16)
        self.maxpool2 = nn.MaxPool2d(2,stride=2)
        
        #上一层有几个神经元+输入抽头+跳层输入, mask里的每一项就有几个
        #本层有几个神经元, mask里就有几项
{create_fc_nrm_msk}
        

    
    def forward(self, pre_x, w_dwn = {w_dwn}, w_up = {w_up}, b_dwn = {b_dwn}, b_up = {b_up}):

        # Initialize hidden states at t=0 时间相关的神经元, 每轮模拟开始时初始化LIF状态
{init_neuron_mem_spk}


        # Record the final layer 收集脉冲和膜电位信息
{create_log_rec}



        for step in range(step_num):
        # input → hid1    
            spk_to_conv1 = pre_x > torch.rand(pre_x.size(), device=device)
            cur_to_conv1 = self.bnconv1(self.conv1(spk_to_conv1.float()))
            mem_conv1, spk_conv1 = mem_update(cur_to_conv1, mem_conv1, spk_conv1)
            spk_to_conv2 = self.maxpool1(spk_conv1)
            cur_to_conv2 = self.bnconv2(self.conv2(spk_to_conv2))
            mem_conv2, spk_conv2 = mem_update(cur_to_conv2, mem_conv2, spk_conv2)
            convout = self.maxpool2(spk_conv2) 
            x = convout.view(convout.size()[0], -1)

            #本质上是模拟输入, 并非真实的脉冲
{dianet_iterate}
            
            #保存信息
{save_spk_mem_rec}
            
            
            spksum_out += mem_out
        spk_out_avg = spksum_out / step_num
        return spk_out_avg






    def print_spk_info(self):
        print("Spike at each time step:")
        for batch in range(0, 1):
            print(f"Batch num {{batch}}:")
            for step in range(step_num):
                print(f"Time step {{step}}:")
{print_spk_log}



    def print_mem_info(self):
        print("Mem at each time step:")
        for batch in range(0, 1):
            print(f"Batch num {{batch}}:")
            for step in range(step_num):
                print(f"Time step {{step}}:")
{print_mem_log}

                
                
        
'''
#____________________________________________________________________________________________



In [29]:

#____________________________________________cell body_______________________________________
load_snn_cell = '''
# Load the network onto device
net = Dianet().to(device)

criteon = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=initial_lr, weight_decay=0.00001,betas=(0.9, 0.99))
#optimizer = optim.Adam(net.parameters(), lr=initial_lr,weight_decay=0.00001,betas=(0.9, 0.99))
'''
#____________________________________________________________________________________________




In [30]:

#____________________________________________cell body_______________________________________
training_cell = '''
# Training

def adjust_learning_rate(epoch_num):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = initial_lr * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr 

loss_hist = []
test_loss_hist = []

for epoch in range(epoch_num):
    adjust_learning_rate(epoch)
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device)
        target = target.to(device)

        net.train()
        logits = net(data)
        loss = criteon(logits, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_hist.append(loss.item())

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       20. * batch_idx / len(train_loader), loss.item()))

    test_loss = 0
    correct = 0
    net.eval()
    for data, target in val_loader:
        data = data.to(device)
        target = target.to(device)
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()
        test_loss_hist.append(test_loss)
        
    test_loss /= len(val_loader.dataset)
    print('\\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(
        test_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))

        
# 获取当前时间
current_time = datetime.now()
print('Start time: ' + str(current_time))


'''
#____________________________________________________________________________________________



In [31]:
#____________________________________________cell body_______________________________________
draw_loss_cell = '''
# Plot Loss
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(torch.tensor(loss_hist).cpu())
plt.plot(torch.tensor(test_loss_hist).cpu())
plt.title("Loss Curves")
plt.legend(["Train Loss", "Test Loss"])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()
'''
#____________________________________________________________________________________________



In [32]:
#____________________________________________cell body_______________________________________
run_testset_cell = '''
# Accuracy
total = 0
correct = 0

# drop_last switched to False to keep all samples
#test_loader = DataLoader(wine_test, batch_size=batch_size, shuffle=True, drop_last=False)

test_loss = 0
correct = 0
for data, target in test_loader:
    data = data.to(device)
    target = target.to(device)
    logits = net(data)
    test_loss += criteon(logits, target).item()

    pred = logits.data.max(1)[1]
    correct += pred.eq(target.data).sum()

    #print(data)
    #print('tar: ' + str(targets))
    #print('prd: ' + str(predicted))
    #print(test_spk)
    #net.print_spk_info()
    net.print_mem_info()
    

test_loss /= len(test_loader.dataset)
print('\\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

'''
#____________________________________________________________________________________________



In [33]:
print_layer_info = ''

for i in range(1, len(dianet_neuron_array)):    #从第[1]行神经元开始遍历, 因为第[0]行是纯输入, 处理方式不同
    print_layer_info += f'print(net.fc{i}.weight.data)\n'
    print_layer_info += f'print(net.fc{i}.bias.data)\n\n'

#____________________________________________cell body_______________________________________
other_cell = f'''

{print_layer_info}

'''

#____________________________________________________________________________________________



In [None]:
save_model_cell = f'''

# Save the model

# 获取当前时间, 转换为指定的形式(如202307311453)
current_time = datetime.now()
formatted_datetime = current_time.strftime("%Y%m%d%H%M")

torch.save(net.state_dict(), formatted_datetime + '_snndianet_model_checkpoint.pth')
torch.save(net, formatted_datetime + '_snndianet_full_model_checkpoint.pth')

'''



In [34]:
# improt_cell, 程序必备的包
nb.cells.append(new_code_cell(source=improt_cell[1:]))

# parameter_cell, 对程序进行配置
nb.cells.append(new_code_cell(source=parameter_cell[1:]))

# organize_dataset_cell, 对数据集进行处理
nb.cells.append(new_code_cell(source=organize_dataset_cell[1:]))

# create_dianet_cell, 利用脚本自动生成dianet
nb.cells.append(new_code_cell(source=create_dianet_cell[1:]))

# load_snn_cell, 将模型加载到device
nb.cells.append(new_code_cell(source=load_snn_cell[1:]))

# training_cell, 完成模型的训练
nb.cells.append(new_code_cell(source=training_cell[1:]))

# save_model_cell, 保存训练后的模型
nb.cells.append(new_code_cell(source=save_model_cell[1:]))

# draw_loss_cell, 绘制loss图
nb.cells.append(new_code_cell(source=draw_loss_cell[1:]))

# run_testset_cell, 执行测试集
nb.cells.append(new_code_cell(source=run_testset_cell[1:]))

# 将other_cell添加到笔记本中
nb.cells.append(new_code_cell(source=other_cell[1:]))

# 将笔记本保存为.ipynb文件
nbf.write(nb, filename)