In [1]:
%pip install tensorflow numpy matplotlib seaborn scikit-learn kagglehub

Note: you may need to restart the kernel to use updated packages.


## Импорты

In [2]:
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
import kagglehub
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

## Загрузка датасета

In [3]:
path = kagglehub.dataset_download("sanikamal/horses-or-humans-dataset")

base_dir = Path(path)

train_dir = base_dir / "horse-or-human" / "train"
validation_dir = base_dir / "horse-or-human" / "validation"

In [4]:
IMG_HEIGHT = 150
IMG_WIDTH = 150
BATCH_SIZE = 32

def create_generator(directory, augment=False, shuffle=True):
    if augment:
        datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )
    else:
        datagen = ImageDataGenerator(rescale=1./255)
    
    generator = datagen.flow_from_directory(
        directory,
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        batch_size=BATCH_SIZE,
        class_mode='binary',
        shuffle=shuffle
    )
    
    return generator

train_generator = create_generator(train_dir, augment=False, shuffle=True)
validation_generator = create_generator(validation_dir, augment=False, shuffle=False)

if validation_generator is None and train_generator is not None:
    
    all_images = []
    all_labels = []
    
    for class_dir in train_dir.iterdir():
        if class_dir.is_dir():
            class_name = class_dir.name
            class_label = 0 if 'horse' in class_name.lower() else 1
            
            for img_file in class_dir.glob("*.*"):
                if img_file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
                    all_images.append(str(img_file))
                    all_labels.append(class_label)
    
    train_files, val_files, train_labels, val_labels = train_test_split(
        all_images, all_labels, test_size=0.2, random_state=42, stratify=all_labels
    )
    
    train_datagen = ImageDataGenerator(rescale=1./255)
    val_datagen = ImageDataGenerator(rescale=1./255)

    validation_generator = train_generator

if train_generator is not None:
    class_names = list(train_generator.class_indices.keys())


Found 1027 images belonging to 2 classes.
Found 256 images belonging to 2 classes.


## CNN

In [5]:
def create_cnn_model(input_shape=(150, 150, 3)):
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid')
    ])
    
    return model

model = create_cnn_model()
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

## Обучение модели

In [6]:
callbacks = [
    EarlyStopping(
        monitor='val_auc',
        patience=10,
        mode='max',
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    )
]

train_generator_augmented = create_generator(train_dir, augment=True, shuffle=True)


train_steps = max(1, train_generator_augmented.samples // BATCH_SIZE)
val_steps = max(1, validation_generator.samples // BATCH_SIZE) if validation_generator else 1

EPOCHS = 10

history = model.fit(
    train_generator_augmented,
    steps_per_epoch=train_steps,
    epochs=EPOCHS,
    validation_data=validation_generator,
    validation_steps=val_steps,
    callbacks=callbacks,
    verbose=1
)

Found 1027 images belonging to 2 classes.
Epoch 1/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 1s/step - accuracy: 0.7296 - auc: 0.8135 - loss: 0.5981 - val_accuracy: 0.5000 - val_auc: 0.9583 - val_loss: 0.7733 - learning_rate: 1.0000e-04
Epoch 2/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 74ms/step - accuracy: 0.8125 - auc: 0.9444 - loss: 0.3521 - val_accuracy: 0.5000 - val_auc: 0.9545 - val_loss: 0.7819 - learning_rate: 1.0000e-04
Epoch 3/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 1s/step - accuracy: 0.8714 - auc: 0.9469 - loss: 0.3038 - val_accuracy: 0.5000 - val_auc: 0.9012 - val_loss: 1.3372 - learning_rate: 1.0000e-04
Epoch 4/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 73ms/step - accuracy: 0.9062 - auc: 0.9792 - loss: 0.2716 - val_accuracy: 0.5000 - val_auc: 0.9092 - val_loss: 1.3418 - learning_rate: 1.0000e-04
Epoch 5/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 

## Результаты ROC-AUC

In [7]:
def evaluate_model(model, generator):
    
    y_pred_proba = model.predict(generator, verbose=0)
    
    y_true = generator.classes[:len(y_pred_proba)]
    
    roc_auc = roc_auc_score(y_true, y_pred_proba)
    
    print(f"ROC-AUC Score: {roc_auc:.4f}")

    return {
        'roc_auc': roc_auc,
    }

results = evaluate_model(model, validation_generator)

ROC-AUC Score: 0.9633
