In [3]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import cv2

# Configuration
DATA_DIR = "screw_dataset"
SEED = 42
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

def advanced_preprocessing(image):
    """
    Prétraitement pour révéler les détails subtils
    """
    # Conversion en niveau de gris avec conservation des détails
    gray = cv2.cvtColor(image.numpy(), cv2.COLOR_RGB2GRAY)
    
    # Amélioration du contraste adaptatif
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    enhanced = clahe.apply(gray)
    
    # Détection de contours adaptative
    edges = cv2.Canny(enhanced, 50, 150)
    
    # Reconstruction de l'image avec mise en valeur des détails
    reconstructed = cv2.dilate(edges, None)
    reconstructed = cv2.normalize(reconstructed, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    return tf.convert_to_tensor(reconstructed, dtype=tf.float32)

def sophisticated_augmentation(image):
    """
    Augmentations simulant des défauts microscopiques
    """
    transforms = [
        # Variations subtiles de géométrie
        lambda img: tf.image.rot90(img, k=random.randint(0, 3)),
        
        # Simulations de rayures et imperfections
        lambda img: tf.image.adjust_contrast(img, random.uniform(0.8, 1.5)),
        
        # Variations locales de luminosité
        lambda img: tf.image.random_brightness(img, max_delta=0.1),
        
        # Simulation de variations de texture
        lambda img: tf.image.adjust_saturation(img, random.uniform(0.5, 1.5))
    ]
    
    # Combinaison aléatoire de transformations
    selected_transforms = random.sample(transforms, random.randint(1, 3))
    
    transformed = image
    for transform in selected_transforms:
        transformed = transform(transformed)
    
    return tf.clip_by_value(transformed, 0.0, 1.0)

def create_model(input_shape=(224, 224, 3)):
    """
    Modèle de détection de défauts basé sur un réseau convolutif profond
    """
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2, 2),
        
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2, 2),
        
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2, 2),
        
        layers.Conv2D(256, (3, 3), activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2, 2),
        
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.4),
        
        layers.Dense(1, activation='sigmoid')
    ])
    
    return model

def focal_loss(gamma=2., alpha=.25):
    """Focal loss pour gérer le déséquilibre des classes"""
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -tf.reduce_sum(alpha * tf.pow(1. - pt_1, gamma) * tf.math.log(pt_1)) \
               -tf.reduce_sum((1 - alpha) * tf.pow(pt_0, gamma) * tf.math.log(1. - pt_0))
    return focal_loss_fixed

# Chargement et préparation des données
def load_and_prepare_data():
    # Charger les images
    dataset = image_dataset_from_directory(
        DATA_DIR,
        labels='inferred',
        label_mode='binary',
        batch_size=BATCH_SIZE,
        image_size=IMG_SIZE,
        shuffle=False
    )

    # Extraction des images et labels
    all_images = []
    all_labels = []
    for images, labels in dataset:
        all_images.append(images.numpy())
        all_labels.append(labels.numpy())

    X = np.concatenate(all_images, axis=0) / 255.0  # Normalisation 
    y = np.concatenate(all_labels, axis=0)

    # Mélange des données
    indices = np.arange(X.shape[0])
    np.random.seed(SEED)
    np.random.shuffle(indices)
    X_shuffled = X[indices]
    y_shuffled = y[indices]

    # Division train/validation/test
    from sklearn.model_selection import train_test_split
    X_train, X_temp, y_train, y_temp = train_test_split(
        X_shuffled, y_shuffled, train_size=0.7, stratify=y_shuffled, random_state=SEED
    )

    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, train_size=0.5, stratify=y_temp, random_state=SEED
    )

    return X_train, y_train, X_val, y_val, X_test, y_test

# Augmentation de données
def data_augmentation(X_train, y_train):
    # Trouver les indices des images "bad" et "good"
    bad_indices = np.where(y_train[:, 0] == 0)[0]  # Assurez-vous d'indexer correctement si y_train est 2D
    good_indices = np.where(y_train[:, 0] == 1)[0]
    
    X_bad = X_train[bad_indices]
    X_good = X_train[good_indices]

    # Augmenter les images "bad"
    X_bad_augmented = []
    y_bad_augmented = []

    target_count = len(X_good)
    
    while len(X_bad_augmented) < target_count:
        # Choisir une image source
        img_source = X_bad[len(X_bad_augmented) % len(X_bad)]
        
        # Appliquer augmentation
        transformed_img = sophisticated_augmentation(tf.convert_to_tensor(img_source))
        
        X_bad_augmented.append(transformed_img.numpy())
        y_bad_augmented.append([0])  # Ajouter comme tableau 2D
    
    # Combinaison et mélange
    X_balanced = np.concatenate([X_good, X_bad_augmented])
    y_balanced = np.concatenate([y_train[good_indices], y_bad_augmented])
    
    indices = np.random.permutation(len(X_balanced))
    return X_balanced[indices], y_balanced[indices]

# Entraînement du modèle
def train_model():
    # Charger les données
    X_train, y_train, X_val, y_val, X_test, y_test = load_and_prepare_data()
    
    # Augmentation des données
    X_train_augmented, y_train_augmented = data_augmentation(X_train, y_train)
    
    # Création du modèle
    model = create_model()
    
    # Compilation
    model.compile(
        optimizer='adam',
        loss=focal_loss(),
        metrics=['accuracy', 'precision', 'recall']
    )
    
    # Callbacks
    callbacks = [
        EarlyStopping(patience=15, restore_best_weights=True),
        ReduceLROnPlateau(patience=10, factor=0.5),
        ModelCheckpoint('best_screw_model.h5', save_best_only=True)
    ]
    
    # Entraînement
    history = model.fit(
        X_train_augmented, y_train_augmented,
        validation_data=(X_val, y_val),
        epochs=50,
        batch_size=BATCH_SIZE,
        callbacks=callbacks,
        class_weight={0: 5.0, 1: 1.0}
    )
    
    # Évaluation
    test_loss, test_accuracy, test_precision, test_recall = model.evaluate(X_test, y_test)
    print("\nRésultats sur le jeu de test:")
    print(f"Perte: {test_loss}")
    print(f"Accuracy: {test_accuracy}")
    print(f"Précision: {test_precision}")
    print(f"Recall: {test_recall}")
    
    # Visualisation des métriques
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Accuracy Entraînement')
    plt.plot(history.history['val_accuracy'], label='Accuracy Validation')
    plt.title('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Perte Entraînement')
    plt.plot(history.history['val_loss'], label='Perte Validation')
    plt.title('Perte')
    plt.legend()

    plt.tight_layout()
    plt.show()
    
    return model, history

# Exécution
if __name__ == "__main__":
    model, history = train_model()

Found 1152 files belonging to 2 classes.


2025-03-19 22:32:19.524897: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 1/50


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 486ms/step - accuracy: 0.5073 - loss: nan - precision: 0.6117 - recall: 0.1439 - val_accuracy: 0.2486 - val_loss: nan - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 2/50
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 482ms/step - accuracy: 0.4997 - loss: nan - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.2486 - val_loss: nan - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 3/50
[1m38/38[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 472ms/step - accuracy: 0.4896 - loss: nan - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.2486 - val_loss: nan - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 4/50
[1m 5/38[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m15s[0m 483ms/step - accuracy: 0.4795 - loss: nan - precision: 0.0000e+00 - recall: 0.0000e+00

KeyboardInterrupt: 