Skip to content

Files

Latest commit

 

History

History
355 lines (230 loc) ยท 14.3 KB

mobile_perf.rst

File metadata and controls

355 lines (230 loc) ยท 14.3 KB

PyTorch ๋ชจ๋ฐ”์ผ ์„ฑ๋Šฅ ๋ ˆ์‹œํ”ผ

์†Œ๊ฐœ

์ „๋ถ€๋Š” ์•„๋‹ˆ์ง€๋งŒ, ๋ชจ๋ฐ”์ผ ๊ธฐ๊ธฐ์—์„œ์˜ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜๊ณผ ML ๋ชจ๋ธ ์ถ”๋ก  ์‚ฌ์šฉ ์‚ฌ๋ก€์— ์„ฑ๋Šฅ(์ง€์—ฐ์‹œ๊ฐ„)์€ ๋งค์šฐ ์ค‘๋Œ€ํ•œ ์‚ฌํ•ญ์ž…๋‹ˆ๋‹ค.

์˜ค๋Š˜๋‚  PyTorch๋Š” GPU, DSP, NPU์™€ ๊ฐ™์€ ํ•˜๋“œ์›จ์–ด ๋ฐฑ์—”๋“œ๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•  ๋•Œ๊นŒ์ง€ CPU ๋ฐฑ์—”๋“œ์—์„œ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

์ด ๋ ˆ์‹œํ”ผ์—์„œ ๋ฐฐ์šธ ๋‚ด์šฉ์€:

  • ๋ชจ๋ฐ”์ผ ๊ธฐ๊ธฐ์—์„œ ์‹คํ–‰ ์‹œ๊ฐ„์„ ์ค„์ด๋Š”๋ฐ ๋„์›€์ด ๋ (์„ฑ๋Šฅ์€ ๋†’์ด๊ณ , ์ง€์—ฐ์‹œ๊ฐ„์€ ์ค„์ด๋Š”) ๋ชจ๋ธ ์ตœ์ ํ™” ๋ฐฉ๋ฒ•
  • ๋ฒค์น˜๋งˆํ‚น(์ตœ์ ํ™”๊ฐ€ ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๋„์›€์ด ๋˜์—ˆ๋Š”์ง€ ํ™•์ธ) ํ•˜๋Š” ๋ฐฉ๋ฒ•

๋ชจ๋ธ ์ค€๋น„

๋ชจ๋ฐ”์ผ ๊ธฐ๊ธฐ์—์„œ ์‹คํ–‰ ์‹œ๊ฐ„์„ ์ค„์ด๋Š”๋ฐ ๋„์›€์ด ๋ (์„ฑ๋Šฅ์€ ๋†’์ด๊ณ , ์ง€์—ฐ์‹œ๊ฐ„์€ ์ค„์ด๋Š”) ๋ชจ๋ธ์˜ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•œ ์ค€๋น„๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.

์„ค์ •

์ฒซ๋ฒˆ์งธ๋กœ ์ ์–ด๋„ ๋ฒ„์ „์ด 1.5.0 ์ด์ƒ์ธ PyTorch๋ฅผ conda๋‚˜ pip์œผ๋กœ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค.

conda install pytorch torchvision -c pytorch

๋˜๋Š”

pip install torch torchvision

๋ชจ๋ธ ์ฝ”๋“œ:

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

class AnnotatedConvBnReLUModel(torch.nn.Module):
    def __init__(self):
        super(AnnotatedConvBnReLUModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
        self.relu = torch.nn.ReLU(inplace=True)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

model = AnnotatedConvBnReLUModel()

torch.quantization.QuantStub ์™€ torch.quantization.DeQuantStub() ์€ ๋ฏธ์‚ฌ์šฉ ์Šคํ…(stub)์ด๋ฉฐ, ์–‘์žํ™”(quantization) ๋‹จ๊ณ„์— ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

1. torch.quantization.fuse_modules ์ด์šฉํ•˜์—ฌ ์—ฐ์‚ฐ์ž ๊ฒฐํ•ฉ(fuse)ํ•˜๊ธฐ

fuse_modules์€ ์–‘์žํ™” ํŒจํ‚ค์ง€ ๋‚ด๋ถ€์— ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ํ˜ผ๋™ํ•˜์ง€ ๋งˆ์‹ญ์‹œ์˜ค. fuse_modules์€ ๋ชจ๋“  torch.nn.Module ์—์„œ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.

torch.quantization.fuse_modules ์€ ๋ชจ๋“ˆ๋“ค์˜ ๋ฆฌ์ŠคํŠธ๋ฅผ ํ•˜๋‚˜์˜ ๋ชจ๋“ˆ๋กœ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์•„๋ž˜ ์ˆœ์„œ์˜ ๋ชจ๋“ˆ๋“ค๋งŒ ๊ฒฐํ•ฉ์‹œํ‚ต๋‹ˆ๋‹ค:

  • Convolution, Batch normalization
  • Convolution, Batch normalization, Relu
  • Convolution, Relu
  • Linear, Relu

์ด ์Šคํฌ๋ฆฝํŠธ๋Š” ์ด์ „์— ์„ ์–ธ๋œ ๋ชจ๋ธ์—์„œ Convolution, Batch Normalization, Relu๋ฅผ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค.

torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']], inplace=True)

2. ๋ชจ๋ธ ์–‘์žํ™”ํ•˜๊ธฐ

PyTorch ์–‘์žํ™”์— ๋Œ€ํ•œ ๋‚ด์šฉ์€ the dedicated tutorial ์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์˜ ์–‘์žํ™”๋Š” ์—ฐ์‚ฐ์„ int8๋กœ ์˜ฎ๊ธฐ๋ฉด์„œ ๋””์Šคํฌ์ƒ์˜ ๋ชจ๋ธ ํฌ๊ธฐ๋ฅผ ์ค„์ด๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฐ ํฌ๊ธฐ ๊ฐ์†Œ๋Š” ๋ชจ๋ธ์„ ์ฒ˜์Œ ์ฝ์–ด ๋“ค์ผ ๋•Œ ๋””์Šคํฌ ์ฝ๊ธฐ ์—ฐ์‚ฐ์„ ์ค„์ด๋Š”๋ฐ ๋„์›€์„ ์ฃผ๊ณ  ๋žจ(RAM)์˜ ์ด๋Ÿ‰๋„ ์ค„์ž…๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋‘ ์ž์›์€ ๋ชจ๋ฐ”์ผ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์˜ ์„ฑ๋Šฅ์— ๋งค์šฐ ์ค‘์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ฝ”๋“œ๋Š” ๋ชจ๋ธ ๋ณด์ •(calibration) ํ•จ์ˆ˜๋ฅผ ์œ„ํ•ด ์Šคํ…์„ ์‚ฌ์šฉํ•ด์„œ ์–‘์žํ™”๋ฅผ ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ ์—์„œ ๊ด€๋ จ๋œ ์‚ฌํ•ญ์„ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare(model, inplace=True)
# ๋ชจ๋ธ ๋ณด์ •
def calibrate(model, calibration_data):
    # ๋ชจ๋ธ ๋ณด์ • ์ฝ”๋“œ
    return
calibrate(model, [])
torch.quantization.convert(model, inplace=True)

3. torch.utils.mobile_optimizer ์‚ฌ์šฉํ•˜๊ธฐ

Torch mobile_optimizer ํŒจํ‚ค์ง€๋Š” ์Šคํฌ๋ฆฝํŠธ๋œ ๋ชจ๋ธ์„ ์ด์šฉํ•ด์„œ ๋ช‡ ๊ฐ€์ง€ ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ , ์ด๋Ÿฌํ•œ ์ตœ์ ํ™”๋Š” conv2d์™€ ์„ ํ˜• ์—ฐ์‚ฐ์— ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค. ์ด ํŒจํ‚ค์ง€๋Š” ์ตœ์ ํ™”๋œ ํ˜•์‹์œผ๋กœ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ์šฐ์„  ํŒจํ‚ค์ง•ํ•˜๋ฉฐ(pre-packs) ๋‹ค์Œ ์—ฐ์‚ฐ์ด relu์ด๋ฉด ์œ„์˜ ์—ฐ์‚ฐ๋“ค๊ณผ relu ์—ฐ์‚ฐ์„ ๊ฒฐํ•ฉ ์‹œํ‚ต๋‹ˆ๋‹ค.

๋จผ์ € ์ด์ „ ๋‹จ๊ณ„์—์„œ๋ถ€ํ„ฐ ๊ฒฐ๊ณผ ๋ชจ๋ธ์„ ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค:

torchscript_model = torch.jit.script(model)

๋‹ค์Œ์€ optimize_for_mobile ์„ ํ˜ธ์ถœํ•˜๊ณ  ๋””์Šคํฌ์— ๋ชจ๋ธ์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

torchscript_model_optimized = optimize_for_mobile(torchscript_model)
torch.jit.save(torchscript_model_optimized, "model.pt")

4. Channels Last Tensor ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹ ์„ ํƒํ•˜๊ธฐ

Channels Last(NHWC) ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์€ PyTorch 1.4.0์—์„œ ๋„์ž…๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด ํ˜•์‹์€ ์˜ค์ง 4์ฐจ์› ํ…์„œ๋งŒ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์€ ๋Œ€๋ถ€๋ถ„์˜ ์—ฐ์‚ฐ์—, ํŠนํžˆ ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์— ๋” ๋‚˜์€ ๋ฉ”๋ชจ๋ฆฌ ์ง€์—ญ์„ฑ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ธก์ • ๊ฒฐ๊ณผ๋Š” MobileNetV2 ๋ชจ๋ธ์—์„œ ๊ธฐ๋ณธ Channels First(NCHW) ํ˜•์‹์— ๋น„ํ•ด 3๋ฐฐ์˜ ์†๋„ ํ–ฅ์ƒ์„ ๋ณด์—ฌ ์ค๋‹ˆ๋‹ค.

์ด ๋ ˆ์‹œํ”ผ๋ฅผ ์ž‘์„ฑํ•˜๋Š” ์‹œ์ ์—์„œ๋Š”, PyTorch Android ์ž๋ฐ” API๋Š” Channels Last ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์œผ๋กœ ๋œ ์ž…๋ ฅ์„ ์ง€์›ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ๋ชจ๋ธ ์ž…๋ ฅ์„ ์œ„ํ•ด ์ด ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋ฉด TorchScript ๋ชจ๋ธ ์ˆ˜์ค€์—์„œ ์‚ฌ์šฉ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

def forward(self, x):
    x = x.contiguous(memory_format=torch.channels_last)
    ...

์ด ๋ณ€ํ™˜์€ ์ž…๋ ฅ์ด Channels Last ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์ด๋ฉด ๋น„์šฉ์ด ๋“ค์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๊ฒฐ๊ตญ์—๋Š” ๋ชจ๋“  ์—ฐ์‚ฐ์ž๊ฐ€ Channels Last ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์œ ์ง€ํ•˜๋ฉด์„œ ์ž‘์—…์„ ํ•ฉ๋‹ˆ๋‹ค.

5. Android - ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ์œ„ํ•œ ํ…์„œ ์žฌ์‚ฌ์šฉํ•˜๊ธฐ

๋ ˆ์‹œํ”ผ์—์„œ ์ด ๋ถ€๋ถ„์€ Android์—๋งŒ ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค.

๋ฉ”๋ชจ๋ฆฌ๋Š” Android ์„ฑ๋Šฅ์— ๋งค์šฐ ์ค‘์š”ํ•œ ์ž์›์ž…๋‹ˆ๋‹ค. ์˜ค๋ž˜๋œ ๋””๋ฐ”์ด์Šค์—์„  ํŠนํžˆ๋‚˜ ๋” ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ํ…์„œ๋Š” ์ƒ๋‹นํ•œ ์–‘์˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ•„์š”๋กœ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ํ‘œ์ค€ ์ปดํ“จํ„ฐ ๋น„์ „ ํ…์„œ๋Š” 1*3*224*224๊ฐœ์˜ ์š”์†Œ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ํƒ€์ž…์ด float์ด๊ณ  588kb ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ํ•„์š”ํ•˜๋‹ค๊ณ  ๊ฐ€์ •ํ•œ ๊ฒฝ์šฐ์ž…๋‹ˆ๋‹ค.

FloatBuffer buffer = Tensor.allocateFloatBuffer(1*3*224*224);
Tensor tensor = Tensor.fromBlob(buffer, new long[]{1, 3, 224, 224});

์—ฌ๊ธฐ์—์„  ๋„ค์ดํ‹ฐ๋ธŒ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ java.nio.FloatBuffer ๋กœ ํ• ๋‹นํ•˜๊ณ  ์ €์žฅ์†Œ๊ฐ€ ํ• ๋‹น๋œ ๋ฒ„ํผ์˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ฐ€๋ฆฌํ‚ฌ org.pytorch.Tensor ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

๋Œ€๋ถ€๋ถ„์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์—์„œ ๋ชจ๋ธ ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ๋‹จ ํ•œ ๋ฒˆ๋งŒ ํ•˜์ง€ ์•Š๊ณ , ์ผ์ •ํ•œ ๋นˆ๋„๋กœ ํ˜น์€ ๊ฐ€๋Šฅํ•œ ํ•œ ๋นจ๋ฆฌ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

๋งŒ์•ฝ ๋ชจ๋“  ๋ชจ๋“ˆ ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ์œ„ํ•ด ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น์„ ์ƒˆ๋กœ ํ•œ๋‹ค๋ฉด - ๊ทธ๊ฑด ์ตœ์ ํ™”๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค. ๋Œ€์‹ ์—, ์ด์ „ ๋‹จ๊ณ„์—์„œ ํ• ๋‹นํ•œ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ์— ์ƒˆ ๋ฐ์ดํ„ฐ๋ฅผ ์ฑ„์šฐ๊ณ  ๋ชจ๋“ˆ ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ๋™์ผํ•œ ํ…์„œ ๊ฐ์ฒด์—์„œ ๋‹ค์‹œ ์‹คํ–‰ํ•จ์œผ๋กœ์จ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์žฌ์‚ฌ์šฉ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฝ”๋“œ๊ฐ€ ์–ด๋–ค ์‹์œผ๋กœ ๊ตฌ์„ฑ์ด ๋˜์–ด ์žˆ๋Š”์ง€๋Š” pytorch android application example ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {
  if (mModule == null) {
    mModule = Module.load(moduleFileAbsoluteFilePath);
    mInputTensorBuffer =
    Tensor.allocateFloatBuffer(3 * 224 * 224);
    mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, 224, 224});
  }

  TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
      image.getImage(), rotationDegrees,
      224, 224,
      TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
      TensorImageUtils.TORCHVISION_NORM_STD_RGB,
      mInputTensorBuffer, 0);

  Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
}

๋ฉค๋ฒ„ ๋ณ€์ˆ˜ mModule , mInputTensorBuffer , mInputTensor ๋Š” ๋‹จ ํ•œ ๋ฒˆ ์ดˆ๊ธฐํ™”๋ฅผ ํ•˜๊ณ  ๋ฒ„ํผ๋Š” org.pytorch.torchvision.TensorImageUtils.imageYUV420CenterCropToFloatBuffer ๋ฅผ ์ด์šฉํ•ด์„œ ๋‹ค์‹œ ์ฑ„์›Œ์ง‘๋‹ˆ๋‹ค.

6. ๋กœ๋”ฉ ์‹œ๊ฐ„ ์ตœ์ ํ™”

PyTorch 1.13 ์ด์ƒ๋ถ€ํ„ฐ ์‚ฌ์šฉ ๊ฐ€๋Šฅ

ํŒŒ์ดํ† ์น˜ ๋ชจ๋ฐ”์ผ์€ ๋กœ๋”ฉ ์†๋„๊ฐ€ ๋” ๋น ๋ฅธ FlatBuffer(ํ”Œ๋žซ๋ฒ„ํผ) ๊ธฐ๋ฐ˜ ํŒŒ์ผ ํ˜•์‹๋„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. FlatBuffer์™€ Pickle(ํ”ผํด) ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ํŒŒ์ผ์€ ๋ชจ๋‘ ๋™์ผํ•œ _load_for_lite_interpreter (Python) ๋˜๋Š” _load_for_mobile (C++) API๋กœ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

FlatBuffer ํ˜•์‹์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด model._save_for_lite_interpreter('path/to/file.ptl') ์‹์œผ๋กœ ๋ชจ๋ธ ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ๋Œ€์‹ , ๋‹ค์Œ ๋ช…๋ น์„ ์‹คํ–‰ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:

๋‹ค์Œ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ €์žฅ

model._save_for_lite_interpreter('path/to/file.ptl', _use_flatbuffer=True)

์ธ์ˆ˜ _use_flatbuffer ๋ฅผ ์ถ”๊ฐ€๋กœ ์‚ฌ์šฉํ•˜์—ฌ zip ํŒŒ์ผ ๋Œ€์‹  FlatBuffer ํŒŒ์ผ์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ์ƒ์„ฑ๋œ ํŒŒ์ผ์€ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์†๋„๊ฐ€ ๋” ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด ResNet-50์„ ์‚ฌ์šฉํ•˜๊ณ  ๋‹ค์Œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

import torch
from torch.jit import mobile
import time
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True)
model.eval()
jit_model = torch.jit.script(model)

jit_model._save_for_lite_interpreter('/tmp/jit_model.ptl')
jit_model._save_for_lite_interpreter('/tmp/jit_model.ff', _use_flatbuffer=True)

import timeit
print('Load ptl file:')
print(timeit.timeit('from torch.jit import mobile; mobile._load_for_lite_interpreter("/tmp/jit_model.ptl")',
                       number=20))
print('Load flatbuffer file:')
print(timeit.timeit('from torch.jit import mobile; mobile._load_for_lite_interpreter("/tmp/jit_model.ff")',
                       number=20))

๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

Load ptl file:
0.5387594579999999
Load flatbuffer file:
0.038842832999999466

์‹ค์ œ ๋ชจ๋ฐ”์ผ ๊ธฐ๊ธฐ์—์„œ๋Š” ์†๋„ ํ–ฅ์ƒ ํญ์ด ๋” ์ž‘๊ฒ ์ง€๋งŒ, ๊ทธ๋Ÿผ์—๋„ ๋กœ๋”ฉ ์‹œ๊ฐ„์ด 3๋ฐฐ์—์„œ 6๋ฐฐ๊นŒ์ง€ ๋‹จ์ถ•๋˜๋Š” ํšจ๊ณผ๋ฅผ ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

### FlatBuffer ๊ธฐ๋ฐ˜ ๋ชจ๋ฐ”์ผ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ์ด์œ 

๊ทธ๋Ÿฌ๋‚˜, FlatBuffer ํ˜•์‹์—๋Š” ๊ณ ๋ คํ•ด์•ผ ํ•  ๋ช‡ ๊ฐ€์ง€ ์ œํ•œ ์‚ฌํ•ญ์ด ์žˆ์Šต๋‹ˆ๋‹ค:

  • PyTorch 1.13 ์ด์ƒ์—์„œ๋งŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, ์ด์ „ ๋ฒ„์ „์˜ PyTorch๋กœ ์ปดํŒŒ์ผ๋œ ํด๋ผ์ด์–ธํŠธ ์žฅ์น˜์—์„œ๋Š” ๋ถˆ๋Ÿฌ์˜ค์ง€ ๋ชปํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • Flatbuffer ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ํŒŒ์ผ ํฌ๊ธฐ์— ๋Œ€ํ•ด 4GB์˜ ์ œํ•œ์„ ๋‘๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋Œ€์šฉ๋Ÿ‰ ๋ชจ๋ธ์—๋Š” ์ ํ•ฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋ฒค์น˜๋งˆํ‚น

๋ฒค์น˜๋งˆํ‚น(์ตœ์ ํ™”๊ฐ€ ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๋„์›€์ด ๋˜์—ˆ๋Š”์ง€ ํ™•์ธ)ํ•˜๋Š” ์ตœ๊ณ ์˜ ๋ฐฉ๋ฒ•์€ ์ตœ์ ํ™”๋ฅผ ํ•˜๊ณ  ์‹ถ์€ ํŠน์ •ํ•œ ์‚ฌ์šฉ ์‚ฌ๋ก€๋ฅผ ์ธก์ •ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์„ฑ๋Šฅ ์ธก์ • ํ–‰์œ„๊ฐ€ ํ™˜๊ฒฝ์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

PyTorch ๋ฐฐํฌํŒ์€ ๋ชจ๋ธ ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•ด์„œ ์›ํ˜• ๊ทธ๋Œ€๋กœ์˜(naked) ๋ฐ”์ด๋„ˆ๋ฆฌ๋ฅผ ๋ฒค์น˜๋งˆํ‚นํ•˜๋Š” ์ˆ˜๋‹จ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด ์ ‘๊ทผ๋ฒ•์€ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๋‚ด๋ถ€์—์„œ ์‹œํ—˜ํ•˜๋Š” ๋ฐฉ๋ฒ•๋ณด๋‹ค ๋” ์•ˆ์ •์ ์ธ ์ธก์ •์น˜๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

Android - ๋ฒค์น˜๋งˆํ‚น ์„ค์ •

๋ ˆ์‹œํ”ผ์—์„œ ์ด ๋ถ€๋ถ„์€ Android์—๋งŒ ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค.

๋ฒค์น˜๋งˆํ‚น์„ ์œ„ํ•ด ๋จผ์ € ๋ฒค์น˜๋งˆํฌ ๋ฐ”์ด๋„ˆ๋ฆฌ๋ฅผ ๋นŒ๋“œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

<from-your-root-pytorch-dir>
rm -rf build_android
BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DBUILD_BINARY=ON

์ด ๊ณณ์— arm64 ๋ฐ”์ด๋„ˆ๋ฆฌ๊ฐ€ ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค: build_android/bin/speed_benchmark_torch . ์ด ๋ฐ”์ด๋„ˆ๋ฆฌ๋Š” --model=<path-to-model>, --input_dim="1,3,224,224" ์„ ์ž…๋ ฅ์„ ์œ„ํ•œ ์ฐจ์› ์ •๋ณด๋กœ ๋ฐ›๊ณ  --input_type="float" ์œผ๋กœ ์ž…๋ ฅ ํƒ€์ž…์„ ์ธ์ž๋กœ ๋ฐ›์Šต๋‹ˆ๋‹ค.

Android ๋””๋ฐ”์ด์Šค๋ฅผ ์—ฐ๊ฒฐํ•œ ์ ์ด ์žˆ์œผ๋ฉด, speedbenchark_torch ๋ฐ”์ด๋„ˆ๋ฆฌ์™€ ๋ชจ๋ธ์„ ํฐ์œผ๋กœ ํ‘ธ์‹œํ•ฉ๋‹ˆ๋‹ค:

adb push <speedbenchmark-torch> /data/local/tmp
adb push <path-to-scripted-model> /data/local/tmp

์ด์ œ ๋ชจ๋ธ์„ ๋ฒค์น˜๋งˆํ‚นํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค:

adb shell "/data/local/tmp/speed_benchmark_torch --model=/data/local/tmp/model.pt" --input_dims="1,3,224,224" --input_type="float"
----- output -----
Starting benchmark.
Running warmup runs.
Main runs.
Main run finished. Microseconds per iter: 121318. Iters per second: 8.24281

iOS - ๋ฒค์น˜๋งˆํ‚น ์„ค์ •

iOS์˜ ๊ฒฝ์šฐ , ๋ฒค์น˜๋งˆํ‚น์˜ ๋„๊ตฌ๋กœ TestApp ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๋จผ์ € optimize_for_mobile ๋ฉ”์†Œ๋“œ๋ฅผ TestApp/benchmark/trace_model.py ์— ์žˆ๋Š” ํŒŒ์ด์ฌ ์Šคํฌ๋ฆฝํŠธ์— ์ ์šฉํ•ฉ๋‹ˆ๋‹ค. ๊ฐ„๋‹จํžˆ ์•„๋ž˜์™€ ๊ฐ™์ด ์ฝ”๋“œ๋ฅผ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค.

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
torchscript_model_optimized = optimize_for_mobile(traced_script_module)
torch.jit.save(torchscript_model_optimized, "model.pt")

์ด์ œ python trace_model.py ๋ฅผ ์‹คํ–‰ํ•ฉ์‹œ๋‹ค. ๋ชจ๋“  ๊ฒƒ์ด ์ž˜ ์ž‘๋™ํ•œ๋‹ค๋ฉด ๋ฒค์น˜๋งˆํ‚น ๋””๋ ‰ํ† ๋ฆฌ ๋‚ด๋ถ€์— ์ตœ์ ํ™”๋œ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ์€ ์†Œ์Šค์—์„œ๋ถ€ํ„ฐ PyTorch ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋นŒ๋“œํ•ฉ๋‹ˆ๋‹ค.

BUILD_PYTORCH_MOBILE=1 IOS_ARCH=arm64 ./scripts/build_ios.sh

์ด์ œ ์ตœ์ ํ™”๋œ ๋ชจ๋ธ๊ณผ PyTorch๊ฐ€ ์ค€๋น„๋˜์—ˆ๊ธฐ์— XCode ํ”„๋กœ์ ํŠธ๋ฅผ ๋งŒ๋“ค๊ณ  ๋ฒค์น˜๋งˆํ‚นํ•  ์‹œ๊ฐ„์ž…๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด XCode ํ”„๋กœ์ ํŠธ๋ฅผ ์„ค์ •ํ•˜๋Š” ๋ฌด๊ฑฐ์šด ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฃจ๋น„ ์Šคํฌ๋ฆฝํŠธ setup.rb ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

ruby setup.rb

์ด์ œ TestApp.xcodeproj ๋ฅผ ์—ด๊ณ  iPhone์„ ์—ฐ๊ฒฐํ•˜๋ฉด ์ค€๋น„๊ฐ€ ๋๋‚ฌ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜๋Š” iPhoneX์—์„œ์˜ ์˜ˆ์ œ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค.

TestApp[2121:722447] Main runs
TestApp[2121:722447] Main run finished. Milliseconds per iter: 28.767
TestApp[2121:722447] Iters per second: : 34.762
TestApp[2121:722447] Done.