In [1]:
import yaml
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import random
import numpy as np
import pandas as pd
import math
from sklearn.model_selection import StratifiedKFold
from torchvision import models, transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from PIL import Image
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, cohen_kappa_score, confusion_matrix
import torchvision.transforms as transforms
from sklearn.metrics import classification_report, confusion_matrix, cohen_kappa_score, roc_auc_score
from sklearn.preprocessing import label_binarize

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Загрузка настроек
with open('parametrs.yaml', 'r', encoding='utf-8') as f:
    config = yaml.safe_load(f)

# Установка сидов
seed = config['experiment']['seed']
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


# Загружаем модель и параметры
device = torch.device(config['training']['device'] if torch.cuda.is_available() else "cpu")
batch_size = config['training']['batch_size']
learning_rate = config['training']['learning_rate']
early_stopping_patience = config['training']['patience']

root_dir = 'D:/dataset/_eyepacs/data/ochishenii_fon_512_split/test'

In [3]:
def get_transform_pipeline(config, label=None):
    transform_list = []

    # Только resize — как базовая подготовка изображения
    transform_list.append(transforms.Resize(config['dataset']['image_size']))

    # Преобразование в тензор и нормализация
    transform_list.extend([
        transforms.ToTensor(),
        transforms.Normalize(mean=config['dataset']['mean'], std=config['dataset']['std'])
    ])

    return transforms.Compose(transform_list)


In [4]:
# Feature Extractor для ViT
class FundusDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, root_dir, config):
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.config = config

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        img_name = row['image_name'] + '.jpg'
        label = int(row['true_label'])
        img_path = os.path.join(self.root_dir, str(label), img_name)
        image = Image.open(img_path).convert("RGB")
        
        transform = get_transform_pipeline(self.config, label=label)
        image = transform(image)

        return image, label, img_name


In [5]:
test_csv = config['dataset']['test_csv']
test_df = pd.read_csv(test_csv)
test_dataset = FundusDataset(test_df, root_dir=root_dir, config=config)
test_loader = DataLoader(test_dataset, batch_size=config['testing']['batch_size'], shuffle=False)


In [6]:
def evaluate_model_on_loader(model, dataloader, device, set_name="Val", num_classes=3):
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    image_names = []

    with torch.no_grad():
        for batch in dataloader:
            if isinstance(batch, dict):
                inputs = batch['image'].to(device)
                labels = batch['label'].to(device)
                names = batch.get('image_name', None)

            elif isinstance(batch, (tuple, list)):
                if len(batch) == 3:
                    inputs, labels, names = batch
                    names = list(names)
                elif len(batch) == 2:
                    inputs, labels = batch
                    names = None
                else:
                    raise ValueError(f"Unexpected batch length: {len(batch)}")

                inputs = inputs.to(device)
                labels = labels.to(device)

            else:
                raise ValueError("Unsupported batch format")

            outputs = model(inputs)

            # Извлекаем логиты из объекта ImageClassifierOutput (если это объект)
            if hasattr(outputs, 'logits'):
                logits = outputs.logits
            else:
                logits = outputs  # если это уже тензор

            probs = torch.softmax(logits, dim=1).detach().cpu().numpy()
            preds = np.argmax(probs, axis=1)

            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            if names is not None:
                image_names.extend(names)

    all_probs = np.array(all_probs)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    print(f" {set_name} Classification Report:")
    print(classification_report(all_labels, all_preds, digits=4))

    print(f" Confusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))

    kappa = cohen_kappa_score(all_labels, all_preds)
    print(f" Cohen’s Kappa: {kappa:.4f}")

    # ➕ AUC-ROC
    try:
        y_true_bin = label_binarize(all_labels, classes=list(range(num_classes)))
        auc = roc_auc_score(y_true_bin, all_probs, multi_class='ovr', average='macro')
        print(f" AUC-ROC (OvR, macro): {auc:.4f}")
    except Exception as e:
        print(f" AUC-ROC calculation failed: {e}")

    return all_labels, all_preds, all_probs, image_names


In [7]:
from transformers import ViTConfig, ViTForImageClassification

checkpoint_path = "checkpoints/cosin/cosin_vit_fold1.pth"
state_dict_raw = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# Если есть префикс '_orig_mod.' — убираем его
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict_raw.items()}

# Создаём конфиг вручную с нужным числом классов
model_config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
model = ViTForImageClassification(model_config)

# Загружаем веса
model.load_state_dict(state_dict)
model.to(torch.device('cpu'))
model.eval()

# Предсказания
evaluate_model_on_loader(model, test_loader, device=torch.device('cpu'), set_name="Test", num_classes=3)


 Test Classification Report:
              precision    recall  f1-score   support

           0     0.8326    0.9200    0.8741      5162
           1     0.5880    0.4501    0.5099      1722
           2     0.5000    0.0141    0.0274       142

    accuracy                         0.7865      7026
   macro avg     0.6402    0.4614    0.4705      7026
weighted avg     0.7659    0.7865    0.7677      7026

 Confusion Matrix:
[[4749  412    1]
 [ 946  775    1]
 [   9  131    2]]
 Cohen’s Kappa: 0.4029
 AUC-ROC (OvR, macro): 0.8344


(array([0, 0, 0, ..., 1, 0, 1], dtype=int64),
 array([0, 0, 0, ..., 1, 0, 1], dtype=int64),
 array([[0.8402739 , 0.12244706, 0.03727905],
        [0.8478072 , 0.10750393, 0.04468885],
        [0.80165887, 0.16279371, 0.03554737],
        ...,
        [0.23995303, 0.70612144, 0.05392552],
        [0.8422144 , 0.12094434, 0.0368412 ],
        [0.20650461, 0.679858  , 0.11363735]], dtype=float32),
 ['34740_left.jpg',
  '10048_left.jpg',
  '32767_right.jpg',
  '20023_left.jpg',
  '35446_left.jpg',
  '7746_right.jpg',
  '31885_left.jpg',
  '23249_right.jpg',
  '22589_right.jpg',
  '23425_right.jpg',
  '8227_right.jpg',
  '40653_left.jpg',
  '20933_right.jpg',
  '29755_right.jpg',
  '16289_left.jpg',
  '31309_right.jpg',
  '21358_right.jpg',
  '31471_left.jpg',
  '30478_right.jpg',
  '43987_left.jpg',
  '26576_left.jpg',
  '10988_left.jpg',
  '37043_left.jpg',
  '23704_right.jpg',
  '2315_left.jpg',
  '1915_right.jpg',
  '43340_right.jpg',
  '9652_right.jpg',
  '23439_left.jpg',
  '11125_rig

In [8]:
from transformers import ViTConfig, ViTForImageClassification

checkpoint_path = "checkpoints/cosin/cosin_vit_fold2.pth"
state_dict_raw = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# Если есть префикс '_orig_mod.' — убираем его
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict_raw.items()}

# Создаём конфиг вручную с нужным числом классов
model_config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
model = ViTForImageClassification(model_config)

# Загружаем веса
model.load_state_dict(state_dict)
model.to(torch.device('cpu'))
model.eval()

# Предсказания
evaluate_model_on_loader(model, test_loader, device=torch.device('cpu'), set_name="Test", num_classes=3)


 Test Classification Report:
              precision    recall  f1-score   support

           0     0.8269    0.9285    0.8748      5162
           1     0.6007    0.4280    0.4998      1722
           2     0.6667    0.0141    0.0276       142

    accuracy                         0.7874      7026
   macro avg     0.6981    0.4569    0.4674      7026
weighted avg     0.7682    0.7874    0.7658      7026

 Confusion Matrix:
[[4793  368    1]
 [ 985  737    0]
 [  18  122    2]]
 Cohen’s Kappa: 0.3944
 AUC-ROC (OvR, macro): 0.8326


(array([0, 0, 0, ..., 1, 0, 1], dtype=int64),
 array([0, 0, 0, ..., 1, 0, 1], dtype=int64),
 array([[0.8538205 , 0.10696697, 0.03921252],
        [0.83731   , 0.11593693, 0.04675297],
        [0.7654312 , 0.19998421, 0.03458449],
        ...,
        [0.28190532, 0.6642169 , 0.05387776],
        [0.81754744, 0.14608225, 0.03637033],
        [0.1728047 , 0.71196765, 0.11522766]], dtype=float32),
 ['34740_left.jpg',
  '10048_left.jpg',
  '32767_right.jpg',
  '20023_left.jpg',
  '35446_left.jpg',
  '7746_right.jpg',
  '31885_left.jpg',
  '23249_right.jpg',
  '22589_right.jpg',
  '23425_right.jpg',
  '8227_right.jpg',
  '40653_left.jpg',
  '20933_right.jpg',
  '29755_right.jpg',
  '16289_left.jpg',
  '31309_right.jpg',
  '21358_right.jpg',
  '31471_left.jpg',
  '30478_right.jpg',
  '43987_left.jpg',
  '26576_left.jpg',
  '10988_left.jpg',
  '37043_left.jpg',
  '23704_right.jpg',
  '2315_left.jpg',
  '1915_right.jpg',
  '43340_right.jpg',
  '9652_right.jpg',
  '23439_left.jpg',
  '11125_rig

In [9]:
from transformers import ViTConfig, ViTForImageClassification

checkpoint_path = "checkpoints/cosin/cosin_vit_fold3.pth"
state_dict_raw = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# Если есть префикс '_orig_mod.' — убираем его
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict_raw.items()}

# Создаём конфиг вручную с нужным числом классов
model_config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=3)
model = ViTForImageClassification(model_config)

# Загружаем веса
model.load_state_dict(state_dict)
model.to(torch.device('cpu'))
model.eval()

# Предсказания
evaluate_model_on_loader(model, test_loader, device=torch.device('cpu'), set_name="Test", num_classes=3)


 Test Classification Report:
              precision    recall  f1-score   support

           0     0.8260    0.9272    0.8737      5162
           1     0.5885    0.4210    0.4909      1722
           2     0.0000    0.0000    0.0000       142

    accuracy                         0.7844      7026
   macro avg     0.4715    0.4494    0.4548      7026
weighted avg     0.7511    0.7844    0.7622      7026

 Confusion Matrix:
[[4786  376    0]
 [ 997  725    0]
 [  11  131    0]]
 Cohen’s Kappa: 0.3859
 AUC-ROC (OvR, macro): 0.8331


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


(array([0, 0, 0, ..., 1, 0, 1], dtype=int64),
 array([0, 0, 0, ..., 1, 0, 1], dtype=int64),
 array([[0.8392315 , 0.11968234, 0.0410861 ],
        [0.8374068 , 0.11965587, 0.04293735],
        [0.8394407 , 0.12876819, 0.03179113],
        ...,
        [0.2734736 , 0.6583484 , 0.06817806],
        [0.8313258 , 0.13389555, 0.03477852],
        [0.2157574 , 0.6973206 , 0.08692208]], dtype=float32),
 ['34740_left.jpg',
  '10048_left.jpg',
  '32767_right.jpg',
  '20023_left.jpg',
  '35446_left.jpg',
  '7746_right.jpg',
  '31885_left.jpg',
  '23249_right.jpg',
  '22589_right.jpg',
  '23425_right.jpg',
  '8227_right.jpg',
  '40653_left.jpg',
  '20933_right.jpg',
  '29755_right.jpg',
  '16289_left.jpg',
  '31309_right.jpg',
  '21358_right.jpg',
  '31471_left.jpg',
  '30478_right.jpg',
  '43987_left.jpg',
  '26576_left.jpg',
  '10988_left.jpg',
  '37043_left.jpg',
  '23704_right.jpg',
  '2315_left.jpg',
  '1915_right.jpg',
  '43340_right.jpg',
  '9652_right.jpg',
  '23439_left.jpg',
  '11125_rig