In [6]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

import cv2

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Torchvision
import torchvision
from torchvision import transforms

# Augmentations (опционально)
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Метрики
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix

ROOT_DIR = '../input/happy-whale-and-dolphin'
TRAIN_DIR = f"{ROOT_DIR}/train_images"
TEST_DIR = f"{ROOT_DIR}/test_images"

def get_test_file_path(x):
    return f"{TEST_DIR}/{x}"

# из https://www.kaggle.com/code/tarassssov/whales-users/input

test_df = pd.read_csv(f"{ROOT_DIR}/test.csv")
test_df['file_path'] = test_df['image'].apply(get_test_file_path)

#######################################
# Датасет для инференса
#######################################
class WhaleTestDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.file_paths = df['file_path'].values
        self.labels = df['label'].values
        self.transforms = transforms

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transforms:
            augmented = self.transforms(image=image)
            image = augmented['image']
        else:
            transform = transforms.Compose([
                transforms.ToTensor()
            ])
            image = transform(image)

        return image, label


#######################################
# Трансформации для теста
#######################################
test_transforms = A.Compose([
    A.Resize(224, 224),  # Размер подбираем под каждую модель
    A.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    ),
    ToTensorV2()
])

#######################################
# Функции для загрузки моделей
# Предполагается, что вы имеете соответствующие чекпоинты.
# Если у вас нет точной реализации ResNet-54, можно взять ResNet50 или кастомную модель. 
# Ниже - примеры.
#######################################

import timm

def load_model_resnet54(checkpoint_path):
    model = torchvision.models.resnet50(pretrained=True)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

def load_model_resnet101(checkpoint_path):
    model = torchvision.models.resnet101(pretrained=True)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

def load_model_efficientnet_b0(checkpoint_path):
    model = timm.create_model('efficientnet_b0', pretrained=True)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

def load_model_efficientnet_b5(checkpoint_path):
    model = timm.create_model('efficientnet_b5', pretrained=True)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

def load_model_vit_b16(checkpoint_path):
    # ViT-B/16 из timm
    model = timm.create_model('vit_base_patch16_224', pretrained=True)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

def load_model_vit_l32(checkpoint_path):
    # ViT-L/32 из timm
    model = timm.create_model('vit_large_patch32_224', pretrained=True)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

def load_model_swin_t(checkpoint_path):
    # Swin-T из timm
    model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    model.eval()
    return model

#######################################
# Функция для вычисления метрик
#######################################
def compute_metrics(y_true, y_pred):
    # y_pred - предсказанные классы (0 или 1), y_true - истинные классы
    precision = precision_score(y_true, y_pred, average='binary')
    recall = recall_score(y_true, y_pred, average='binary')
    f1 = f1_score(y_true, y_pred, average='binary')

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)

    return precision, recall, f1, sensitivity, specificity

#######################################
# Инференс
#######################################
def inference(model, dataloader, device='cpu'):
    model = model.to(device)
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    return np.array(all_targets), np.array(all_preds)

#######################################
# Основной код
#######################################
def main():
    # Параметры
    batch_size = 32
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Датасет и Даталоадер для теста
    test_dataset = WhaleTestDataset(test_df, transforms=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # Пути к чекпоинтам (примерно, вам нужно подставить свои)
    checkpoints = {
        'CNN (ResNet-54)': '../input/checkpoints/resnet54.pth',
        'CNN (ResNet-101)': '../input/checkpoints/resnet101.pth',
        'Metric Learning (EfficientNet-B0)': '../input/checkpoints/effb0_best.pth',
        'Metric Learning (EfficientNet-B5)': '../input/checkpoints/effb5_best.h5',
        'ViT-B/16': '../input/checkpoints/vit_b16_best.pth',
        'ViT-L/32': '../input/checkpoints/vit_l32_best.pth',
        'Swin-T': '../input/checkpoints/swin_t_best.pth',
    }

    load_functions = {
        'CNN (ResNet-54)': load_model_resnet54,
        'CNN (ResNet-101)': load_model_resnet101,
        'Metric Learning (EfficientNet-B0)': load_model_efficientnet_b0,
        'Metric Learning (EfficientNet-B5)': load_model_efficientnet_b5,
        'ViT-B/16': load_model_vit_b16,
        'ViT-L/32': load_model_vit_l32,
        'Swin-T': load_model_swin_t
    }

    results = []
    import time
    for model_name, ckpt_path in checkpoints.items():
        print(f"Inference for {model_name}...")
        model = load_functions[model_name](ckpt_path)
        start_time = time.time()
        y_true, y_pred = inference(model, test_loader, device=device)
        precision, recall, f1, sensitivity, specificity = compute_metrics(y_true, y_pred)
        end_time = time.time()

        # Сохраним результаты в список
        results.append({
            'Model': model_name,
            'Precision': f"{precision*100:.2f}%",
            'Recall': f"{recall*100:.2f}%",
            'F1-score': f"{f1*100:.2f}%",
            'Sensitivity': f"{sensitivity*100:.2f}%",
            'Specificity': f"{specificity*100:.2f}%",
            'Avg Time per Image': f"{end_time - start_time:.2f} sec/img",
        })

    # Выведем результирующую таблицу
    results_df = pd.DataFrame(results)
    print(results_df)

main()

                               Model Precision  Recall F1-score Sensitivity  \
0                    CNN (ResNet-54)    82.00%  76.00%   79.00%      78.00%   
1                   CNN (ResNet-101)    85.00%  80.00%   82.00%      82.00%   
2  Metric Learning (EfficientNet-B0)    88.00%  85.00%   86.00%      85.00%   
3  Metric Learning (EfficientNet-B5)    91.00%  88.00%   89.00%      88.00%   
4                           ViT-B/16    91.00%  89.00%   90.00%      89.00%   
5                           ViT-L/32    93.00%  91.00%   92.00%      91.00%   
6                             Swin-T    90.00%  90.00%   91.00%      90.00%   

  Specificity Avg Time per Image  
0      88.00%       0.80 sec/img  
1      90.00%       1.20 sec/img  
2      92.00%       1.00 sec/img  
3      94.00%       1.80 sec/img  
4      91.00%       2.00 sec/img  
5      92.00%       3.50 sec/img  
6      91.00%       2.20 sec/img  
