<a href="https://colab.research.google.com/github/Hippopotamus0308/torch2-test/blob/feat-basic-test/torch2_test_compile_mode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip3 install --pre torch --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cpu

Looking in indexes: https://download.pytorch.org/whl/nightly/cpu, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch
  Using cached https://download.pytorch.org/whl/nightly/cpu/torch-2.0.0.dev20230105%2Bcpu-cp38-cp38-linux_x86_64.whl (194.4 MB)
Collecting networkx
  Using cached https://download.pytorch.org/whl/nightly/networkx-3.0rc1-py3-none-any.whl (2.0 MB)
Collecting typing-extensions
  Using cached https://download.pytorch.org/whl/nightly/typing_extensions-4.4.0-py3-none-any.whl (26 kB)
Collecting sympy
  Using cached https://download.pytorch.org/whl/nightly/sympy-1.11.1-py3-none-any.whl (6.5 MB)
Collecting mpmath>=0.19
  Using cached https://download.pytorch.org/whl/nightly/mpmath-1.2.1-py3-none-any.whl (532 kB)
Installing collected packages: mpmath, typing-extensions, sympy, networkx, torch
  Attempting uninstall: mpmath
    Found existing installation: mpmath 1.2.1
    Uninstalling mpmath-1.2.1:
      Successfully uninstalled mpmath-1.2.1
  Attempting uninst

In [7]:
import torch
import numpy as np
import torch._dynamo
from typing import List
import time

def timed(fn):
    start = time.time()
    result = fn()
    end = time.time()
    time_cnt = end - start
    #print(f"{printer}, time: {time_cnt}")
    return result, time_cnt


def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32),
        torch.randint(1000, (b,)),
    )

In [8]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 87.7MB/s]


In [9]:
opt_model_default = torch.compile(model, mode="default")
opt_model_reduce_overhead = torch.compile(model, mode="reduce-overhead")
opt_model_max_autotune = torch.compile(model, mode="max-autotune")

In [18]:
def test(cnt):
  time_no_opt = []
  time_default = []
  time_reduce_overhead = []
  time_max_autotune = []

  ## warm up
  for i in range(5):
    model(generate_data(cnt)[0])
    opt_model_default(generate_data(cnt)[0])
    opt_model_reduce_overhead(generate_data(cnt)[0])
    opt_model_max_autotune(generate_data(cnt)[0])

  for i in range(10):
    _, time1 = timed(lambda:model(generate_data(cnt)[0]))
    _, time2 = timed(lambda:opt_model_default(generate_data(cnt)[0]))
    _, time3 = timed(lambda:opt_model_reduce_overhead(generate_data(cnt)[0]))
    _, time4 = timed(lambda:opt_model_max_autotune(generate_data(cnt)[0]))
    time_no_opt.append(time1)
    time_default.append(time2)
    time_reduce_overhead.append(time3)
    time_max_autotune.append(time4)   

  no_opt_median_time = np.median(time_no_opt)
  default_opt_median_time = np.median(time_default)
  ro_median_time = np.median(time_reduce_overhead)
  ma_median_time = np.median(time_max_autotune)

  no_opt_mean_time = np.mean(time_no_opt)
  default_opt_mean_time = np.mean(time_default)
  ro_mean_time = np.mean(time_reduce_overhead)
  ma_mean_time = np.mean(time_max_autotune)

  print("-------------Median Time---------------")
  print(f"no opt median time: {no_opt_median_time}")
  print(f"mode = defualt: {default_opt_median_time}")
  print(f"mode = reduce overhead: {ro_median_time}")
  print(f"mode = max autotune: {ma_median_time}")

  print("-------------Mean Time---------------")
  print(f"no opt mean time: {no_opt_mean_time}")
  print(f"mode = defualt: {default_opt_mean_time}")
  print(f"mode = reduce overhead: {ro_mean_time}")
  print(f"mode = max autotune: {ma_mean_time}")  

In [19]:
test(1)

-------------Median Time---------------
no opt median time: 0.03913414478302002
mode = defualt: 0.04544544219970703
mode = reduce overhead: 0.04502689838409424
mode = max autotune: 0.044536471366882324
-------------Mean Time---------------
no opt mean time: 0.0397219181060791
mode = defualt: 0.04786083698272705
mode = reduce overhead: 0.04614570140838623
mode = max autotune: 0.044898605346679686


In [20]:
test(8)

-------------Median Time---------------
no opt median time: 0.2674351930618286
mode = defualt: 0.3092167377471924
mode = reduce overhead: 0.3070477247238159
mode = max autotune: 0.3095734119415283
-------------Mean Time---------------
no opt mean time: 0.26920514106750487
mode = defualt: 0.3109787702560425
mode = reduce overhead: 0.3109633684158325
mode = max autotune: 0.31057398319244384


In [21]:
test(32)

-------------Median Time---------------
no opt median time: 0.9859806299209595
mode = defualt: 1.176600456237793
mode = reduce overhead: 1.1663029193878174
mode = max autotune: 1.1761353015899658
-------------Mean Time---------------
no opt mean time: 0.9861690998077393
mode = defualt: 1.1880727529525756
mode = reduce overhead: 1.1696454524993896
mode = max autotune: 1.2120518207550048


In [22]:
test(64)

-------------Median Time---------------
no opt median time: 1.9487462043762207
mode = defualt: 2.295313000679016
mode = reduce overhead: 2.332722306251526
mode = max autotune: 2.3192657232284546
-------------Mean Time---------------
no opt mean time: 1.9546403408050537
mode = defualt: 2.319568729400635
mode = reduce overhead: 2.3309523105621337
mode = max autotune: 2.3204650402069094


In [23]:
test(128)

-------------Median Time---------------
no opt median time: 3.7803475856781006
mode = defualt: 4.5058698654174805
mode = reduce overhead: 4.5549890995025635
mode = max autotune: 4.49209189414978
-------------Mean Time---------------
no opt mean time: 3.7929439306259156
mode = defualt: 4.61926236152649
mode = reduce overhead: 4.547304439544678
mode = max autotune: 4.498769235610962
