In [2]:
#!c1.4
import os
import zipfile

import cv2
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.callbacks import EarlyStopping
from keras.layers import (
    BatchNormalization,
    Conv2D,
    Dense,
    Dropout,
    Flatten,
    GlobalAveragePooling2D,
    MaxPool2D,
)
from keras.models import Model, Sequential
from keras.preprocessing import image
from keras.utils import to_categorical
from PIL import Image, ImageChops, ImageEnhance
from scipy import stats
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import ResNet50, Xception
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2

import pandas as pd
import matplotlib.pyplot as plt

from data_prepare.dataset_tools import extract_zip_with_cleanup, prepare_and_save_data, create_data_generators
from data_prepare.plots import plot_history, confusion_matrix_plot, roc_plot, precision_recall_plot

In [17]:
models_dir = "models"

In [8]:
def init_data():
    image_archive_path = "data/celeb/v1/"
    fake_images_path, real_images_path = extract_zip_with_cleanup(image_archive_path)
    print(fake_images_path, real_images_path)
    train_dir, val_dir, test_dir = prepare_and_save_data(real_images_path, fake_images_path, output_dir="data/dataset/x")
    return create_data_generators(train_dir, val_dir, test_dir)

In [9]:
train_gen, val_gen, test_gen = init_data()

data/celeb/v1/fake data/celeb/v1/real
Found 9073 images belonging to 2 classes.
Found 1944 images belonging to 2 classes.
Found 1946 images belonging to 2 classes.

Class indices: {'fake': 0, 'real': 1}
Train samples: 9073
Val samples: 1944
Test samples: 1946


In [3]:
def build_xception_model(input_shape=(299, 299, 3), num_classes=1):
    """
    Строит модель на основе Xception для обнаружения дипфейков

    Параметры:
        input_shape: размер входного изображения (по умолчанию 299x299 для Xception)
        num_classes: 1 для бинарной классификации (sigmoid), 2 для softmax
    """

    base_model = Xception(
        weights="imagenet", include_top=False, input_shape=input_shape
    )

    base_model.trainable = False

    model = Sequential(
        [
            base_model,
            GlobalAveragePooling2D(),
            BatchNormalization(),
            Dense(256, activation="relu"),
            Dropout(0.5),
            Dense(num_classes, activation="sigmoid" if num_classes == 1 else "softmax"),
        ]
    )

    model.compile(
        optimizer=Adam(learning_rate=1e-4),
        loss="binary_crossentropy" if num_classes == 1 else "categorical_crossentropy",
        metrics=[
            tf.keras.metrics.Precision(name="precision"),
            tf.keras.metrics.Recall(name="recall"),
            tf.keras.metrics.AUC(name="auc"),
        ],
    )

    return model, base_model

In [15]:
def train_model(train_generator, val_generator, input_size=(224, 224), unfreeze_layers=0, initial_epochs=5, total_epochs=15):
    """Обучение с подбором LR и визуализацией"""
    model, base_model = build_xception_model(
        input_shape=(*input_size, 3)
    )

    model_path = os.path.join(models_dir, f'best_model_{input_size[0]}x{input_size[1]}_unfreeze{unfreeze_layers}.h5')
    callbacks = [
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_auc',
            patience=5,
            mode='max',
            restore_best_weights=True
        ),
        ModelCheckpoint(
            model_path,
            monitor='val_auc',
            save_best_only=True,
            mode='max'
        )
    ]
    
    history = model.fit(
        train_generator,
        epochs=initial_epochs,
        validation_data=val_generator,
        callbacks=callbacks
    )
    
    if not fine_tune:
        return model, history
    
    base_model.trainable = True
    
    for layer in base_model.layers[:int(len(base_model.layers)*unfreeze_layers)]:
        layer.trainable = False
    
    model.compile(
        optimizer=Adam(learning_rate=1e-5),
        loss="binary_crossentropy",
        metrics=[
            tf.keras.metrics.Precision(name="precision"),
            tf.keras.metrics.Recall(name="recall"),
            tf.keras.metrics.AUC(name="auc"),
        ]
    )
    
    fine_tune_history = model.fit(
        train_generator,
        epochs=initial_epochs + fine_tune_epochs,
        initial_epoch=history.epoch[-1] + 1,
        validation_data=val_generator,
        callbacks=[
            EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)
        ]
    )
    
    return model, history

In [5]:
def evaluate_model(model, test_gen, input_size, unfreeze_layers, history):
    """Оценка и визуализация результатов с уникальными именами файлов"""
    # Создаем папку для результатов, если её нет
    os.makedirs("evaluation_plots", exist_ok=True)
    
    # 1. Генерируем уникальный префикс для файлов
    prefix = f"size_{input_size[0]}x{input_size[1]}_unfreeze_{unfreeze_layers}"
    
    # 2. Confusion Matrix
    y_pred = model.predict(test_gen)
    y_pred_classes = (y_pred > 0.5).astype(int)
    y_true = test_gen.labels
    
    cm = confusion_matrix(y_true, y_pred_classes)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Fake', 'Real'],
                yticklabels=['Fake', 'Real'])
    plt.title(f'Confusion Matrix\n{prefix}')
    plt.savefig(f'evaluation_plots/{prefix}_confusion_matrix.png', bbox_inches='tight')
    plt.close()
    
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['auc'], label='Train AUC')
    plt.plot(history.history['val_auc'], label='Val AUC')
    plt.title('AUC Curve')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Loss Curve')
    plt.legend()
    
    plt.suptitle(f'Learning Curves - {prefix}')
    plt.savefig(f'evaluation_plots/{prefix}_learning_curves.png', bbox_inches='tight')
    plt.close()
    
    # 4. График LR (если использовался ReduceLROnPlateau)
    if 'lr' in history.history:
        plt.figure(figsize=(8, 4))
        plt.plot(history.history['lr'])
        plt.title('Learning Rate Schedule')
        plt.ylabel('Learning Rate')
        plt.xlabel('Epoch')
        plt.savefig(f'evaluation_plots/{prefix}_lr_schedule.png', bbox_inches='tight')
        plt.close()
    
    return {
        'accuracy': model.evaluate(test_gen, verbose=0)[1],
        'auc': model.evaluate(test_gen, verbose=0)[2]
    }

In [18]:
input_sizes = [(156, 159), (224, 224), (299, 299)]
unfreeze_options = [0, 30, 100]

results = []
for size in input_sizes:
    for unfreeze in unfreeze_options:
        print(f"\nTraining with size={size}, unfreeze={unfreeze}")
        
        # Адаптируем генераторы под новый размер
        train_gen.image_shape = size + (3,)
        val_gen.image_shape = size + (3,)
        
        model, history = train_model(
            train_gen,
            val_gen,
            input_size=size,
            unfreeze_layers=unfreeze
        )
        
        metrics = evaluate_model(
            model, 
            test_gen,
            input_size=size,
            unfreeze_layers=unfreeze,
            history=history
        )
        
        results.append({
            'input_size': f"{size[0]}x{size[1]}",
            'unfreeze_layers': unfreeze,
            'val_auc': max(history.history['val_auc']),
            'test_auc': metrics['auc'],
            'plots_folder': 'evaluation_plots'
        })


Training with size=(156, 159), unfreeze=0


  self._warn_if_super_not_called()


ValueError: could not broadcast input array from shape (299,299,3) into shape (156,159,3)