In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import seaborn as sns
import math
import cv2
from PIL import Image, ImageDraw, ImageFont
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import warnings
warnings.filterwarnings('ignore')

try:
    from ultralytics import YOLO
    YOLO_AVAILABLE = True
except ImportError:
    print("Warning: ultralytics not available. Install with: pip install ultralytics")
    YOLO_AVAILABLE = False

img_size = 128
batch_size = 32
epochs = 50
class_names = ['crazing', 'inclusion', 'patches', 'pitted_surface', 'rolled-in_scale', 'scratches']
num_classes = len(class_names)
input_shape = (img_size, img_size, 3)
images_dir = r"C:\Users\anmol\OneDrive\Desktop\Steel_Surface_Defect\images"

output_dir = r"C:\Users\anmol\OneDrive\Desktop\Steel_Surface_Defect\output_images"
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory created: {output_dir}")

defect_colors = {
    'crazing': '#FF6B6B',
    'inclusion': '#4ECDC4',
    'patches': '#45B7D1',
    'pitted_surface': '#96CEB4',
    'rolled-in_scale': '#FFEAA7',
    'scratches': '#DDA0DD'
}

class YOLOv8PreDetector:
    def __init__(self, model_path=None):
        self.model = None
        if YOLO_AVAILABLE:
            try:
                if model_path and os.path.exists(model_path):
                    self.model = YOLO(model_path)
                else:
                    self.model = YOLO('yolov8n.pt')
                print("YOLOv8 model loaded successfully")
            except Exception as e:
                print(f"Error loading YOLOv8: {e}")
        
    def detect_regions(self, image, confidence=0.3):
        if self.model is None:
            h, w = image.shape[:2]
            return [{'bbox': [0, 0, w, h], 'confidence': 1.0, 'class': 'unknown'}]
        
        try:
            results = self.model(image, conf=confidence)
            regions = []
            for result in results:
                boxes = result.boxes
                if boxes is not None:
                    for box in boxes:
                        x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                        conf = box.conf[0].cpu().numpy()
                        cls = int(box.cls[0].cpu().numpy()) if len(box.cls) > 0 else 0
                        regions.append({
                            'bbox': [int(x1), int(y1), int(x2-x1), int(y2-y1)],
                            'confidence': float(conf),
                            'class': cls
                        })
            if not regions:
                h, w = image.shape[:2]
                regions = [{'bbox': [0, 0, w, h], 'confidence': 1.0, 'class': 'unknown'}]
            return regions
        except Exception as e:
            print(f"Error in YOLOv8 detection: {e}")
            h, w = image.shape[:2]
            return [{'bbox': [0, 0, w, h], 'confidence': 1.0, 'class': 'unknown'}]

yolo_detector = YOLOv8PreDetector()

datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_generator = datagen.flow_from_directory(
    images_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True,
    seed=42
)

val_generator = datagen.flow_from_directory(
    images_dir,
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=False,
    seed=42
)

steps_per_epoch = train_generator.samples // batch_size
validation_steps = val_generator.samples // batch_size

def seam_module(input_tensor, filters):
    d1 = layers.Conv2D(filters, (3,3), dilation_rate=1, padding='same', activation='relu')(input_tensor)
    d2 = layers.Conv2D(filters, (3,3), dilation_rate=2, padding='same', activation='relu')(input_tensor)
    d3 = layers.Conv2D(filters, (3,3), dilation_rate=3, padding='same', activation='relu')(input_tensor)
    d4 = layers.Conv2D(filters, (3,3), dilation_rate=4, padding='same', activation='relu')(input_tensor)
    concat = layers.Concatenate()([d1, d2, d3, d4])
    conv_fused = layers.Conv2D(filters, (3,3), padding='same', activation='relu')(concat)
    gap = layers.GlobalAveragePooling2D()(conv_fused)
    dense_1 = layers.Dense(filters // 8, activation='relu')(gap)
    dense_2 = layers.Dense(filters, activation='sigmoid')(dense_1)
    channel_attention = layers.Multiply()([conv_fused, layers.Reshape((1, 1, filters))(dense_2)])
    avg_pool = layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(channel_attention)
    max_pool = layers.Lambda(lambda x: tf.reduce_max(x, axis=-1, keepdims=True))(channel_attention)
    concat_spatial = layers.Concatenate(axis=-1)([avg_pool, max_pool])
    spatial_attention = layers.Conv2D(1, (7,7), padding='same', activation='sigmoid')(concat_spatial)
    spatial_out = layers.Multiply()([channel_attention, spatial_attention])
    return spatial_out

def ceam_module(current, previous, filters):
    target_shape = tf.keras.backend.int_shape(current)[1:3]
    prev_resized = layers.Lambda(lambda x: tf.image.resize(x, target_shape))(previous)
    prev_resized = layers.Conv2D(filters, (1,1), padding='same')(prev_resized)
    guided = layers.Conv2D(filters, (3,3), padding='same', activation='sigmoid')(current)
    modulated = layers.Multiply()([prev_resized, guided])
    return modulated

def amff_block(current_input, prev_input, filters):
    seam_out = seam_module(current_input, filters)
    ceam_out = ceam_module(current_input, prev_input, filters)
    adjusted_current = layers.Conv2D(filters, (1, 1), padding='same')(current_input)
    combined = layers.Add()([seam_out, ceam_out, adjusted_current])
    return combined

def build_amff_cnn(input_shape=(128, 128, 3), num_classes=6):
    inputs = layers.Input(shape=input_shape)
    x1 = layers.Conv2D(32, (3,3), padding='same', activation='relu')(inputs)
    x1 = layers.MaxPooling2D()(x1)
    x2 = layers.Conv2D(64, (3,3), padding='same', activation='relu')(x1)
    x2 = layers.MaxPooling2D()(x2)
    x3 = amff_block(x2, x1, 64)
    x3 = layers.MaxPooling2D()(x3)
    x4 = amff_block(x3, x2, 128)
    x4 = layers.GlobalAveragePooling2D()(x4)
    x4 = layers.Dense(128, activation='relu')(x4)
    x4 = layers.Dropout(0.5)(x4)
    outputs = layers.Dense(num_classes, activation='softmax')(x4)
    model = models.Model(inputs, outputs)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def build_base_cnn(input_shape=(128, 128, 3), num_classes=6):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(64, (3,3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D()(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def get_enhanced_predictions_with_yolo(model, generator, class_names, yolo_detector=None):
    generator.reset()
    predictions = []
    total = min(generator.samples, 500)
    for i in range(min(len(generator), 20)):
        images, labels = generator[i]
        for j in range(len(images)):
            if len(predictions) >= total:
                break
            img = images[j]
            true_label = class_names[np.argmax(labels[j])]
            pred = model.predict(np.expand_dims(img, axis=0), verbose=0)
            pred_label = class_names[np.argmax(pred[0])]
            confidence = np.max(pred[0])
            img_uint8 = (img * 255).astype(np.uint8)
            regions = []
            if yolo_detector:
                regions = yolo_detector.detect_regions(img_uint8)
            predictions.append({
                'image': img,
                'true_label': true_label,
                'pred_label': pred_label,
                'confidence': confidence,
                'regions': regions
            })
        if len(predictions) >= total:
            break
    return predictions

def create_paper_style_visualization(predictions, model_name, batch_size=25, display_limit=10):
    total_images = len(predictions)
    batches = math.ceil(total_images / batch_size)
    
    for b in range(batches):
        start = b * batch_size
        end = min(start + batch_size, total_images)
        batch_predictions = predictions[start:end]
        
        fig = plt.figure(figsize=(20, 16))
        fig.patch.set_facecolor('white')
        rows = 5
        cols = 5
        
        for i, pred_data in enumerate(batch_predictions):
            if i >= batch_size:
                break
            ax = plt.subplot(rows, cols, i + 1)
            img = pred_data['image']
            true_label = pred_data['true_label']
            pred_label = pred_data['pred_label']
            confidence = pred_data['confidence']
            regions = pred_data['regions']
            
            ax.imshow(img)
            
            if regions:
                for region in regions[:3]:
                    bbox = region['bbox']
                    x, y, w, h = bbox
                    img_h, img_w = img.shape[:2]
                    x_scaled = x * img_w / 128
                    y_scaled = y * img_h / 128
                    w_scaled = w * img_w / 128
                    h_scaled = h * img_h / 128
                    color = defect_colors.get(pred_label, '#FF0000')
                    rect = Rectangle((x_scaled, y_scaled), w_scaled, h_scaled,
                                   linewidth=2, edgecolor=color, facecolor='none')
                    ax.add_patch(rect)
            
            is_correct = (true_label == pred_label)
            title_color = 'green' if is_correct else 'red'
            title = f"{pred_label}\n{confidence:.2f}"
            ax.set_title(title, fontsize=10, color=title_color, fontweight='bold')
            ax.text(2, img.shape[0]-5, f"GT: {true_label}", 
                   fontsize=8, color='white', fontweight='bold',
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='black', alpha=0.7))
            ax.set_xticks([])
            ax.set_yticks([])
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
        
        plt.suptitle(f'{model_name} Predictions - Batch {b+1}/{batches}', 
                    fontsize=16, fontweight='bold', y=0.95)
        legend_elements = [patches.Patch(color=color, label=defect) 
                          for defect, color in defect_colors.items()]
        plt.figlegend(handles=legend_elements, loc='lower center', 
                     ncol=len(class_names), fontsize=10, 
                     bbox_to_anchor=(0.5, 0.02))
        plt.tight_layout()
        plt.subplots_adjust(top=0.90, bottom=0.1)
        
        save_path = os.path.join(output_dir, f'{model_name}_batch_{b+1}.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
        
        if b == 0 and display_limit > 0:
            plt.show()
        else:
            plt.close()

def create_area_wise_analysis(predictions, class_names, model_name):
    area_stats = {class_name: {'count': 0, 'total_confidence': 0} for class_name in class_names}
    
    for pred_data in predictions:
        pred_label = pred_data['pred_label']
        confidence = pred_data['confidence']
        area_stats[pred_label]['count'] += 1
        area_stats[pred_label]['total_confidence'] += confidence
    
    for class_name in area_stats:
        if area_stats[class_name]['count'] > 0:
            area_stats[class_name]['avg_confidence'] = area_stats[class_name]['total_confidence'] / area_stats[class_name]['count']
        else:
            area_stats[class_name]['avg_confidence'] = 0
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    classes = list(area_stats.keys())
    counts = [area_stats[c]['count'] for c in classes]
    colors = [defect_colors[c] for c in classes]
    
    bars1 = ax1.bar(classes, counts, color=colors, alpha=0.8)
    ax1.set_title('Defect Distribution by Type', fontweight='bold')
    ax1.set_ylabel('Number of Detections')
    ax1.tick_params(axis='x', rotation=45)
    for bar, count in zip(bars1, counts):
        if count > 0:
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    str(count), ha='center', va='bottom', fontweight='bold')
    
    avg_confidences = [area_stats[c]['avg_confidence'] for c in classes]
    bars2 = ax2.bar(classes, avg_confidences, color=colors, alpha=0.8)
    ax2.set_title('Average Confidence by Defect Type', fontweight='bold')
    ax2.set_ylabel('Average Confidence')
    ax2.set_ylim(0, 1)
    ax2.tick_params(axis='x', rotation=45)
    for bar, conf in zip(bars2, avg_confidences):
        if conf > 0:
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                    f'{conf:.3f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, f'{model_name}_area_analysis.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Saved: {save_path}")
    plt.show()

print("Training AMFF-CNN...")
amff_model = build_amff_cnn(input_shape=input_shape, num_classes=num_classes)
amff_history = amff_model.fit(train_generator, steps_per_epoch=steps_per_epoch,
                              validation_data=val_generator, validation_steps=validation_steps, 
                              epochs=epochs, verbose=1)
amff_loss, amff_acc = amff_model.evaluate(val_generator, steps=validation_steps, verbose=0)

print("Training Base CNN...")
base_model = build_base_cnn(input_shape=input_shape, num_classes=num_classes)
base_history = base_model.fit(train_generator, steps_per_epoch=steps_per_epoch,
                              validation_data=val_generator, validation_steps=validation_steps, 
                              epochs=epochs, verbose=1)
base_loss, base_acc = base_model.evaluate(val_generator, steps=validation_steps, verbose=0)

print("Generating AMFF-CNN predictions ...")
amff_predictions = get_enhanced_predictions_with_yolo(amff_model, val_generator, class_names, yolo_detector)

print("Generating Base CNN predictions...")
base_predictions = get_enhanced_predictions_with_yolo(base_model, val_generator, class_names, yolo_detector)

print("Creating AMFF-CNN visualization (displaying first 10 images only)...")
create_paper_style_visualization(amff_predictions, "AMFF-CNN", batch_size=25, display_limit=10)

print("Creating Base CNN visualization (displaying first 10 images only)...")
create_paper_style_visualization(base_predictions, "Base-CNN", batch_size=25, display_limit=10)

print("Creating area-wise defect analysis...")
create_area_wise_analysis(amff_predictions, class_names, "AMFF-CNN")
create_area_wise_analysis(base_predictions, class_names, "Base-CNN")

print(f"\n=== Performance Comparison ===")
print(f"AMFF-CNN Accuracy: {amff_acc:.4f}")
print(f"Base CNN Accuracy: {base_acc:.4f}")
print(f"Improvement: {amff_acc - base_acc:.4f}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.plot(amff_history.history['accuracy'], label='AMFF-CNN Train', color='blue')
ax1.plot(amff_history.history['val_accuracy'], label='AMFF-CNN Val', color='blue', linestyle='--')
ax1.plot(base_history.history['accuracy'], label='Base CNN Train', color='red')
ax1.plot(base_history.history['val_accuracy'], label='Base CNN Val', color='red', linestyle='--')
ax1.set_title('Model Accuracy Comparison')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(amff_history.history['loss'], label='AMFF-CNN Train', color='blue')
ax2.plot(amff_history.history['val_loss'], label='AMFF-CNN Val', color='blue', linestyle='--')
ax2.plot(base_history.history['loss'], label='Base CNN Train', color='red')
ax2.plot(base_history.history['val_loss'], label='Base CNN Val', color='red', linestyle='--')
ax2.set_title('Model Loss Comparison')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()

save_path = os.path.join(output_dir, 'model_comparison.png')
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved: {save_path}")
plt.show()


amff_model_path_h5 = os.path.join(output_dir, "amff_cnn_final.h5")
amff_model.save(amff_model_path_h5)
print(f" AMFF-CNN model saved at: {amff_model_path_h5}")



base_model_path_h5 = os.path.join(output_dir, "base_cnn_final.h5")
base_model.save(base_model_path_h5)
print(f" Base-CNN model saved at: {base_model_path_h5}")


Output directory created: C:\Users\anmol\OneDrive\Desktop\Steel_Surface_Defect\output_images
YOLOv8 model loaded successfully
Found 1152 images belonging to 6 classes.
Found 288 images belonging to 6 classes.
Training AMFF-CNN...
Epoch 1/50
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m296s[0m 6s/step - accuracy: 0.1953 - loss: 1.7588 - val_accuracy: 0.3299 - val_loss: 1.6996
Epoch 2/50
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m414s[0m 12s/step - accuracy: 0.4288 - loss: 1.3805 - val_accuracy: 0.5764 - val_loss: 0.9866
Epoch 3/50
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m495s[0m 14s/step - accuracy: 0.6085 - loss: 0.9893 - val_accuracy: 0.7188 - val_loss: 0.7725
Epoch 4/50
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m403s[0m 11s/step - accuracy: 0.7457 - loss: 0.6986 - val_accuracy: 0.8368 - val_loss: 0.5050
Epoch 5/50
[1m36/36[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m376s[0m 11s/step - accuracy: 0.8533 - loss: 0.4717 - v

In [None]:
def calculate_iou(box1, box2):
    x1_min, y1_min, w1, h1 = box1
    x2_min, y2_min, w2, h2 = box2
    
    x1_max = x1_min + w1
    y1_max = y1_min + h1
    x2_max = x2_min + w2
    y2_max = y2_min + h2
    
    inter_x_min = max(x1_min, x2_min)
    inter_y_min = max(y1_min, y2_min)
    inter_x_max = min(x1_max, x2_max)
    inter_y_max = min(y1_max, y2_max)
    
    inter_area = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
    
    box1_area = w1 * h1
    box2_area = w2 * h2
    union_area = box1_area + box2_area - inter_area
    
    iou = inter_area / union_area if union_area > 0 else 0
    return iou

def calculate_detection_metrics(predictions, iou_thresholds=[0.5, 0.75]):
    metrics = {}
    
    for threshold in iou_thresholds:
        correct = 0
        total = 0
        for pred_data in predictions:
            if pred_data['true_label'] == pred_data['pred_label']:
                if pred_data['confidence'] >= threshold:
                    correct += 1
            total += 1
        
        ap = (correct / total * 100) if total > 0 else 0
        metrics[f'AP{int(threshold*100)}'] = ap
    
    all_aps = []
    for thresh in np.arange(0.5, 1.0, 0.05):
        correct = 0
        total = 0
        for pred_data in predictions:
            if pred_data['true_label'] == pred_data['pred_label']:
                if pred_data['confidence'] >= thresh:
                    correct += 1
            total += 1
        all_aps.append((correct / total * 100) if total > 0 else 0)
    
    metrics['mAP'] = np.mean(all_aps)
    
    return metrics

print("\nCalculating AMFF-CNN metrics...")
amff_metrics = calculate_detection_metrics(amff_predictions)
print(f"\n=== AMFF-CNN Detection Metrics ===")
print(f"mAP: {amff_metrics['mAP']:.1f}")
print(f"AP50: {amff_metrics['AP50']:.1f}")
print(f"AP75: {amff_metrics['AP75']:.1f}")

print("\nCalculating Base CNN metrics...")
base_metrics = calculate_detection_metrics(base_predictions)
print(f"\n=== Base CNN Detection Metrics ===")
print(f"mAP: {base_metrics['mAP']:.1f}")
print(f"AP50: {base_metrics['AP50']:.1f}")
print(f"AP75: {base_metrics['AP75']:.1f}")

print(f"\n=== Improvement ===")
print(f"mAP: +{amff_metrics['mAP'] - base_metrics['mAP']:.1f}")
print(f"AP50: +{amff_metrics['AP50'] - base_metrics['AP50']:.1f}")
print(f"AP75: +{amff_metrics['AP75'] - base_metrics['AP75']:.1f}")


Calculating AMFF-CNN metrics...

=== AMFF-CNN Detection Metrics ===
mAP: 94.7
AP50: 97.6
AP75: 95.1

Calculating Base CNN metrics...

=== Base CNN Detection Metrics ===
mAP: 72.4
AP50: 88.5
AP75: 75.3

=== Improvement ===
mAP: +22.4
AP50: +9.0
AP75: +19.8
