In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# model = LeNet()
model = LeNet().to(device=device)
parameters_to_prune = (
 (model.conv1, 'weight'),
 (model.conv2, 'weight'),
 (model.fc1, 'weight'),
 (model.fc2, 'weight'),
 (model.fc3, 'weight'),
)
print(parameters_to_prune)
prune.global_unstructured(
 parameters_to_prune,
 pruning_method=prune.L1Unstructured,
    amount=0.2,
)
# 计算卷积层和整个模型的稀疏度
# 其实调用的是 Tensor.numel 内内函数，返回输入张量中元素的总数
print(
 "Sparsity in conv1.weight: {:.2f}%".format(
 100. * float(torch.sum(model.conv1.weight == 0))
 / float(model.conv1.weight.nelement())
 )
)
zero_number = 0
total_bumber = 0
for name, m in model.named_modules():
    if isinstance(m,torch.nn.Linear) or isinstance(m,torch.nn.Conv2d):
        zero_number = zero_number + torch.sum(m.weight==0)
        total_bumber = total_bumber + m.weight.numel()

print(
 "Global sparsity: {:.2f}%".format(
 100. * (zero_number/total_bumber)
 )
)

((Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1)), 'weight'), (Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1)), 'weight'), (Linear(in_features=400, out_features=120, bias=True), 'weight'), (Linear(in_features=120, out_features=84, bias=True), 'weight'), (Linear(in_features=84, out_features=10, bias=True), 'weight'))
Sparsity in conv1.weight: 1.85%
Global sparsity: 20.00%


In [22]:
for name, m in model.named_modules():
    if isinstance(m,torch.nn.Linear) or isinstance(m,torch.nn.Conv2d):
        if hasattr(m,"weight_mask"):
            print(m)
            m.weight.data = m.weight_orig

prune.global_unstructured(
 parameters_to_prune,
 pruning_method=prune.L1Unstructured,
    amount=0.3,
)
# 计算卷积层和整个模型的稀疏度
# 其实调用的是 Tensor.numel 内内函数，返回输入张量中元素的总数
print(
 "Sparsity in conv1.weight: {:.2f}%".format(
 100. * float(torch.sum(model.conv1.weight == 0))
 / float(model.conv1.weight.nelement())
 )
)
zero_number = 0
total_bumber = 0
for name, m in model.named_modules():
    if isinstance(m,torch.nn.Linear) or isinstance(m,torch.nn.Conv2d):
        zero_number = zero_number + torch.sum(m.weight==0)
        total_bumber = total_bumber + m.weight.numel()

print(
 "Global sparsity: {:.2f}%".format(
 100. * (zero_number/total_bumber)
 )
)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
Linear(in_features=400, out_features=120, bias=True)
Linear(in_features=120, out_features=84, bias=True)
Linear(in_features=84, out_features=10, bias=True)
Sparsity in conv1.weight: 5.56%
Global sparsity: 44.00%


In [1]:
import torch
from spike_quan_layer_IntegerOnly import PowerNormQuan,QInferAttention,QAttention,PowerNormInfer,LsqQuan
from PowerNorm import MaskPowerNorm

PN = MaskPowerNorm(num_features=4)
x = torch.rand((2,3,4))*4
x_quan_fn = LsqQuan(bit=4,per_channel=False)
x_quan_fn.init_from(x)
PN_quan_fn = LsqQuan(bit=4,per_channel=False)

x_quan = x_quan_fn(x)

print(x_quan_fn)
print(x_quan)

PNQuan = PowerNormQuan(m=PN,quan_fn=PN_quan_fn)
output1 = PNQuan(x_quan)

output1 = PNQuan(x_quan)

PNQuanInfer = PowerNormInfer(m=PNQuan,last_act_quan=x_quan_fn)
print(PNQuanInfer.N1,PNQuanInfer.M1)
output2 = PNQuanInfer(x_quan/x_quan_fn.s)

print(output1/PNQuan.quan_fn.s)
print(output2)


LsqQuan(thd_pos=7, thd_neg=-8, s=1.620495319366455, per_channel=False)
tensor([[[1.6205, 3.2410, 1.6205, 3.2410],
         [1.6205, 3.2410, 0.0000, 3.2410],
         [1.6205, 1.6205, 1.6205, 3.2410]],

        [[3.2410, 1.6205, 0.0000, 3.2410],
         [0.0000, 0.0000, 3.2410, 1.6205],
         [3.2410, 1.6205, 3.2410, 1.6205]]], grad_fn=<MulBackward0>)
Parameter containing:
tensor([7, 7, 7, 7], dtype=torch.int32) Parameter containing:
tensor([135, 135, 135, 135], dtype=torch.int32)
tensor([[[1., 2., 1., 2.],
         [1., 2., 0., 2.],
         [1., 1., 1., 2.]],

        [[2., 1., 0., 2.],
         [0., 0., 2., 1.],
         [2., 1., 2., 1.]]], grad_fn=<DivBackward0>)
tensor([[[1., 2., 1., 2.],
         [1., 2., 0., 2.],
         [1., 1., 1., 2.]],

        [[2., 1., 0., 2.],
         [0., 0., 2., 1.],
         [2., 1., 2., 1.]]], grad_fn=<TransposeBackward0>)


  from .autonotebook import tqdm as notebook_tqdm


In [16]:
import torch
from spike_quan_layer_IntegerOnly import QInferAttention,QAttention,LsqQuan

torch.set_default_dtype(torch.double)
torch.set_default_tensor_type(torch.DoubleTensor)

while(1):
    x = torch.rand((1,3,4))*4
    quan_x = LsqQuan(bit=4,per_channel=False)
    quan_x.init_from(x)
    x = quan_x(x)

    quan_qkv_weight = LsqQuan(bit=4,per_channel=False)
    quan_proj_weight = LsqQuan(bit=4,per_channel=False)
    quan_q = LsqQuan(bit=4,per_channel=False)
    quan_k = LsqQuan(bit=4,per_channel=False)
    quan_v = LsqQuan(bit=4,per_channel=False)
    quan_proj = LsqQuan(bit=4,per_channel=False)
    attn_quan = LsqQuan(bit=4,per_channel=False)
    after_attn_quan = LsqQuan(bit=4,per_channel=False)

    QAtten = QAttention(dim=4,quan_qkv_weight=quan_qkv_weight,quan_proj_weight=quan_proj_weight,quan_q=quan_q,quan_k=quan_k,quan_v=quan_v,quan_proj=quan_proj,attn_quan=attn_quan,after_attn_quan=after_attn_quan,num_heads=2)

    # initialize quantize attention
    output = QAtten(x)
    # print("quan_qkv_weight",QAtten.quan_qkv_weight)
    # print("quan_proj_weight",QAtten.quan_proj_weight)
    # print("quan_q",QAtten.quan_q)
    # print("quan_k",QAtten.quan_k)
    # print("quan_v",QAtten.quan_v)
    # print("quan_proj",QAtten.quan_proj)
    # print("attn_quan",QAtten.attn_quan)
    # print("after_attn_quan",QAtten.after_attn_quan)

    QAttenInfer = QInferAttention(m=QAtten, last_act_quan=quan_x)

    # Compare output:
    output1 = QAtten(x)/QAtten.quan_proj.s
    output2 = QAttenInfer(x/quan_x.s)

    if torch.abs(QAtten.q - QAttenInfer.q).sum() >= 1:
        print("quan_x",quan_x)
        print("quan_qkv_weight",QAtten.quan_qkv_weight)
        print("quan_proj_weight",QAtten.quan_proj_weight)
        print("scale",QAtten.scale)
        print("quan_q",QAtten.quan_q)
        print("quan_k",QAtten.quan_k)
        print("quan_v",QAtten.quan_v)
        print("quan_proj",QAtten.quan_proj)
        print("attn_quan",QAtten.attn_quan)
        print("after_attn_quan",QAtten.after_attn_quan)
        print(torch.abs(QAtten.q - QAttenInfer.q).sum())
        print("=====================QAtten.q, QAttenInfer.q======================")
        print(QAttenInfer.neuron_q_M,QAttenInfer.neuron_q_N)
        print(QAtten.q1/quan_q.s)
        print("==================================================================")
        print(QAttenInfer.q1)
        print(QAttenInfer.q1*quan_x.s*QAtten.quan_qkv_weight.s*QAtten.scale/quan_q.s)
        print(QAttenInfer.q1*QAttenInfer.neuron_q_M/(2**QAttenInfer.neuron_q_N))
        print("==================================================================")
        print(QAtten.q)
        print(QAttenInfer.q)
        print("==================================================================")
        print(torch.abs(QAtten.k - QAttenInfer.k).sum())
        print(torch.abs(QAtten.v - QAttenInfer.v).sum())
        print(torch.abs(QAtten.attn - QAttenInfer.attn).sum())
        print(torch.abs(output1 - output2).sum())
        break



quan_x LsqQuan(thd_pos=7, thd_neg=-8, s=1.4090669638576874, per_channel=False)
quan_qkv_weight LsqQuan(thd_pos=7, thd_neg=-8, s=0.19660905948470428, per_channel=False)
quan_proj_weight LsqQuan(thd_pos=7, thd_neg=-8, s=0.17160493360573897, per_channel=False)
scale 0.7071067811865476
quan_q LsqQuan(thd_pos=7, thd_neg=-8, s=1.0474673340955825, per_channel=False)
quan_k LsqQuan(thd_pos=7, thd_neg=-8, s=0.9033111870689389, per_channel=False)
quan_v LsqQuan(thd_pos=7, thd_neg=-8, s=0.8982452689186882, per_channel=False)
quan_proj LsqQuan(thd_pos=7, thd_neg=-8, s=0.30580500484626777, per_channel=False)
attn_quan LsqQuan(thd_pos=7, thd_neg=-8, s=0.3828836410068537, per_channel=False)
after_attn_quan LsqQuan(thd_pos=7, thd_neg=-8, s=0.3365168127666457, per_channel=False)
tensor(2., grad_fn=<SumBackward0>)
Parameter containing:
tensor(192) Parameter containing:
tensor(10)
tensor([[[[ 1.1221, -1.4961],
          [ 0.7481, -1.3091],
          [ 1.1221, -0.1870]],

         [[ 0.9351,  1.1221],
   

In [4]:
import torch
from spike_quan_layer_IntegerOnly import AdditionInfer,AdditionQuan,LsqQuan

x1 = torch.rand((1,16,16))*4
x2 = torch.rand((1,16,16))*4

x1_quan_fn = LsqQuan(bit=4,per_channel=False)
x2_quan_fn = LsqQuan(bit=4,per_channel=False)

x1_quan_fn.init_from(x1)
x2_quan_fn.init_from(x2)

x1_quan = x1_quan_fn(x1)
x2_quan = x2_quan_fn(x2)

AddQuan = AdditionQuan(quan_a_fn=LsqQuan(bit=4,per_channel=False))

output1 = AddQuan(x1_quan,x2_quan)
output1 = AddQuan(x1_quan,x2_quan)/AddQuan.quan_a_fn.s

AddInfer = AdditionInfer(m=AddQuan,quan_input1_fn=x1_quan_fn,quan_input2_fn=x2_quan_fn)
output2 = AddInfer(x1_quan/x1_quan_fn.s,x2_quan/x2_quan_fn.s)

print(output1==output2)



tensor([[[True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True, True,
          True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, 

In [2]:
from spike_quan_layer import MyLayerNorm
import torch

LN = torch.nn.LayerNorm(normalized_shape=4,elementwise_affine=True,eps=1e-6).eval()
MyLN = MyLayerNorm(dim=4,eps=1e-6).eval()

LN.weight.data = torch.rand(LN.weight.data.shape)
LN.bias.data = torch.rand(LN.bias.data.shape)
MyLN.weight.data = LN.weight.data
MyLN.bias.data = LN.bias.data
print(LN.weight)
print(MyLN.weight)
print(LN.bias)
print(MyLN.bias)

x = torch.rand((1,4,4))
print(x)
output = LN(x)
output1 = MyLN(x)

print(output)
print(output1)



Parameter containing:
tensor([0.2620, 0.8055, 0.6904, 0.4581], requires_grad=True)
Parameter containing:
tensor([0.2620, 0.8055, 0.6904, 0.4581], requires_grad=True)
Parameter containing:
tensor([0.0708, 0.6759, 0.6418, 0.8788], requires_grad=True)
Parameter containing:
tensor([0.0708, 0.6759, 0.6418, 0.8788], requires_grad=True)
tensor([[[0.2103, 0.5787, 0.7762, 0.5156],
         [0.0825, 0.4825, 0.6976, 0.4493],
         [0.9010, 0.0231, 0.0402, 0.4332],
         [0.5531, 0.5815, 0.2986, 0.5273]]])
tensor([[[-0.3289,  0.9079,  1.5119,  0.8684],
         [-0.3386,  0.8745,  1.4838,  0.9230],
         [ 0.4742, -0.0577,  0.0459,  0.9860],
         [ 0.2178,  1.3318, -0.5364,  1.0305]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[[-0.2754,  0.8768,  1.3953,  0.8698],
         [-0.2838,  0.8479,  1.3710,  0.9171],
         [ 0.4202,  0.0406,  0.1257,  0.9716],
         [ 0.1981,  1.2439, -0.3786,  1.0102]]], grad_fn=<AddBackward0>)


In [5]:
import torch
from thop import profile
from thop import clever_format
from models_vit import vit_small_patch16
from torchvision import models

vit_small = vit_small_patch16()
vit_small.eval()

input = torch.randn(1, 3, 224, 224)  # 随机生成一个输入张量，这个尺寸应该与模型输入的尺寸相匹配
flops, params = profile(vit_small, inputs=(input,))

# 将结果转换为更易于阅读的格式
flops, params = clever_format([flops, params], '%.3f')

print(f"运算量：{flops}, 参数量：{params}")

vit_small = models.resnet34()
vit_small.eval()

input = torch.randn(1, 3, 224, 224)  # 随机生成一个输入张量，这个尺寸应该与模型输入的尺寸相匹配
flops, params = profile(vit_small, inputs=(input,))

# 将结果转换为更易于阅读的格式
flops, params = clever_format([flops, params], '%.3f')

print(f"运算量：{flops}, 参数量：{params}")



self.global_pool False
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
运算量：4.249G, 参数量：21.975M
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.module

In [6]:
import torch
import torchvision
from models_vit import vit_small_patch16

vit_small = vit_small_patch16()
vit_small.eval()

print(vit_small)

input_shape = {}
output_shape = {}

def layer_hook(module, inp, out):
    global input_shape, output_shape
    input_shape[module] = inp[0].shape
    output_shape[module] = out.shape

for name,module in list(vit_small.named_modules()):
    if isinstance(module,torch.nn.Conv2d):
        module.register_forward_hook(layer_hook)
    if isinstance(module,torch.nn.Linear):
        module.register_forward_hook(layer_hook)

input = torch.rand((1,3,224,224))

output = vit_small(input)
Head = 6

idx = 0
for name,module in list(vit_small.named_modules()):
    print(name)
    if isinstance(module,torch.nn.Linear) and name.count("qkv") > 0:
        for i in range(3):
            file = open(f"forEyeriss/{str(idx).rjust(2,'0')}.yaml","w+")
            file.write("{{include_text('../problem_base.yaml')}}\n")
            file.write("problem:\n")
            file.write("  <<<: *problem_base\n")
            inputShape = input_shape[module]
            outputShape = output_shape[module]
            file.write(f"  instance: {{C: {inputShape[-1]}, M: {outputShape[-1]//3}, P: {outputShape[-2]}}}\n")
            idx = idx + 1
            file.close()
        file = open(f"forEyeriss/{str(idx).rjust(2,'0')}.yaml","w+")
        file.write("{{include_text('../problem_base.yaml')}}\n")
        file.write("problem:\n")
        file.write("  <<<: *problem_base\n")
        inputShape = input_shape[module]
        outputShape = output_shape[module]
        file.write(f"  instance: {{M: {inputShape[-1]//Head}, C: {outputShape[-2]}, P: {Head}}}\n")
        idx = idx + 1
        file.close()

        file = open(f"forEyeriss/{str(idx).rjust(2,'0')}.yaml","w+")
        file.write("{{include_text('../problem_base.yaml')}}\n")
        file.write("problem:\n")
        file.write("  <<<: *problem_base\n")
        inputShape = input_shape[module]
        outputShape = output_shape[module]
        file.write(f"  instance: {{M: {outputShape[-2]}, C: {outputShape[-2]}, P: {Head}}}\n")
        idx = idx + 1
        file.close()
    elif isinstance(module,torch.nn.Linear):
        file = open(f"forEyeriss/{str(idx).rjust(2,'0')}.yaml","w+")
        file.write("{{include_text('../problem_base.yaml')}}\n")
        file.write("problem:\n")
        file.write("  <<<: *problem_base\n")
        inputShape = input_shape[module]
        outputShape = output_shape[module]
        file.write(f"  instance: {{C: {inputShape[-1]}, M: {outputShape[-1]}, P: {outputShape[-2]}}}\n")
        idx = idx + 1
        file.close()
    elif isinstance(module,torch.nn.Conv2d):
        file = open(f"forEyeriss/{str(idx).rjust(2,'0')}.yaml","w+")
        file.write("{{include_text('../problem_base.yaml')}}\n")
        file.write("problem:\n")
        file.write("  <<<: *problem_base\n")
        outputShape = output_shape[module]
        file.write(f"  instance: {{C: {module.in_channels}, M: {module.out_channels}, P: {outputShape[-1]}, Q: {outputShape[-2]}, R: {module.kernel_size[0]}, S: {module.kernel_size[1]}, HStride: {module.stride[0]}, WStride: {module.stride[0]}}}\n")
        idx = idx + 1
        file.close()
    # if isinstance(module,torch.nn.Linear):
    #     file = open(f"forEyeriss/{str(idx).rjust(2,'0')}.yaml","w+")
    #     file.write("{{include_text('../problem_base.yaml')}}\n")
    #     file.write("problem:\n")
    #     file.write("  <<<: *problem_base\n")
    #     inputShape = input_shape[module]
    #     outputShape = output_shape[module]
    #     file.write(f"  instance: {{C: {inputShape[-1]}, M: {outputShape[-1]}}}\n")
    #     idx = idx + 1
    #     file.close()
    

self.global_pool False
VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  

In [None]:
import torch
from copy import deepcopy
import torchvision
from models_vit import vit_small_patch16

vit_small = vit_small_patch16()

communication_traffic = 0.0

def hook_input_feature_size(module,input,output):
    global communication_traffic
    if module.in_features == 384 and module.out_features == 384:
    communication_traffic = communication_traffic + torch.prod(torch.tensor(input[0].shape))*8
    print(module,communication_traffic)

for module in vit_small.modules():
    if isinstance(module,torch.nn.Linear) or isinstance(module,torch.nn.Conv2d) or isinstance(module,torch.nn.MaxPool2d) or isinstance(module,torch.nn.AdaptiveAvgPool2d):
        module.register_forward_hook(hook_input_feature_size)

x = torch.rand(1,3,224,224)

out = vit_small(x)
print("communication_traffic",(communication_traffic/(1024*1024*8)).item(),"MB")


self.global_pool False
Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16)) tensor(1204224.)
Linear(in_features=384, out_features=1152, bias=True) tensor(1809408.)
Linear(in_features=384, out_features=384, bias=True) tensor(2414592.)
Linear(in_features=384, out_features=1536, bias=True) tensor(3019776.)
Linear(in_features=1536, out_features=384, bias=True) tensor(5440512.)
Linear(in_features=384, out_features=1152, bias=True) tensor(6045696.)
Linear(in_features=384, out_features=384, bias=True) tensor(6650880.)
Linear(in_features=384, out_features=1536, bias=True) tensor(7256064.)
Linear(in_features=1536, out_features=384, bias=True) tensor(9676800.)
Linear(in_features=384, out_features=1152, bias=True) tensor(10281984.)
Linear(in_features=384, out_features=384, bias=True) tensor(10887168.)
Linear(in_features=384, out_features=1536, bias=True) tensor(11492352.)
Linear(in_features=1536, out_features=384, bias=True) tensor(13913088.)
Linear(in_features=384, out_features=1152, bias=True)

In [None]:
import torch
from spike_quan_layer import spiking_softmax

x_list = []
x_accu = 0.0
T = 4
for i in range(T):
    x_list.append(torch.rand(2,12))
    x_accu = x_accu + x_list[-1]
ssoftmax = spiking_softmax(T=4)

y1 = torch.nn.functional.softmax(dim=-1)
y2 = 0.0
for i in range(T):
    y2 = y2 + ssoftmax(x_list[i])


In [None]:
import torch

QANN = torch.load("QANN_PATCH_BeforeNorm.pth")
SNN = torch.load("SNN_PATCH_BeforeNorm.pth").reshape(24,4,128,56,56)

print(QANN.abs().mean())


print(SNN.sum(dim=0).abs().mean())



In [None]:
import torch
from spike_quan_layer import DyHT, MyQuan


x = torch.rand(size=(1,8,1)).cuda()
mydyht = DyHT(C=1, init_alpha=0.25).cuda()
myquan = MyQuan(level=8, sym=True).cuda()
myquan.init_state = myquan.batch_init
myquan.s = mydyht.gamma/myquan.pos_max
print(mydyht.gamma, mydyht.alpha, myquan.s)
y1 = mydyht(x)
y2 = myquan(x * mydyht.gamma * mydyht.alpha)

print(y1,y2)


  self.alpha = nn.Parameter(torch.tensor(torch.ones(1) * init_alpha))
  self.gamma = nn.Parameter(torch.tensor(torch.ones(C)))


level 8
Parameter containing:
tensor([1.], device='cuda:0', requires_grad=True) Parameter containing:
tensor([0.2500], device='cuda:0', requires_grad=True) Parameter containing:
tensor(1., device='cuda:0', requires_grad=True)
myquan:s_scale= tensor(1., device='cuda:0', grad_fn=<AddBackward0>)
myquan:x/s_scale= tensor([[[0.7803],
         [0.9474],
         [1.4505],
         [0.7960],
         [1.4146],
         [0.7544],
         [1.4963],
         [0.6601]]], device='cuda:0', grad_fn=<AddBackward0>)
tensor([[[0.0701],
         [0.1118],
         [0.2376],
         [0.0740],
         [0.2286],
         [0.0636],
         [0.2491],
         [0.0400]]], device='cuda:0', grad_fn=<MulBackward0>) tensor([[[0.],
         [0.],
         [1.],
         [0.],
         [1.],
         [0.],
         [1.],
         [0.]]], device='cuda:0', grad_fn=<MulBackward0>)
