In [None]:
import torch
import torch.nn as nn
from contextlib import contextmanager
import time

# 时间计算器
@contextmanager
def timed_block(label,loop = 1):
    start = time.perf_counter()
    try:
        yield
    finally:
        end = time.perf_counter()
        print(f"{label} : {(end - start)/loop:.6f}s")

In [None]:
# 模型定义
# 线性层
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # 使用浮点数类型
        self.linear = nn.Linear(2048, 2048)  

    def forward(self, x):
        return self.linear(x)

In [None]:
# 前馈层
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input=2048, ffn_num_hiddens=2048 , ffn_num_outputs=2048, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))
    
    def createdata(self,batch=128,seq=512,hidden=2048,dtype = torch.float32):
        return torch.randn((batch,seq,hidden),dtype = dtype) 

In [None]:
# infer

# 参数修改
dtype = torch.float32
is_compression = True
compression_dtype = torch.quint8
device = torch.device("cpu")
# assert not(dtype == torch.float16 and is_compression) ,"semi-precision and compression can't be used togther."


with timed_block("total"):
    with timed_block("create model"):
        model = PositionWiseFFN(ffn_num_input = 2048, ffn_num_hiddens =2048 , ffn_num_outputs=2048)

    with timed_block("model.half"):    
        if dtype == torch.float16:
            model.half()
    
    with timed_block("model.quantization"):  
        if is_compression:
            model = torch.quantization.quantize_dynamic(model, dtype=compression_dtype)
        print(model)

    with timed_block("model.to"):
        model.to(device)
    
    with timed_block("model.eval"):
        model.eval()
    

    with timed_block("create data"):
        # data = torch.randn(1000, 2048, dtype=dtype) 
        data = model.createdata(batch=128,seq=512,hidden=2048,dtype=dtype)

    with timed_block("data.to"):
        data = data.to(device)

    # 推理过程
    num_epochs = 1
    with timed_block("inference avg time",num_epochs):
        for epoch in range(num_epochs):
            outputs = model(data)

print("Inference {} {} Done!".format(device, compression_dtype if is_compression else dtype))
