<a href="https://colab.research.google.com/github/06unoh/model_optimization/blob/main/model_optimization_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install onnx onnxscript onnxruntime

# ONNX Runtime C++ 바이너리 다운로드
!wget https://github.com/microsoft/onnxruntime/releases/download/v1.23.2/onnxruntime-linux-x64-1.23.2.tgz
!tar -xvf onnxruntime-linux-x64-1.23.2.tgz

Collecting onnx
  Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Collecting onnxscript
  Downloading onnxscript-0.5.6-py3-none-any.whl.metadata (13 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting onnx_ir<2,>=0.1.12 (from onnxscript)
  Downloading onnx_ir-0.1.12-py3-none-any.whl.metadata (3.2 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m60.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxscript-0.5.6-py3-none-any.whl (683 kB)
[2K   [90m━━━━━━━━━━━━━━━━

In [2]:
import os
import torch
from torch import nn
import numpy as np
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

In [3]:
# 기존 학습 SwinIR Class
def window_partition(x, window_size):
  B, H, W, C=x.shape
  x=x.view(B, H//window_size, window_size, W//window_size, window_size, -1)
  windows=x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  return windows

def window_reverse(window, window_size, H, W):
  B=int(window.shape[0]*window_size*window_size/(H*W))
  x=window.view(B, H//window_size, W//window_size, window_size, window_size, -1)
  x=x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)
  return x

def get_attn_mask(window_size, shift_size, H, W, device):
  img_mask=torch.zeros((1, H, W, 1), device=device)
  cnt=0

  for h in (slice(-window_size),slice(-window_size,-shift_size),slice(-shift_size,None)):
    for w in (slice(-window_size),slice(-window_size,-shift_size),slice(-shift_size,None)):
      img_mask[:, h, w, :]=cnt
      cnt+=1
  img_window=window_partition(img_mask, window_size)
  img_window=img_window.view(img_window.shape[0], -1)
  attn_mask=img_window.unsqueeze(1)-img_window.unsqueeze(2)
  attn_mask=attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask==0, float(0.0))
  return attn_mask


class WindowAttention(nn.Module):
  def __init__(self, dim, num_heads, window_size):
    super().__init__()
    self.dim=dim
    self.n_heads=num_heads
    self.window_size=window_size
    self.scale=(dim/self.n_heads)**-0.5

    self.qkv=nn.Linear(dim, 3*dim)
    self.proj=nn.Linear(dim, dim)

  def forward(self, x, mask=None):
    B_, N, C=x.shape
    qkv=self.qkv(x).view(B_, N, 3, self.n_heads,C//self.n_heads)
    q, k, v=qkv.permute(2,0,3,1,4).contiguous()
    attn=(q@k.transpose(-2, -1))*self.scale

    if mask is not None:
      nW=mask.shape[0]
      attn=attn.view(B_//nW, nW, self.n_heads, N, N)
      attn=attn+mask.unsqueeze(1).unsqueeze(0)
      attn=attn.view(B_,self.n_heads,N,N)
    attn=attn.softmax(dim=-1)
    attn=(attn@v).permute(0,2,1,3).contiguous().view(B_, N, C)
    return self.proj(attn)

class DropPath(nn.Module):
  def __init__(self, drop_prob):
    super().__init__()
    self.drop_prob=drop_prob

  def forward(self, x):
    if self.drop_prob==0 or not self.training:
      return x

    keep_drop=1-self.drop_prob
    B=x.shape[0]
    shape=(B,)+(1,)*(x.ndim-1)
    random_tensor=torch.rand(shape, dtype=x.dtype, device=x.device)+keep_drop
    random_tensor.floor_()
    return x.div(keep_drop)*random_tensor

class SwinTFBlock(nn.Module):
  def __init__(self, dim, num_heads, window_size, shift_size ,drop_prob=0.1):
    super().__init__()
    self.window_size=window_size
    self.shift_size=shift_size

    self.norm1=nn.LayerNorm(dim)
    self.attn=WindowAttention(dim, num_heads, window_size)
    self.drop_path1=DropPath(drop_prob)

    self.norm2=nn.LayerNorm(dim)
    self.mlp=nn.Sequential(
        nn.Linear(dim, dim*4),
        nn.GELU(),
        nn.Linear(dim*4, dim)
    )
    self.drop_path2=DropPath(drop_prob)

  def forward(self, x, H, W):
    B, N, C=x.shape
    shortcut=x
    x=self.norm1(x).view(x.shape[0], H, W, -1)

    if self.shift_size>0:
      x=torch.roll(x, shifts=(-self.shift_size, -self.shift_size),dims=(1,2))

    x_windows=window_partition(x, self.window_size)
    x_windows=x_windows.view(-1, self.window_size*self.window_size, C)    #(B_, N, C)

    attn_mask=get_attn_mask(self.window_size, self.shift_size, H, W, x.device) if self.shift_size>0 else None
    x_attn=self.attn(x_windows, mask=attn_mask)   #(B_,N,C)

    x_attn=x_attn.view(-1,self.window_size, self.window_size, C)
    x=window_reverse(x_attn, self.window_size, H, W)   # (B, H, W ,C)

    if self.shift_size>0:
      x=torch.roll(x, shifts=(self.shift_size, self.shift_size),dims=(1,2))

    x=x.view(-1, H*W, C)
    x=shortcut+self.drop_path1(x)
    x=x+self.drop_path2(self.mlp(self.norm2(x)))
    return x

class RSTB(nn.Module):
  def __init__(self, dim, num_heads, window_size, drop_prob, depth):
    super().__init__()
    self.blocks=nn.ModuleList([
        SwinTFBlock(
            dim,
            num_heads,
            window_size,
            shift_size=0 if i%2==0 else window_size//2,
            drop_prob=0.1)
        for i in range(depth)
    ])
    self.conv=nn.Conv2d(dim, dim, 3, 1, 1)

  def forward(self, x):
    B, H, W, C=x.shape
    shortcut=x

    x=x.view(B,H*W,C)
    for blk in self.blocks:
      x=blk(x, H, W)
    x=x.view(B, H, W, C).permute(0,3,1,2).contiguous()
    x=self.conv(x)
    x=x.permute(0,2,3,1).contiguous()
    x=shortcut+x
    return x

class SwinIR(nn.Module):
  def __init__(self,img_dim=3, embed_dim=256, num_heads=8, window_size=8, drop_prob=0.1, depth=4, depths=3):
    super().__init__()
    self.conv_first=nn.Conv2d(img_dim, embed_dim, 3, 1, 1)
    self.layers=nn.ModuleList([
        RSTB(embed_dim, num_heads, window_size, drop_prob, depth)
        for _ in range(depths)
    ])
    self.norm=nn.LayerNorm(embed_dim)
    self.conv_after_body=nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)

    self.upsample=nn.Sequential(
        nn.Conv2d(embed_dim, embed_dim*4, 3, 1, 1),
        nn.PixelShuffle(2),
        nn.Conv2d(embed_dim, embed_dim*4, 3, 1, 1),
        nn.PixelShuffle(2),
        nn.Conv2d(embed_dim, img_dim, 3, 1, 1)
    )

  def forward(self, x):
    B, C, H, W=x.shape

    x=self.conv_first(x)
    x=x.permute(0,2,3,1).contiguous()

    for layer in self.layers:   # 좋은데 기본 트랜스포머 구조 (B,L,C)
      x=layer(x)

    x=self.norm(x)
    x=x.permute(0,3,1,2).contiguous()
    x=self.conv_after_body(x)
    return self.upsample(x)

In [4]:
# ONNX Formatting
model=SwinIR()
checkpoint=torch.load('swinir_best.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
model.eval()

dummy=torch.randn(1, 3, 64, 64)

torch.onnx.export(
    model,
    dummy,
    'swinir_x4.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        "input": {0: "batch", 2: "height", 3: "width"},
        "output": {0: "batch", 2: "height", 3: "width"},
    }
)


  torch.onnx.export(


[torch.onnx] Obtain model graph for `SwinIR([...]` with `torch.export.export(..., strict=False)`...


  B=int(window.shape[0]*window_size*window_size/(H*W))


[torch.onnx] Obtain model graph for `SwinIR([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 168 of general pattern rewrite rules.


ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 20},
            producer_name='pytorch',
            producer_version='2.9.0+cu126',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"input"<FLOAT,[s77,3,s53,s0]>
            ),
            outputs=(
                %"output"<FLOAT,[1,3,4*s53,4*s0]>
            ),
            initializers=(
                %"conv_first.weight"<FLOAT,[256,3,3,3]>{TorchTensor(...)},
                %"conv_first.bias"<FLOAT,[256]>{TorchTensor(...)},
                %"layers.0.blocks.0.norm1.weight"<FLOAT,[256]>{TorchTensor(...)},
                %"layers.0.blocks.0.norm1.bias"<FLOAT,[256]>{TorchTensor(...)},
                %"layers.0.blocks.0.attn.qkv.bias"<FLOAT,[768]>{TorchTensor(...)},
                %"layers.0.blocks.0.attn.proj.bias"<FLOAT,[256]>{TorchTensor(...)},
                %"layers.0.blocks.0.norm2.w

In [6]:
# 경량화
ORIGINAL_MODEL='swinir_best.pth'
FP32_MODEL='swinir_x4.onnx'
FP32_DATA='swinir_x4.onnx.data'
INT8_MODEL='swinir_x4_int8.onnx'
INT8_DATA='swinir_x4_int8.onnx.data'


model_onnx=onnx.load(FP32_MODEL)
onnx.checker.check_model(model_onnx)

# quantize_dynamic(
#     model_input=FP32_MODEL,
#     model_output=INT8_MODEL,
#     weight_type=QuantType.QInt8
# )
quantize_dynamic(
    model_input=FP32_MODEL,
    model_output=INT8_MODEL,
    weight_type=QuantType.QInt8,
    op_types_to_quantize=["MatMul", "Gemm"],  # Conv는 빼고 양자화
)

def safe_size(path):
  return os.path.getsize(path) if os.path.exists(path) else 0

origin_size=safe_size(ORIGINAL_MODEL)/(1024*1024)
fp32_size=(safe_size(FP32_MODEL)+safe_size(FP32_DATA))/(1024*1024)
int8_size=(safe_size(INT8_MODEL)+safe_size(INT8_DATA))/(1024*1024)

print("======= Result of Model size =======")
print(f"Original File Size: {origin_size:.2f} MB")
print(f"FP32: {fp32_size:.2f} MB")
print(f"INT8: {int8_size:.2f} MB")
print(f"압축율: {(1-int8_size/fp32_size)*100:.2f}%")



Original File Size: 189.85 MB
FP32: 64.17 MB
INT8: 37.10 MB
압축율: 42.19%


In [7]:
%%writefile engine.h
#pragma once
#include <string>
#include <vector>
#include <onnxruntime_cxx_api.h>

class InferenceEngine {
  public:
    InferenceEngine(const std::string& model_path, int threads);
    std::vector<float> infer(const std::vector<float>& input);

  private:
    Ort::Env env;
    Ort::Session session;
    Ort::MemoryInfo mem_info;
    std::vector<int64_t> input_shape;
};

Writing engine.h


In [8]:
%%writefile engine.cpp
#include "engine.h"

InferenceEngine::InferenceEngine(
    const std::string& model_path,
    int threads
) : env(ORT_LOGGING_LEVEL_WARNING, "Engine"),
    session(nullptr),
    mem_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault))
{
    Ort::SessionOptions opts;
    opts.SetIntraOpNumThreads(threads);
    opts.SetInterOpNumThreads(threads);
    opts.SetGraphOptimizationLevel(
        GraphOptimizationLevel::ORT_ENABLE_ALL);

    session = Ort::Session(env, model_path.c_str(), opts);
    input_shape = {1, 3, 64, 64};
}

std::vector<float> InferenceEngine::infer(const std::vector<float>& input) {
    auto input_tensor = Ort::Value::CreateTensor<float>(
        mem_info,
        const_cast<float*>(input.data()),
        input.size(),
        input_shape.data(),
        input_shape.size()
    );

    const char* input_names[] = {"input"};
    const char* output_names[] = {"output"};

    auto outputs = session.Run(
        Ort::RunOptions{nullptr},
        input_names, &input_tensor, 1,
        output_names, 1
    );

    float* out = outputs[0].GetTensorMutableData<float>();
    size_t out_size = outputs[0]
        .GetTensorTypeAndShapeInfo()
        .GetElementCount();

    return std::vector<float>(out, out + out_size);
}


Writing engine.cpp


In [9]:
%%writefile main.cpp
#include <iostream>
#include <chrono>
#include "engine.h"

double benchmark(InferenceEngine& engine, const char* tag) {
    std::vector<float> input(1*3*64*64, 0.5f);

    // 워밍업
    engine.infer(input);

    auto t0 = std::chrono::high_resolution_clock::now();
    auto out = engine.infer(input);
    auto t1 = std::chrono::high_resolution_clock::now();

    std::chrono::duration<double, std::milli> ms = t1 - t0;

    std::cout << "[" << tag << "] Time: "
              << ms.count()
              << " ms / First: "
              << out[0] << std::endl;

    return ms.count();
}

int main() {
    InferenceEngine fp32("swinir_x4.onnx", 4);
    InferenceEngine int8("swinir_x4_int8.onnx", 4);

    double t_fp32 = benchmark(fp32, "FP32");
    double t_int8 = benchmark(int8, "INT8");

    std::cout << "Result: " << "Faster " << (1-t_int8/t_fp32)*100 << "% than Before";
    return 0;
}


Writing main.cpp


In [10]:
!g++ main.cpp engine.cpp \
  -I /content/onnxruntime-linux-x64-1.23.2/include \
  -L /content/onnxruntime-linux-x64-1.23.2/lib \
  -lonnxruntime \
  -std=c++17 -O3 -march=native -o app

!LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/content/onnxruntime-linux-x64-1.23.2/lib ./app


[FP32] Time: 85827.6 ms / First: 0.486768
[INT8] Time: 6320.48 ms / First: 0.486675
Result: Faster 92.6358% than Before