In [None]:
import torch
import PIL.Image
import matplotlib.pyplot as plt
import torchvision.transforms.v2

### Изображение

In [None]:
# Загрузим изображение котика
image = PIL.Image.open("image.jpg")
image

In [None]:
# Будем экспериментировать с efficientnet_b3. Подготовим картинку для применения этой модели.
# The inference transforms ... perform the following preprocessing operations:
transform = torchvision.transforms.v2.Compose([ # Accepts PIL.Image
    # The images are resized to resize_size=[320] using interpolation=InterpolationMode.BICUBIC
    torchvision.transforms.v2.Resize(320, interpolation = torchvision.transforms.v2.InterpolationMode.BICUBIC),
    # followed by a central crop of crop_size=[300].
    torchvision.transforms.v2.CenterCrop(300),
    # Finally the values are first rescaled to [0.0, 1.0]
    torchvision.transforms.v2.ToImage(),
    torchvision.transforms.v2.ToDtype(torch.float32, scale = True),
    # and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    torchvision.transforms.v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
tensor = transform(image)
plt.axis('off')
plt.imshow(tensor.permute(1, 2, 0))
plt.show()

In [None]:
# Создадим батч, скопировав полученную картинку 64 раза.
input = torch.stack([ tensor ] * 64)
print(input.shape)

### Модель

In [None]:
# https://github.com/pytorch/vision/issues/7744
def get_state_dict(self, *args, **kwargs):
    kwargs.pop("check_hash")
    return torch.hub.load_state_dict_from_url(self.url, *args, **kwargs)
torchvision.models._api.WeightsEnum.get_state_dict = get_state_dict

weights = torchvision.models.get_model_weights("efficientnet_b3").DEFAULT
model = torchvision.models.get_model("efficientnet_b3", weights = weights).eval()

In [None]:
output = model(input)[0]
print(output.argmax())
print(output[0:7])

### Производительность

In [None]:
def test(device: torch.device) -> float:
    with torch.no_grad():
        return model.to(device)(input.to(device))

In [None]:
%timeit -r 10 test('cpu')

In [None]:
%timeit -r 100 test('cuda')

In [None]:
model = model.to('cpu')
input = input.to('cpu')

### JIT

https://pytorch.org/docs/master/jit.html \
https://pytorch.org/tutorials/advanced/cpp_export.html

#### Tracing

In [None]:
traced_model = torch.jit.trace(model, input)
torch.allclose(traced_model(input)[0], output)

In [None]:
traced_model.save("models/traced.pt")

#### Scripting

##### Problem

In [None]:
class TestModule(torch.nn.Module):
    def forward(self, input):
        if input.sum() > 0: return input.max(dim = 1).values
        else: return input.min(dim = 1).values
    
input1 = torch.tensor([ [ 1., 2. ], [ 3., 4. ] ])
input2 = torch.tensor([ [ 1., 2. ], [ -3., -4. ] ])

In [None]:
test_module = TestModule()
print(test_module(input1))
print(test_module(input2))

In [None]:
traced_test_module_1 = torch.jit.trace(test_module, input1)
print(traced_test_module_1(input1))
print(traced_test_module_1(input2))

In [None]:
traced_test_module_2 = torch.jit.trace(test_module, input2)
print(traced_test_module_2(input1))
print(traced_test_module_2(input2))

##### Solution

In [None]:
scripted_test_module = torch.jit.script(test_module)
print(scripted_test_module(input1))
print(scripted_test_module(input2))

In [None]:
scripted_model = torch.jit.script(model)
torch.allclose(scripted_model(input)[0], output)

In [None]:
scripted_model.save("models/scripted.pt")

[onednn_fusion](https://pytorch.org/docs/stable/generated/torch.jit.enable_onednn_fusion.html#torch.jit.enable_onednn_fusion) \
[torch.jit.freeze()](https://pytorch.org/docs/stable/generated/torch.jit.freeze.html#torch.jit.freeze) \
[torch.jit.optimize_for_inference](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) \
torch.jit.optimized_execution() - Не задокументировано

### ONNX

https://pytorch.org/docs/stable/onnx.html

#### TorchScript

https://pytorch.org/docs/stable/onnx_torchscript.html

In [None]:
torch.onnx.export(
    model, # traced_model, scripted_model
    input,
    "models/torchscript.onnx",
    input_names = [ "batch" ],
    output_names = [ "scores" ],
    dynamic_axes = {
        "batch": { 0: 'batch_size' },
        "scores": { 0: 'batch_size' }
    }
)

#### TorchDynamo

https://pytorch.org/docs/stable/onnx_dynamo.html

In [None]:
# Experimental, does not seem to work for efficientnet_b3
torchdynamo_model = torch.onnx.dynamo_export(
    model,
    input,
    export_options = torch.onnx.ExportOptions(dynamic_shapes = True)
)
torchdynamo_model.save("models/torchdynamo.onnx")