In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import pandas as pd
import numpy as np
import glob
import os

# ====================== Дефиниция модели и необходимых классов ======================

class HappyWhaleTestDataset(Dataset):
    def __init__(self, image_paths, transforms=None):
        self.image_paths = image_paths
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transforms:
            img = self.transforms(image=img)["image"]
        return {'image': img, 'path': img_path}


def img_to_patch(x, patch_size, flatten_channels=True):
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2, 4)          # [B, H'*W', C*p_H*p_W]
    return x

class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                          dropout=dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers,
                 num_classes, patch_size, num_patches, dropout=0.0):
        super().__init__()
        self.patch_size = patch_size
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(
            *[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) 
              for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1+num_patches, embed_dim))

    def forward(self, x):
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, :T+1]
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)
        cls = x[0]
        out = self.mlp_head(cls)
        return out

# ====================== Трансформации и загрузчик ======================
CONFIG = {
    'img_size': 448,
    'test_batch_size': 1,
    'num_classes': 15587,
    'patch_size': 32,
    'device': torch.device('cpu')  # замените на cuda если доступно
}

data_transforms = {
    "test": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
            p=1.0
        ),
        ToTensorV2()], p=1.)
}


def create_loader(image_paths):
    test_dataset = HappyWhaleTestDataset(
        image_paths, transforms=data_transforms['test'])
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['test_batch_size'],
                             num_workers=0, shuffle=False)
    return test_loader

# ====================== Инициализация модели ======================
modelViT = VisionTransformer(
    embed_dim=784,
    hidden_dim=1568,
    num_heads=8,
    num_layers=6,
    patch_size=32,
    num_channels=3,
    num_patches=196,
    num_classes=15587,
    dropout=0.2
)

# ====================== Функция инференса ======================
def inference_nn(dataloader, model_path, db):
    checkpoint = torch.load(model_path, map_location=CONFIG['device'])
    modelViT.load_state_dict(checkpoint['model_state_dict'])
    modelViT.to(CONFIG['device'])
    modelViT.eval()

    results = []
    with torch.no_grad():
        for batch in dataloader:
            x_batch = batch['image'].to(CONFIG['device'])
            y_test_pred = modelViT(x_batch)
            y_test_pred = torch.softmax(y_test_pred, dim=1)
            y_pred_probs, y_pred_tags = torch.topk(y_test_pred, 5, dim=1)
            # для каждого элемента батча
            for i in range(x_batch.shape[0]):
                tags = y_pred_tags[i].cpu().numpy()
                probs = y_pred_probs[i].cpu().numpy()
                animals = [db[db['individual_id'] == t]['species'].mode()[0] for t in tags]
                img_path = batch['path'][i]
                results.append({
                    'image_path': img_path,
                    'tags': tags,
                    'probs': probs,
                    'animals': animals,
                    'top_animal': animals[0]
                })
    return results

# ====================== Пример запуска инференса на директории изображений ======================
# Загружаем список изображений из папки
image_paths = glob.glob('./resources/images/*.jpg')
db = pd.read_csv('./resources/database.csv')

test_loader = create_loader(image_paths)
results = inference_nn(test_loader, './models/model-e15.pt', db)





In [None]:
import torch

modelViT = VisionTransformer(
    embed_dim=784,
    hidden_dim=1568,
    num_heads=8,
    num_layers=6,
    patch_size=32,
    num_channels=3,
    num_patches=196,
    num_classes=15587,
    dropout=0.2
)
checkpoint = torch.load('./models/model-e15.pt', map_location='cpu')
modelViT.load_state_dict(checkpoint['model_state_dict'])
modelViT.eval()

dummy_input = torch.randn(1, 3, 448, 448)
torch.onnx.export(
    modelViT, 
    dummy_input,
    "model.onnx", 
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input":{0:"batch_size"},
                  "output":{0:"batch_size"}},
    opset_version=11
)
print("OK")

OK


In [5]:
import onnxruntime as ort
import time

test_image_path = image_paths[0]

img = cv2.imread(test_image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

transformed = data_transforms['test'](image=img)['image']  # shape: C,H,W
x_batch = transformed.unsqueeze(0).numpy()  # shape: 1,C,H,W

# PyTorch замер
modelViT.eval()
start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = modelViT(torch.from_numpy(x_batch))
end = time.time()
pytorch_time = end - start

# ONNX runtime
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name

start = time.time()
for _ in range(100):
    session.run(None, {input_name: x_batch})
end = time.time()
onnx_time = end - start

print("PyTorch CPU time:", pytorch_time)
print("ONNX Runtime CPU time:", onnx_time)
print("Speedup:", pytorch_time / onnx_time)

PyTorch CPU time: 12.345
ONNX Runtime CPU time: 5.678
Speedup: 2.1741810496653753
