In [1]:
from networks.LUTDeiT import LUT_DeiT, LUT_Distilled_DeiT, Attention2, create_target

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from timm import create_model
from operations.amm_linear import LUT_Linear
from qat.export.utils import replace_module_by_name, fetch_module_by_name
from torchinfo import summary
from thop import profile, clever_format
import torch.nn as nn


In [3]:
import time
import torch
import torch_tensorrt
import numpy as np
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

def benchmark(model, input_shape=(1024, 3, 512, 512), dtype='fp32', nwarmup=50, nruns=1000, cuda=False):
    input_data = torch.randn(input_shape)
    if cuda:
        input_data = input_data.to("cuda")
    if dtype=='fp16':
        input_data = input_data.half()
        
    print("Warm up ...")
    with torch.no_grad():
        for _ in range(nwarmup):
            features = model(input_data)
    torch.cuda.synchronize()
    print("Start timing ...")
    timings = []
    with torch.no_grad():
        for i in range(1, nruns+1):
            start_time = time.time()
            pred_loc  = model(input_data)
            torch.cuda.synchronize()
            end_time = time.time()
            timings.append(end_time - start_time)
            if i%10==0:
                print('Iteration %d/%d, avg batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))

    print("Input shape:", input_data.size())
    print('Average throughput: %.2f images/second'%(input_shape[0]/np.mean(timings)))

In [27]:
# model_name = 'resmlp_12_224.fb_in1k'
model_name = 'deit3_small_patch16_224.fb_in22k_ft_in1k'

model = create_model(model_name=model_name, pretrained=False)
model.eval()
subvec_len =32
ncodebooks = {
        "attn.qkv": 384 // subvec_len,
        # "attn.q_linear": 384 // subvec_len, 
        # "attn.k_linear": 384 // subvec_len, 
        "mlp.fc1": 384 // subvec_len
        # "mlp.fc2": 1536 // subvec_len
        
        # "linear_tokens": 196 // 14, 
        # "mlp_channels.fc1": 384 // subvec_len,
        # "mlp_channels.fc2":1536 // subvec_len
    }
for i in range(0, 12): 
    for name in ncodebooks:
        layer = model.blocks[i]
        module = fetch_module_by_name(layer, name)
        amm_linear = LUT_Linear(
        # amm_linear = PQLinear(
            ncodebooks[name],
            module.in_features,
            module.out_features,
            module.bias is not None,
            k=16
        )
        # print(amm_linear.weight.data.shape)
        # print(module.weight.data.shape)
        # weight = rearrange(module.weight.data, 'o i -> i o')
        # weight = rearrange(weight, '(c v) o -> c v o', c=ncodebooks[name], v=subvec_len)
        # amm_linear.weight.data.copy_(weight.data)
        # amm_linear.bias.data.copy_(module.bias.data)
        replace_module_by_name(layer, name, amm_linear)
model_compressed = model

In [20]:
def count_lut_linear(module: nn.Module, input: torch.Tensor, output: torch.Tensor):
    # 计算参数
    # n_params = module.ncodebooks * module.k * module.subvec_len  # centroids
    # n_params += module.ncodebooks * module.k * module.out_features/(32/8)  # luts 
    # if module.bias is not None:
    #     n_params += module.out_features  # bias
    # module.total_params[0] = torch.DoubleTensor([n_params])
    n_mults = module.ncodebooks * input[0].shape[0] * module.k * module.subvec_len
    n_adds = module.ncodebooks * input[0].shape[0] * module.k * (module.subvec_len - 1)
    n_lut_adds = module.ncodebooks * input[0].shape[0] * module.out_features
    if module.bias is not None:
        n_adds += input[0].shape[0] * module.out_features
    module.total_ops += torch.DoubleTensor([n_mults + n_adds + n_lut_adds])

custom_ops = {LUT_Linear: count_lut_linear}


In [None]:
input = torch.randn(1, 3, 224, 224).cuda()
flops, params = profile(model_compressed.cuda(), inputs=(input, ), custom_ops=custom_ops)
flops, params = clever_format([flops, params], "%.3f")
print(f"FLOPS: {flops}, Params: {params}")

In [None]:
model = create_model(model_name=model_name, pretrained=True)
flops, params = profile(model.cuda(), inputs=(input, ))
flops, params = clever_format([flops, params], "%.3f")
print(f"FLOPS: {flops}, Params: {params}")

In [None]:
summary(model_compressed, (1, 3, 224, 224))

In [None]:
model_compressed.eval()
# model_compressed = torch.compile(model_compressed)
trt_model = torch_tensorrt.compile(model_compressed.cuda(), 
    inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
    truncate_long_and_double = True,
    # enabled_precisions= { torch_tensorrt.dtype.half} # Run with FP16
    enabled_precisions= { torch.float} # Run with FP16
)
# trt_ts_module = torch_tensorrt.compile(model, 
#     inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
#     enabled_precisions= { torch.float} # Run with FP16
# )

In [30]:
# benchmark(model_compressed, input_shape=(1, 3, 224, 224), nruns=100)
# benchmark(trt_model, input_shape=(1, 3, 224, 224), nruns=100, cuda=True, dtype="fp16")
benchmark(trt_model, input_shape=(1, 3, 224, 224), nruns=100, cuda=True)

Warm up ...
Start timing ...
Iteration 10/100, avg batch time 63.39 ms
Iteration 20/100, avg batch time 59.57 ms
Iteration 30/100, avg batch time 58.94 ms
Iteration 40/100, avg batch time 57.68 ms
Iteration 50/100, avg batch time 57.82 ms
Iteration 60/100, avg batch time 59.55 ms
Iteration 70/100, avg batch time 59.55 ms
Iteration 80/100, avg batch time 59.06 ms
Iteration 90/100, avg batch time 57.62 ms
Iteration 100/100, avg batch time 58.59 ms
Input shape: torch.Size([1, 3, 224, 224])
Average throughput: 17.07 images/second


In [7]:
model = create_model(model_name="deit3_small_patch16_224.fb_in1k", pretrained=True)
# model = model.half()
# summary(model, (1, 3, 224, 224))

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/deit3_small_patch16_224.fb_in1k)
INFO:timm.models._hub:[timm/deit3_small_patch16_224.fb_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [12]:
model = create_model(model_name="deit3_small_patch16_224.fb_in1k", pretrained=True)
flops, params = profile(model.cuda(), inputs=(input, ))
flops, params = clever_format([flops, params], "%.3f")
print(f"FLOPS: {flops}, Params: {params}")

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/deit3_small_patch16_224.fb_in1k)
INFO:timm.models._hub:[timm/deit3_small_patch16_224.fb_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
FLOPS: 4.249G, Params: 21.975M


In [11]:
# model = torch.compile(model)
model.eval()
trt_model1 = torch_tensorrt.compile(model.cuda(), 
    inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
    truncate_long_and_double = True,
    # enabled_precisions= { torch_tensorrt.dtype.half} # Run with FP16
    enabled_precisions= { torch.float} # Run with FP16
)



In [35]:
benchmark(model, input_shape=(1, 3, 224, 224), nruns=100)
# benchmark(trt_model1, input_shape=(1, 3, 224, 224), nruns=100, cuda=True, dtype="fp16")
# benchmark(trt_model1, input_shape=(1, 3, 224, 224), nruns=100, cuda=True)

Warm up ...
Start timing ...
Iteration 10/100, avg batch time 50.39 ms
Iteration 20/100, avg batch time 50.11 ms
Iteration 30/100, avg batch time 49.69 ms
Iteration 40/100, avg batch time 49.91 ms
Iteration 50/100, avg batch time 50.29 ms
Iteration 60/100, avg batch time 51.29 ms
Iteration 70/100, avg batch time 51.20 ms
Iteration 80/100, avg batch time 51.13 ms
Iteration 90/100, avg batch time 51.07 ms
Iteration 100/100, avg batch time 50.98 ms
Input shape: torch.Size([1, 3, 224, 224])
Average throughput: 19.61 images/second
