In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

# Загрузка и предобработка данных
train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root='./data', train=False, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Определение размерности скрытого пространства
latent_dim = 2

# Кодировщик
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc_mean = nn.Linear(256, latent_dim)
        self.fc_log_var = nn.Linear(256, latent_dim)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = torch.relu(self.fc1(x))
        z_mean = self.fc_mean(x)
        z_log_var = self.fc_log_var(x)
        return z_mean, z_log_var

# Декодировщик
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 784)

    def forward(self, z):
        x = torch.relu(self.fc1(z))
        x = torch.sigmoid(self.fc2(x))
        x = x.view(-1, 1, 28, 28)
        return x

# Модель VAE
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, z_mean, z_log_var):
        epsilon = torch.randn_like(z_mean)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

    def forward(self, x):
        z_mean, z_log_var = self.encoder(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_recon = self.decoder(z)
        return x_recon, z_mean, z_log_var

# Создание модели и оптимизатора
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Функция потерь VAE
def vae_loss(x, x_recon, z_mean, z_log_var):
    reconstruction_loss = nn.functional.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    return reconstruction_loss + kl_loss

# Обучение модели VAE
def train(model, optimizer, train_loader):
    model.train()
    train_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_recon, z_mean, z_log_var = model(x)
        loss = vae_loss(x, x_recon, z_mean, z_log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(train_loader.dataset)

# Тестирование модели VAE
def test(model, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for x, _ in test_loader:
            x_recon, z_mean, z_log_var = model(x)
            loss = vae_loss(x, x_recon, z_mean, z_log_var)
            test_loss += loss.item()
    return test_loss / len(test_loader.dataset)

# Обучение модели на несколько эпох
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, optimizer, train_loader)
    test_loss = test(model, test_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

# Запуск модели на тестовых данных
with torch.no_grad():
    samples = torch.randn(16, latent_dim)
    samples = model.decoder(samples).cpu()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 17281744.03it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 484304.50it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4370380.66it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8799320.45it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch 1/10, Train Loss: 196.7230, Test Loss: 175.2559
Epoch 2/10, Train Loss: 170.4104, Test Loss: 166.4851
Epoch 3/10, Train Loss: 165.2447, Test Loss: 163.1827
Epoch 4/10, Train Loss: 162.8319, Test Loss: 161.4415
Epoch 5/10, Train Loss: 161.1966, Test Loss: 160.1352
Epoch 6/10, Train Loss: 159.8952, Test Loss: 158.8164
Epoch 7/10, Train Loss: 158.7448, Test Loss: 157.8169
Epoch 8/10, Train Loss: 157.7564, Test Loss: 156.8630
Epoch 9/10, Train Loss: 156.7967, Test Loss: 156.0210
Epoch 10/10, Train Loss: 155.9720, Test Loss: 155.7152


In [6]:
!pip install onnx onnxruntime
import onnx
import onnxruntime
example_input = torch.randn(1, 1, 28, 28)

# Экспорт модели в формат ONNX
onnx_model_path = 'vae_model.onnx'
torch.onnx.export(model, example_input, onnx_model_path)

# Создание сессии onnxruntime
session = onnxruntime.InferenceSession(onnx_model_path)

# Получение имени входного и выходного тензоров модели
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Преобразование входных данных в формат, поддерживаемый onnxruntime
example_input_np = example_input.numpy()
example_input_ort = {input_name: example_input_np}

# Выполнение модели в onnxruntime
output = session.run([output_name], example_input_ort)

# Вывод результата
print(output)

[array([[[[0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 8.94069672e-08, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00],
         [0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          5.96046448e-08, 1.49011612e-07, 0.00000000e+00,
          0.00000000e+00, 1.19209290e-07, 2.98023224e-07,
          3.24845314e-06, 2.68220901e-06, 7.21216202e-06,
          4.02331352e-05, 1.67250633e-04, 3.42935324e-04,
          4.34786081e-04, 1.20112300e-03, 3.91572714e-04,
          1.24275684e-04, 2.29477882e-06, 0.

In [7]:
!pip install onnx2torch
from onnx2torch.converter import convert
torch_model = convert("vae_model.onnx")
output = torch_model(dummy_input)
print(output)

Collecting onnx2torch
  Downloading onnx2torch-1.5.14-py3-none-any.whl (80 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/80.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.1/80.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.8.0->onnx2torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.8.0->onnx2torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.8.0->onnx2torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.8.0->onnx2torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3

NotImplementedError: Converter is not implemented (OperationDescription(domain='', operation_type='RandomNormalLike', version=1))