<a href="https://colab.research.google.com/github/NickOsipov/notebooks/blob/main/comparison_torch_pickle_onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Пожалуйста, делайте копию на свой Google Drive!

# Сравнение методов сериализации

In [1]:
!pip install onnx onnxruntime

Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.22.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.22.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m67.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
import onnx
import onnxruntime
import pickle
from tqdm import tqdm
from torch.utils.data import Subset
import random
import numpy as np
from scipy import stats
import warnings
warnings.filterwarnings("ignore")

In [3]:
class ComplexCNN(nn.Module):
    def __init__(self):
        super(ComplexCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
# Загрузка и подготовка данных (оставлено без изменений)
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
train_set = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)

num_samples = 1000
indices = random.sample(range(len(train_set)), num_samples)
limited_set = Subset(train_set, indices)
train_loader = torch.utils.data.DataLoader(limited_set, batch_size=4, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 35.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.38MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.95MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.2MB/s]


In [5]:
# Обучение модели
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ComplexCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
n_epoch = 2
sep = "-" * 60

for epoch in range(n_epoch):
    print(sep)
    print(f"Epoch: {epoch}")

    losses = []

    for data in tqdm(train_loader):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        losses.append(loss)

    epoch_loss = torch.mean(torch.tensor(losses))
    print(f"\nLoss: {epoch_loss}")
else:
    print(sep)
    print("Обучение завершено")

------------------------------------------------------------
Epoch: 0


100%|██████████| 250/250 [00:21<00:00, 11.54it/s]



Loss: 2.307921886444092
------------------------------------------------------------
Epoch: 1


100%|██████████| 250/250 [00:16<00:00, 14.96it/s]


Loss: 2.3026809692382812
------------------------------------------------------------
Обучение завершено





In [6]:
# Сохранение модели различными способами

# 1. PyTorch
torch.save(model.state_dict(), "cnn_pytorch.pth")
print("Модель сохранена в формате PyTorch")

# 2. Pickle
with open("cnn_pickle.pkl", "wb") as f:
    pickle.dump(model, f)
print("Модель сохранена в формате Pickle")

# 3. ONNX
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(model, dummy_input, "cnn.onnx", verbose=True)
print("Модель сохранена в формате ONNX")

Модель сохранена в формате PyTorch
Модель сохранена в формате Pickle
Модель сохранена в формате ONNX


In [7]:
# Загрузка моделей

def print_load_time(start_time, model):
    print("--------------------------------------------------------")
    print(f"Модель: {model}")
    print(f"Время десереализации: {time.time() - start_time:.6f} секунд")

# 1. PyTorch
start_time = time.time()
pytorch_model = ComplexCNN().to(device)
pytorch_model.load_state_dict(torch.load("cnn_pytorch.pth"))
pytorch_model.eval()
print_load_time(start_time, "PyTorch")

# 2. Pickle
start_time = time.time()
with open("cnn_pickle.pkl", "rb") as f:
    pickle_model = pickle.load(f)
pickle_model.eval()
print_load_time(start_time, "Pickle")

# 3. ONNX
start_time = time.time()
onnx_model = onnx.load("cnn.onnx")
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession("cnn.onnx")
print_load_time(start_time, "ONNX")

print("\nВсе модели загружены")

--------------------------------------------------------
Модель: PyTorch
Время десереализации: 0.033961 секунд
--------------------------------------------------------
Модель: Pickle
Время десереализации: 0.011964 секунд
--------------------------------------------------------
Модель: ONNX
Время десереализации: 0.038799 секунд

Все модели загружены


In [8]:
def measure_inference_time(model_func, input_tensor, num_iterations=1000):
    times = []
    for _ in range(num_iterations):
        start_time = time.time()
        _ = model_func(input_tensor)
        end_time = time.time()
        times.append(end_time - start_time)
    return times

def bootstrap_analysis(times, num_bootstrap=1000, confidence=0.95):
    means = []
    for _ in range(num_bootstrap):
        sample = np.random.choice(times, size=len(times), replace=True)
        means.append(np.mean(sample))

    mean = np.mean(means)
    ci_lower, ci_upper = np.percentile(means, [(1-confidence)/2 * 100, (1+confidence)/2 * 100])
    return mean, ci_lower, ci_upper

In [9]:
# Подготовка входных данных
input_tensor = torch.randn(1, 1, 28, 28).to(device)
onnx_input = {ort_session.get_inputs()[0].name: input_tensor.cpu().numpy()}

# Измерение времени инференса
original_times = measure_inference_time(model, input_tensor)
pytorch_times = measure_inference_time(pytorch_model, input_tensor)
pickle_times = measure_inference_time(pickle_model, input_tensor)
onnx_times = measure_inference_time(lambda x: ort_session.run(None, onnx_input), onnx_input)

models = ['Original', 'PyTorch', 'Pickle', 'ONNX']
times_list = [original_times, pytorch_times, pickle_times, onnx_times]

In [10]:
max_name_length = max(len(name) for name in models)
for model_name, times_ in zip(models, times_list):
    print(f"Среднее время инференса {model_name:<{max_name_length}}: {np.mean(times_):.6f} секунд")

Среднее время инференса Original: 0.006360 секунд
Среднее время инференса PyTorch : 0.006606 секунд
Среднее время инференса Pickle  : 0.006873 секунд
Среднее время инференса ONNX    : 0.003747 секунд


In [11]:
# Выполнение bootstrap-анализа
sep_1 = "=" * 60
sep_2 = "-" * 60

print(sep_1)
print("Результаты bootstrap-анализа (95% доверительный интервал):")
print(sep_2)
for model_name, times_ in zip(models, times_list):
    mean, ci_lower, ci_upper = bootstrap_analysis(times_)
    print(f"{model_name:<{max_name_length}}: {mean:.6f} секунд ({ci_lower:.6f} - {ci_upper:.6f})")

print()
print(sep_1)
print("Сравнение производительности:")
print(sep_2)
for model_name, times_ in zip(models[1:], times_list[1:]):
    speedup = np.mean(original_times) / np.mean(times_)
    print(f"Ускорение {model_name:<{max_name_length}}: {speedup:.2f}x")

# Статистический тест (t-test) для сравнения с оригинальной моделью
print()
print(sep_1)
print("Статистическая значимость (p-value):")
print(sep_2)
for model_name, times_ in zip(models[1:], times_list[1:]):
    t_stat, p_value = stats.ttest_ind(original_times, times_)
    print(f"{model_name:<{max_name_length}} vs Original: p-value = {p_value}")

Результаты bootstrap-анализа (95% доверительный интервал):
------------------------------------------------------------
Original: 0.006360 секунд (0.006314 - 0.006411)
PyTorch : 0.006608 секунд (0.006549 - 0.006671)
Pickle  : 0.006884 секунд (0.006636 - 0.007149)
ONNX    : 0.003746 секунд (0.003705 - 0.003791)

Сравнение производительности:
------------------------------------------------------------
Ускорение PyTorch : 0.96x
Ускорение Pickle  : 0.93x
Ускорение ONNX    : 1.70x

Статистическая значимость (p-value):
------------------------------------------------------------
PyTorch  vs Original: p-value = 1.9029182946460803e-09
Pickle   vs Original: p-value = 0.00010891357510157145
ONNX     vs Original: p-value = 0.0
