# Short tests to figure out how to use torch.compile

## basic usage

based on this [tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)

```python

In [5]:
import torch

torch.manual_seed(0)
torch._logging.set_logs(graph_code=True)
random_input = torch.randn(3, 3)

In [6]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


opt_foo1 = torch.compile(foo)
print(opt_foo1(random_input, random_input))


@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


print(opt_foo2(random_input, random_input))

In [5]:
def foo3(x):
    y = x + 1
    z = torch.nn.functional.relu(y)
    u = z * 2
    return u


opt_foo3 = torch.compile(foo3)


# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])

In [6]:
# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)

eager_times = []
for i in range(10):
    _, eager_time = timed(lambda: foo3(inp))
    eager_times.append(eager_time)
    print(f"eager time {i}: {eager_time}")
print("~" * 10)

compile_times = []
for i in range(10):
    _, compile_time = timed(lambda: opt_foo3(inp))
    compile_times.append(compile_time)
    print(f"compile time {i}: {compile_time}")
print("~" * 10)

import numpy as np

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)

In [7]:
def f1(x, y):
    if x.sum() < 0:
        return -y
    return y


# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)


inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

In [8]:
import traceback as tb

torch._logging.set_logs(graph_code=True)


def f2(x, y):
    return x + y


inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

In [1]:
import os
import sys

if os.path.basename(os.getcwd()) == "testing":
    os.chdir(os.path.dirname(os.getcwd()))
    sys.path.append(os.getcwd())

import torch
from lossless.component.coolchic import CoolChicEncoder
from lossless.configs.config import args, str_args
from lossless.training.manager import ImageEncoderManager
from lossless.util.command_line_args_loading import load_args
from lossless.util.image_loading import load_image_as_tensor
from lossless.util.logger import TrainingLogger
from lossless.util.parsecli import get_coolchic_param_from_args

torch.autograd.set_detect_anomaly(True)
torch.set_float32_matmul_precision("high")

# ==========================================================================================
# LOAD COMMAND LINE ARGS AND IMAGE
# ==========================================================================================
command_line_args = load_args()
im_path = args["input"][command_line_args.image_index]
im_tensor, colorspace_bitdepths = load_image_as_tensor(
    im_path, device="cuda:0", color_space=command_line_args.color_space
)
# im_tensor = im_tensor[:,:64, :64] # for faster testing

# ==========================================================================================
# LOAD PRESETS, COOLCHIC PARAMETERS
# ==========================================================================================
image_encoder_manager = ImageEncoderManager(
    preset_name=args["preset"], colorspace_bitdepths=colorspace_bitdepths
)

encoder_param = get_coolchic_param_from_args(
    args,
    "lossless",
    image_size=(im_tensor.shape[2], im_tensor.shape[3]),
    use_image_arm=command_line_args.use_image_arm,
    encoder_gain=command_line_args.encoder_gain,
)
coolchic = CoolChicEncoder(param=encoder_param)
coolchic.to_device("cuda:0")
# ==========================================================================================
# SETUP LOGGER
# ==========================================================================================
dataset_name = im_path.split("/")[-2]
logger = TrainingLogger(
    log_folder_path=args["LOG_PATH"],
    image_name=f"{dataset_name}_" + im_path.split("/")[-1].split(".")[0],
    debug_mode=image_encoder_manager.n_itr < 1000,
    experiment_name=command_line_args.experiment_name,
)
# ==========================================================================================
# COMPILE AND RUN
# ==========================================================================================

# Turn on graph code logging; these print to notebook output
import logging

# Enable Dynamo logs and print generated graph code
# Note: logs appear only when the compiled graph runs at least once
# torch._logging.set_logs()

# torch.compile returns a wrapped/compiled module â€” assign it and call it
# coolchic.compile()

with torch.no_grad():
    out = coolchic(
        image=im_tensor,
        quantizer_noise_type="kumaraswamy",
        quantizer_type="softround",
        soft_round_temperature=torch.tensor(0.3),
        noise_parameter=torch.tensor(0.25),
    )

Converting image to YCoCg color space


In [15]:
from torch.export import export, ExportedProgram

example_input = (
    im_tensor,
    torch.zeros((1,6,im_tensor.shape[2], im_tensor.shape[3]), device=im_tensor.device),
)
exported_program: ExportedProgram = export(coolchic.image_arm, args=example_input)

In [16]:
from torch.fx import passes

tmp = passes.graph_drawer.FxGraphDrawer(exported_program, 'resnet50')
with open("a.svg", "wb") as f:
    f.write(tmp.get_dot_graph().create_svg())