In [None]:
# ==================================================================================
# EXPLORATORY DATA ANALYSIS - BREAST CANCER
# ==================================================================================

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
from collections import Counter

# Set style for better visualizations
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)

# ==================================================================================
# 1. DATASET OVERVIEW
# ==================================================================================

print("=" * 80)
print("BREAST CANCER DATASET - EXPLORATORY DATA ANALYSIS")
print("=" * 80)

# Define the dataset path
CANCER_TYPE = 'Breast_Cancer'
DATASET_PATH = '/kaggle/input/multi-cancer/Multi Cancer/Multi Cancer/Breast Cancer'

print(f"\nüìÅ Dataset Path: {DATASET_PATH}")
print(f"üî¨ Cancer Type: {CANCER_TYPE}")

# ==================================================================================
# 2. DIRECTORY STRUCTURE & CLASS DISTRIBUTION
# ==================================================================================

print("\n" + "=" * 80)
print("DIRECTORY STRUCTURE & CLASS DISTRIBUTION")
print("=" * 80)

classes = []
class_counts = {}
all_files = {}

# Check if path exists
if os.path.exists(DATASET_PATH):
    classes = os.listdir(DATASET_PATH)
    classes = [c for c in classes if os.path.isdir(os.path.join(DATASET_PATH, c))]
    
    print(f"\n‚úì Found {len(classes)} classes: {classes}")
    
    for class_name in classes:
        class_path = os.path.join(DATASET_PATH, class_name)
        files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
        all_files[class_name] = files
        class_counts[class_name] = len(files)
        print(f"  - {class_name}: {len(files)} images")
    
    total_images = sum(class_counts.values())
    print(f"\nüìä Total Images: {total_images}")
    
    # Calculate class distribution percentages
    print("\nüìà Class Distribution:")
    for class_name, count in class_counts.items():
        percentage = (count / total_images) * 100
        print(f"  - {class_name}: {percentage:.2f}%")
else:
    print(f"\n‚ùå Error: Path {DATASET_PATH} not found!")

# ==================================================================================
# 3. VISUALIZE CLASS DISTRIBUTION
# ==================================================================================

if class_counts:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Bar plot
    axes[0].bar(class_counts.keys(), class_counts.values(), color=['#FF6B6B', '#4ECDC4'])
    axes[0].set_xlabel('Class', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Number of Images', fontsize=12, fontweight='bold')
    axes[0].set_title('Class Distribution - Bar Chart', fontsize=14, fontweight='bold')
    axes[0].grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for i, (class_name, count) in enumerate(class_counts.items()):
        axes[0].text(i, count + 50, str(count), ha='center', va='bottom', fontweight='bold')
    
    # Pie chart
    colors = ['#FF6B6B', '#4ECDC4', '#95E1D3', '#F38181', '#AA96DA']
    axes[1].pie(class_counts.values(), labels=class_counts.keys(), autopct='%1.1f%%',
                startangle=90, colors=colors[:len(class_counts)], textprops={'fontsize': 12, 'fontweight': 'bold'})
    axes[1].set_title('Class Distribution - Pie Chart', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# ==================================================================================
# 4. SAMPLE IMAGES FROM EACH CLASS
# ==================================================================================

print("\n" + "=" * 80)
print("SAMPLE IMAGES FROM EACH CLASS")
print("=" * 80)

if all_files:
    num_samples = 5
    fig, axes = plt.subplots(len(classes), num_samples, figsize=(20, 4 * len(classes)))
    
    if len(classes) == 1:
        axes = axes.reshape(1, -1)
    
    for idx, class_name in enumerate(classes):
        # Get random sample of images
        sample_files = np.random.choice(all_files[class_name], 
                                       min(num_samples, len(all_files[class_name])), 
                                       replace=False)
        
        for img_idx, img_file in enumerate(sample_files):
            img_path = os.path.join(DATASET_PATH, class_name, img_file)
            img = Image.open(img_path)
            
            axes[idx, img_idx].imshow(img)
            axes[idx, img_idx].axis('off')
            
            if img_idx == 0:
                axes[idx, img_idx].set_title(f'{class_name}\n{img.size[0]}x{img.size[1]}', 
                                            fontsize=12, fontweight='bold', loc='left')
            else:
                axes[idx, img_idx].set_title(f'{img.size[0]}x{img.size[1]}', 
                                            fontsize=10)
    
    plt.suptitle(f'{CANCER_TYPE} - Sample Images', fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.show()

# ==================================================================================
# 5. IMAGE DIMENSIONS ANALYSIS
# ==================================================================================

print("\n" + "=" * 80)
print("IMAGE DIMENSIONS ANALYSIS")
print("=" * 80)

if all_files:
    dimensions = {class_name: [] for class_name in classes}
    aspect_ratios = {class_name: [] for class_name in classes}
    file_sizes = {class_name: [] for class_name in classes}
    
    # Sample images for analysis (analyze 100 images per class for speed)
    sample_size = min(100, min([len(files) for files in all_files.values()]))
    
    print(f"\nüîç Analyzing {sample_size} images per class...")
    
    for class_name in classes:
        sample_files = np.random.choice(all_files[class_name], sample_size, replace=False)
        
        for img_file in sample_files:
            img_path = os.path.join(DATASET_PATH, class_name, img_file)
            try:
                img = Image.open(img_path)
                width, height = img.size
                dimensions[class_name].append((width, height))
                aspect_ratios[class_name].append(width / height)
                file_sizes[class_name].append(os.path.getsize(img_path) / 1024)  # KB
            except Exception as e:
                print(f"Error reading {img_file}: {e}")
    
    # Display statistics
    print("\nüìè Dimension Statistics:")
    for class_name in classes:
        widths = [d[0] for d in dimensions[class_name]]
        heights = [d[1] for d in dimensions[class_name]]
        
        print(f"\n  {class_name}:")
        print(f"    - Width:  min={min(widths)}, max={max(widths)}, mean={np.mean(widths):.1f}, std={np.std(widths):.1f}")
        print(f"    - Height: min={min(heights)}, max={max(heights)}, mean={np.mean(heights):.1f}, std={np.std(heights):.1f}")
        print(f"    - Aspect Ratio: mean={np.mean(aspect_ratios[class_name]):.3f}, std={np.std(aspect_ratios[class_name]):.3f}")
        print(f"    - File Size (KB): min={min(file_sizes[class_name]):.1f}, max={max(file_sizes[class_name]):.1f}, mean={np.mean(file_sizes[class_name]):.1f}")
    
    # Visualize dimensions distribution
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Width distribution
    for class_name in classes:
        widths = [d[0] for d in dimensions[class_name]]
        axes[0, 0].hist(widths, bins=30, alpha=0.6, label=class_name)
    axes[0, 0].set_xlabel('Width (pixels)', fontweight='bold')
    axes[0, 0].set_ylabel('Frequency', fontweight='bold')
    axes[0, 0].set_title('Image Width Distribution', fontweight='bold')
    axes[0, 0].legend()
    axes[0, 0].grid(alpha=0.3)
    
    # Height distribution
    for class_name in classes:
        heights = [d[1] for d in dimensions[class_name]]
        axes[0, 1].hist(heights, bins=30, alpha=0.6, label=class_name)
    axes[0, 1].set_xlabel('Height (pixels)', fontweight='bold')
    axes[0, 1].set_ylabel('Frequency', fontweight='bold')
    axes[0, 1].set_title('Image Height Distribution', fontweight='bold')
    axes[0, 1].legend()
    axes[0, 1].grid(alpha=0.3)
    
    # Aspect ratio distribution
    for class_name in classes:
        axes[1, 0].hist(aspect_ratios[class_name], bins=30, alpha=0.6, label=class_name)
    axes[1, 0].set_xlabel('Aspect Ratio (Width/Height)', fontweight='bold')
    axes[1, 0].set_ylabel('Frequency', fontweight='bold')
    axes[1, 0].set_title('Aspect Ratio Distribution', fontweight='bold')
    axes[1, 0].legend()
    axes[1, 0].grid(alpha=0.3)
    
    # File size distribution
    for class_name in classes:
        axes[1, 1].hist(file_sizes[class_name], bins=30, alpha=0.6, label=class_name)
    axes[1, 1].set_xlabel('File Size (KB)', fontweight='bold')
    axes[1, 1].set_ylabel('Frequency', fontweight='bold')
    axes[1, 1].set_title('File Size Distribution', fontweight='bold')
    axes[1, 1].legend()
    axes[1, 1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# ==================================================================================
# 6. COLOR CHANNEL ANALYSIS
# ==================================================================================

print("\n" + "=" * 80)
print("COLOR CHANNEL ANALYSIS")
print("=" * 80)

if all_files:
    print(f"\nüé® Analyzing color channels for {sample_size} images per class...")
    
    color_stats = {class_name: {'R': [], 'G': [], 'B': []} for class_name in classes}
    
    for class_name in classes:
        sample_files = np.random.choice(all_files[class_name], min(50, len(all_files[class_name])), replace=False)
        
        for img_file in sample_files:
            img_path = os.path.join(DATASET_PATH, class_name, img_file)
            try:
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                color_stats[class_name]['R'].append(np.mean(img[:, :, 0]))
                color_stats[class_name]['G'].append(np.mean(img[:, :, 1]))
                color_stats[class_name]['B'].append(np.mean(img[:, :, 2]))
            except Exception as e:
                continue
    
    # Visualize color channel distributions
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    channels = ['R', 'G', 'B']
    channel_colors = ['red', 'green', 'blue']
    
    for idx, (channel, color) in enumerate(zip(channels, channel_colors)):
        for class_name in classes:
            axes[idx].hist(color_stats[class_name][channel], bins=30, alpha=0.6, label=class_name)
        axes[idx].set_xlabel(f'{channel} Channel Mean Intensity', fontweight='bold')
        axes[idx].set_ylabel('Frequency', fontweight='bold')
        axes[idx].set_title(f'{channel} Channel Distribution', fontweight='bold', color=color)
        axes[idx].legend()
        axes[idx].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print color statistics
    print("\nüìä Color Channel Statistics:")
    for class_name in classes:
        print(f"\n  {class_name}:")
        for channel in channels:
            mean_val = np.mean(color_stats[class_name][channel])
            std_val = np.std(color_stats[class_name][channel])
            print(f"    - {channel} Channel: mean={mean_val:.2f}, std={std_val:.2f}")

# ==================================================================================
# 7. SUMMARY & RECOMMENDATIONS
# ==================================================================================

print("\n" + "=" * 80)
print("SUMMARY & RECOMMENDATIONS")
print("=" * 80)

if class_counts:
    print("\n‚úÖ Dataset Summary:")
    print(f"  - Total Images: {sum(class_counts.values())}")
    print(f"  - Number of Classes: {len(classes)}")
    print(f"  - Classes: {', '.join(classes)}")
    
    # Check for class imbalance
    if class_counts:
        max_count = max(class_counts.values())
        min_count = min(class_counts.values())
        imbalance_ratio = max_count / min_count if min_count > 0 else 0
        
        print(f"\n‚öñÔ∏è Class Balance:")
        if imbalance_ratio > 1.5:
            print(f"  - ‚ö†Ô∏è Class imbalance detected (ratio: {imbalance_ratio:.2f})")
            print(f"  - Recommendation: Consider using class weights or data augmentation")
        else:
            print(f"  - ‚úì Classes are relatively balanced (ratio: {imbalance_ratio:.2f})")
    
    print(f"\nüéØ Recommended Image Size for Training: 224x224 (standard for transfer learning)")
    print(f"üì¶ Recommended Batch Size: 32")
    print(f"üîÑ Data Augmentation: Recommended to improve model generalization")

print("\n" + "=" * 80)
print("EDA COMPLETED!")
print("=" * 80)

In [None]:
# ==================================================================================
# INSTALL ONNX CONVERSION TOOLS
# ==================================================================================

print("Installing ONNX conversion tools...")
!pip install -q tf2onnx onnx onnxruntime

print("‚úÖ Installation completed!")

In [None]:
# ==================================================================================
# DATA PREPARATION & PREPROCESSING
# ==================================================================================

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

print("=" * 80)
print("DATA PREPARATION - BREAST CANCER CLASSIFIER")
print("=" * 80)

# Configuration
CANCER_TYPE = 'Breast_Cancer'
DATASET_PATH = '/kaggle/input/multi-cancer/Multi Cancer/Multi Cancer/Breast Cancer'
IMG_SIZE = (224, 224)  # Standard size for EfficientNet
BATCH_SIZE = 32
VALIDATION_SPLIT = 0.2
RANDOM_SEED = 42

# Set seeds for reproducibility
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

print(f"\nüìä Configuration:")
print(f"  - Image Size: {IMG_SIZE}")
print(f"  - Batch Size: {BATCH_SIZE}")
print(f"  - Validation Split: {VALIDATION_SPLIT * 100}%")
print(f"  - Random Seed: {RANDOM_SEED}")

# ==================================================================================
# DATA AUGMENTATION
# ==================================================================================

print("\n" + "=" * 80)
print("DATA AUGMENTATION SETUP")
print("=" * 80)

# Training data augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.15,
    shear_range=0.15,
    fill_mode='nearest',
    validation_split=VALIDATION_SPLIT
)

# Validation data - only rescaling
val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=VALIDATION_SPLIT
)

print("‚úÖ Augmentation strategies:")
print("  - Rotation: ¬±20¬∞")
print("  - Horizontal & Vertical Flip")
print("  - Width/Height Shift: 20%")
print("  - Zoom: 15%")
print("  - Shear: 15%")

# ==================================================================================
# DATA GENERATORS
# ==================================================================================

print("\n" + "=" * 80)
print("CREATING DATA GENERATORS")
print("=" * 80)

# Training generator
train_generator = train_datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='training',
    shuffle=True,
    seed=RANDOM_SEED
)

# Validation generator
validation_generator = val_datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='validation',
    shuffle=False,
    seed=RANDOM_SEED
)

print(f"\n‚úÖ Data Generators Created:")
print(f"  - Training samples: {train_generator.samples}")
print(f"  - Validation samples: {validation_generator.samples}")
print(f"  - Classes: {train_generator.class_indices}")
print(f"  - Steps per epoch (train): {train_generator.samples // BATCH_SIZE}")
print(f"  - Validation steps: {validation_generator.samples // BATCH_SIZE}")

# ==================================================================================
# VISUALIZE AUGMENTED SAMPLES
# ==================================================================================

print("\n" + "=" * 80)
print("SAMPLE AUGMENTED IMAGES")
print("=" * 80)

# Get a batch of augmented images
sample_batch = train_generator.next()
sample_images = sample_batch[0][:6]
sample_labels = sample_batch[1][:6]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

class_names = {v: k for k, v in train_generator.class_indices.items()}

for idx in range(6):
    axes[idx].imshow(sample_images[idx])
    label = "Malignant" if sample_labels[idx] > 0.5 else "Benign"
    axes[idx].set_title(f'{label}', fontsize=12, fontweight='bold')
    axes[idx].axis('off')

plt.suptitle('Sample Augmented Training Images', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n‚úÖ Data preparation completed!")
print("Ready for model training...")

In [None]:
# ==================================================================================
# VISUALIZE AUGMENTED SAMPLES
# ==================================================================================

print("\n" + "=" * 80)
print("SAMPLE AUGMENTED IMAGES")
print("=" * 80)

# Get a batch of augmented images (fixed method)
sample_batch = next(train_generator)
sample_images = sample_batch[0][:6]
sample_labels = sample_batch[1][:6]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

class_names = {v: k for k, v in train_generator.class_indices.items()}

for idx in range(6):
    axes[idx].imshow(sample_images[idx])
    label = "Malignant" if sample_labels[idx] > 0.5 else "Benign"
    axes[idx].set_title(f'{label}', fontsize=12, fontweight='bold')
    axes[idx].axis('off')

plt.suptitle('Sample Augmented Training Images', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n‚úÖ Data preparation completed!")
print("Ready for model training...")

In [None]:
# ==================================================================================
# MODEL BUILDING - EFFICIENTNET WITH TRANSFER LEARNING
# ==================================================================================

from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

print("=" * 80)
print("MODEL BUILDING - BREAST CANCER CLASSIFIER")
print("=" * 80)

# ==================================================================================
# BUILD MODEL
# ==================================================================================

def build_breast_cancer_model(img_size=(224, 224, 3)):
    """
    Build a binary classification model using EfficientNetB3 with transfer learning
    """
    print("\nüèóÔ∏è Building model architecture...")
    
    # Load pre-trained EfficientNetB3
    base_model = EfficientNetB3(
        include_top=False,
        weights='imagenet',
        input_shape=img_size,
        pooling='avg'
    )
    
    # Freeze base model layers initially
    base_model.trainable = False
    
    # Build the model
    inputs = layers.Input(shape=img_size, name='input_layer')
    
    # Base model
    x = base_model(inputs, training=False)
    
    # Classification head
    x = layers.Dropout(0.3, name='dropout_1')(x)
    x = layers.Dense(256, activation='relu', name='dense_1')(x)
    x = layers.BatchNormalization(name='batch_norm_1')(x)
    x = layers.Dropout(0.3, name='dropout_2')(x)
    x = layers.Dense(128, activation='relu', name='dense_2')(x)
    x = layers.BatchNormalization(name='batch_norm_2')(x)
    x = layers.Dropout(0.2, name='dropout_3')(x)
    
    # Output layer (sigmoid for binary classification)
    outputs = layers.Dense(1, activation='sigmoid', name='output')(x)
    
    model = models.Model(inputs=inputs, outputs=outputs, name='breast_cancer_classifier')
    
    return model, base_model

# Build the model
model, base_model = build_breast_cancer_model(img_size=(224, 224, 3))

print(f"\n‚úÖ Model built successfully!")
print(f"  - Total layers: {len(model.layers)}")
print(f"  - Trainable layers: {sum([1 for layer in model.layers if layer.trainable])}")

# ==================================================================================
# COMPILE MODEL
# ==================================================================================

print("\n" + "=" * 80)
print("COMPILING MODEL")
print("=" * 80)

# Compile with initial settings
model.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3),
    loss='binary_crossentropy',
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]
)

print("‚úÖ Model compiled with:")
print("  - Optimizer: Adam (lr=0.001)")
print("  - Loss: Binary Crossentropy")
print("  - Metrics: Accuracy, Precision, Recall, AUC")

# Display model summary
print("\nüìã Model Summary:")
model.summary()

# ==================================================================================
# CALLBACKS
# ==================================================================================

print("\n" + "=" * 80)
print("SETTING UP CALLBACKS")
print("=" * 80)

# Early stopping
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

# Reduce learning rate on plateau
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

# Model checkpoint
checkpoint = ModelCheckpoint(
    'best_breast_cancer_model.h5',
    monitor='val_accuracy',
    save_best_only=True,
    verbose=1
)

callbacks = [early_stop, reduce_lr, checkpoint]

print("‚úÖ Callbacks configured:")
print("  - Early Stopping (patience=5)")
print("  - Reduce LR on Plateau (patience=3, factor=0.5)")
print("  - Model Checkpoint (save best model)")

print("\n‚úÖ Ready for training!")

In [None]:
# ==================================================================================
# MODEL TRAINING - PHASE 1: TRAIN WITH FROZEN BASE
# ==================================================================================

print("=" * 80)
print("TRAINING PHASE 1: FROZEN BASE MODEL")
print("=" * 80)

# Training configuration
EPOCHS_PHASE1 = 8

print(f"\nüéØ Training Configuration:")
print(f"  - Epochs: {EPOCHS_PHASE1}")
print(f"  - Training samples: {train_generator.samples}")
print(f"  - Validation samples: {validation_generator.samples}")
print(f"  - Steps per epoch: {train_generator.samples // BATCH_SIZE}")
print(f"  - Validation steps: {validation_generator.samples // BATCH_SIZE}")

print("\nüöÄ Starting training...")
print("-" * 80)

# Train the model
history_phase1 = model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=EPOCHS_PHASE1,
    callbacks=callbacks,
    verbose=1
)

print("\n‚úÖ Phase 1 training completed!")

# ==================================================================================
# TRAINING RESULTS - PHASE 1
# ==================================================================================

print("\n" + "=" * 80)
print("PHASE 1 TRAINING RESULTS")
print("=" * 80)

# Get final metrics
final_train_acc = history_phase1.history['accuracy'][-1]
final_val_acc = history_phase1.history['val_accuracy'][-1]
final_train_loss = history_phase1.history['loss'][-1]
final_val_loss = history_phase1.history['val_loss'][-1]

print(f"\nüìä Final Metrics:")
print(f"  - Training Accuracy: {final_train_acc:.4f}")
print(f"  - Validation Accuracy: {final_val_acc:.4f}")
print(f"  - Training Loss: {final_train_loss:.4f}")
print(f"  - Validation Loss: {final_val_loss:.4f}")

# Best metrics
best_val_acc = max(history_phase1.history['val_accuracy'])
best_epoch = history_phase1.history['val_accuracy'].index(best_val_acc) + 1

print(f"\nüèÜ Best Performance:")
print(f"  - Best Validation Accuracy: {best_val_acc:.4f}")
print(f"  - Best Epoch: {best_epoch}")

# ==================================================================================
# VISUALIZE TRAINING HISTORY - PHASE 1
# ==================================================================================

print("\n" + "=" * 80)
print("TRAINING HISTORY VISUALIZATION")
print("=" * 80)

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Accuracy plot
axes[0, 0].plot(history_phase1.history['accuracy'], label='Training Accuracy', linewidth=2)
axes[0, 0].plot(history_phase1.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontweight='bold')
axes[0, 0].set_ylabel('Accuracy', fontweight='bold')
axes[0, 0].set_title('Model Accuracy - Phase 1', fontweight='bold', fontsize=14)
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# Loss plot
axes[0, 1].plot(history_phase1.history['loss'], label='Training Loss', linewidth=2)
axes[0, 1].plot(history_phase1.history['val_loss'], label='Validation Loss', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontweight='bold')
axes[0, 1].set_ylabel('Loss', fontweight='bold')
axes[0, 1].set_title('Model Loss - Phase 1', fontweight='bold', fontsize=14)
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# Precision plot
axes[1, 0].plot(history_phase1.history['precision'], label='Training Precision', linewidth=2)
axes[1, 0].plot(history_phase1.history['val_precision'], label='Validation Precision', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontweight='bold')
axes[1, 0].set_ylabel('Precision', fontweight='bold')
axes[1, 0].set_title('Model Precision - Phase 1', fontweight='bold', fontsize=14)
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# Recall plot
axes[1, 1].plot(history_phase1.history['recall'], label='Training Recall', linewidth=2)
axes[1, 1].plot(history_phase1.history['val_recall'], label='Validation Recall', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontweight='bold')
axes[1, 1].set_ylabel('Recall', fontweight='bold')
axes[1, 1].set_title('Model Recall - Phase 1', fontweight='bold', fontsize=14)
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.suptitle('Breast Cancer Classifier - Training History (Phase 1)', 
             fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

print("\n‚úÖ Training visualization completed!")

In [None]:
# ==================================================================================
# MODEL TRAINING - PHASE 2: FINE-TUNING
# ==================================================================================

print("=" * 80)
print("TRAINING PHASE 2: FINE-TUNING")
print("=" * 80)

# ==================================================================================
# UNFREEZE BASE MODEL LAYERS
# ==================================================================================

print("\nüîì Unfreezing base model layers for fine-tuning...")

# Unfreeze the base model
base_model.trainable = True

# Freeze the first 80% of layers, fine-tune the last 20%
fine_tune_at = int(len(base_model.layers) * 0.8)

# Freeze all layers before fine_tune_at
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

# Count trainable parameters
trainable_params = sum([tf.size(w).numpy() for w in model.trainable_weights])
non_trainable_params = sum([tf.size(w).numpy() for w in model.non_trainable_weights])

print(f"\n‚úÖ Base model partially unfrozen:")
print(f"  - Total base layers: {len(base_model.layers)}")
print(f"  - Frozen layers: {fine_tune_at}")
print(f"  - Trainable layers: {len(base_model.layers) - fine_tune_at}")
print(f"  - Trainable parameters: {trainable_params:,}")
print(f"  - Non-trainable parameters: {non_trainable_params:,}")

# ==================================================================================
# RECOMPILE WITH LOWER LEARNING RATE
# ==================================================================================

print("\n" + "=" * 80)
print("RECOMPILING MODEL")
print("=" * 80)

# Compile with lower learning rate for fine-tuning
model.compile(
    optimizer=optimizers.Adam(learning_rate=1e-5),  # Much lower learning rate
    loss='binary_crossentropy',
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]
)

print("‚úÖ Model recompiled with:")
print("  - Optimizer: Adam (lr=0.00001)")
print("  - Loss: Binary Crossentropy")
print("  - Metrics: Accuracy, Precision, Recall, AUC")

# ==================================================================================
# TRAIN PHASE 2
# ==================================================================================

print("\n" + "=" * 80)
print("STARTING FINE-TUNING")
print("=" * 80)

EPOCHS_PHASE2 = 10
total_epochs = EPOCHS_PHASE1 + EPOCHS_PHASE2

print(f"\nüéØ Fine-tuning Configuration:")
print(f"  - Additional Epochs: {EPOCHS_PHASE2}")
print(f"  - Total Epochs: {total_epochs}")
print(f"  - Learning Rate: 0.00001")

print("\nüöÄ Starting fine-tuning...")
print("-" * 80)

# Continue training
history_phase2 = model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=total_epochs,
    initial_epoch=len(history_phase1.history['accuracy']),
    callbacks=callbacks,
    verbose=1
)

print("\n‚úÖ Phase 2 fine-tuning completed!")

# ==================================================================================
# COMBINE TRAINING HISTORIES
# ==================================================================================

print("\n" + "=" * 80)
print("COMBINING TRAINING HISTORIES")
print("=" * 80)

# Combine both training phases
combined_history = {
    'accuracy': history_phase1.history['accuracy'] + history_phase2.history['accuracy'],
    'val_accuracy': history_phase1.history['val_accuracy'] + history_phase2.history['val_accuracy'],
    'loss': history_phase1.history['loss'] + history_phase2.history['loss'],
    'val_loss': history_phase1.history['val_loss'] + history_phase2.history['val_loss'],
    'precision': history_phase1.history['precision'] + history_phase2.history['precision'],
    'val_precision': history_phase1.history['val_precision'] + history_phase2.history['val_precision'],
    'recall': history_phase1.history['recall'] + history_phase2.history['recall'],
    'val_recall': history_phase1.history['val_recall'] + history_phase2.history['val_recall'],
    'auc': history_phase1.history['auc'] + history_phase2.history['auc'],
    'val_auc': history_phase1.history['val_auc'] + history_phase2.history['val_auc']
}

# ==================================================================================
# FINAL TRAINING RESULTS
# ==================================================================================

print("\n" + "=" * 80)
print("COMPLETE TRAINING RESULTS")
print("=" * 80)

# Get final metrics
final_train_acc = combined_history['accuracy'][-1]
final_val_acc = combined_history['val_accuracy'][-1]
final_train_loss = combined_history['loss'][-1]
final_val_loss = combined_history['val_loss'][-1]
final_val_precision = combined_history['val_precision'][-1]
final_val_recall = combined_history['val_recall'][-1]
final_val_auc = combined_history['val_auc'][-1]

print(f"\nüìä Final Metrics:")
print(f"  - Training Accuracy: {final_train_acc:.4f}")
print(f"  - Validation Accuracy: {final_val_acc:.4f}")
print(f"  - Training Loss: {final_train_loss:.4f}")
print(f"  - Validation Loss: {final_val_loss:.4f}")
print(f"  - Validation Precision: {final_val_precision:.4f}")
print(f"  - Validation Recall: {final_val_recall:.4f}")
print(f"  - Validation AUC: {final_val_auc:.4f}")

# Calculate F1 Score
f1_score = 2 * (final_val_precision * final_val_recall) / (final_val_precision + final_val_recall)
print(f"  - Validation F1-Score: {f1_score:.4f}")

# Best metrics
best_val_acc = max(combined_history['val_accuracy'])
best_epoch = combined_history['val_accuracy'].index(best_val_acc) + 1

print(f"\nüèÜ Best Performance:")
print(f"  - Best Validation Accuracy: {best_val_acc:.4f}")
print(f"  - Best Epoch: {best_epoch}")

# ==================================================================================
# VISUALIZE COMPLETE TRAINING HISTORY
# ==================================================================================

print("\n" + "=" * 80)
print("COMPLETE TRAINING HISTORY VISUALIZATION")
print("=" * 80)

fig, axes = plt.subplots(2, 3, figsize=(20, 12))

# Accuracy plot
axes[0, 0].plot(combined_history['accuracy'], label='Training Accuracy', linewidth=2)
axes[0, 0].plot(combined_history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[0, 0].axvline(x=EPOCHS_PHASE1, color='r', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[0, 0].set_xlabel('Epoch', fontweight='bold')
axes[0, 0].set_ylabel('Accuracy', fontweight='bold')
axes[0, 0].set_title('Model Accuracy - Complete Training', fontweight='bold', fontsize=14)
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# Loss plot
axes[0, 1].plot(combined_history['loss'], label='Training Loss', linewidth=2)
axes[0, 1].plot(combined_history['val_loss'], label='Validation Loss', linewidth=2)
axes[0, 1].axvline(x=EPOCHS_PHASE1, color='r', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[0, 1].set_xlabel('Epoch', fontweight='bold')
axes[0, 1].set_ylabel('Loss', fontweight='bold')
axes[0, 1].set_title('Model Loss - Complete Training', fontweight='bold', fontsize=14)
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# Precision plot
axes[0, 2].plot(combined_history['precision'], label='Training Precision', linewidth=2)
axes[0, 2].plot(combined_history['val_precision'], label='Validation Precision', linewidth=2)
axes[0, 2].axvline(x=EPOCHS_PHASE1, color='r', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[0, 2].set_xlabel('Epoch', fontweight='bold')
axes[0, 2].set_ylabel('Precision', fontweight='bold')
axes[0, 2].set_title('Model Precision - Complete Training', fontweight='bold', fontsize=14)
axes[0, 2].legend()
axes[0, 2].grid(alpha=0.3)

# Recall plot
axes[1, 0].plot(combined_history['recall'], label='Training Recall', linewidth=2)
axes[1, 0].plot(combined_history['val_recall'], label='Validation Recall', linewidth=2)
axes[1, 0].axvline(x=EPOCHS_PHASE1, color='r', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[1, 0].set_xlabel('Epoch', fontweight='bold')
axes[1, 0].set_ylabel('Recall', fontweight='bold')
axes[1, 0].set_title('Model Recall - Complete Training', fontweight='bold', fontsize=14)
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# AUC plot
axes[1, 1].plot(combined_history['auc'], label='Training AUC', linewidth=2)
axes[1, 1].plot(combined_history['val_auc'], label='Validation AUC', linewidth=2)
axes[1, 1].axvline(x=EPOCHS_PHASE1, color='r', linestyle='--', label='Fine-tuning starts', alpha=0.7)
axes[1, 1].set_xlabel('Epoch', fontweight='bold')
axes[1, 1].set_ylabel('AUC', fontweight='bold')
axes[1, 1].set_title('Model AUC - Complete Training', fontweight='bold', fontsize=14)
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

# Summary metrics
axes[1, 2].axis('off')
summary_text = f"""
FINAL PERFORMANCE SUMMARY

Validation Metrics:
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
Accuracy:    {final_val_acc:.4f}
Precision:   {final_val_precision:.4f}
Recall:      {final_val_recall:.4f}
F1-Score:    {f1_score:.4f}
AUC:         {final_val_auc:.4f}

Best Validation Accuracy:
{best_val_acc:.4f} (Epoch {best_epoch})

Training Configuration:
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
Phase 1: {EPOCHS_PHASE1} epochs
Phase 2: {EPOCHS_PHASE2} epochs
Total: {total_epochs} epochs
"""
axes[1, 2].text(0.1, 0.5, summary_text, fontsize=12, family='monospace',
                verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.suptitle('Breast Cancer Classifier - Complete Training History', 
             fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

print("\n‚úÖ Complete training visualization finished!")
print("=" * 80)

In [None]:
# ==================================================================================
# MODEL EVALUATION
# ==================================================================================

from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import seaborn as sns

print("=" * 80)
print("MODEL EVALUATION")
print("=" * 80)

# ==================================================================================
# PREDICTIONS ON VALIDATION SET
# ==================================================================================

print("\nüîç Generating predictions on validation set...")

# Reset validation generator
validation_generator.reset()

# Get predictions
y_pred_probs = model.predict(validation_generator, verbose=1)
y_pred = (y_pred_probs > 0.5).astype(int).flatten()

# Get true labels
y_true = validation_generator.classes

print(f"\n‚úÖ Predictions generated:")
print(f"  - Total samples: {len(y_true)}")
print(f"  - Predicted Benign: {np.sum(y_pred == 0)}")
print(f"  - Predicted Malignant: {np.sum(y_pred == 1)}")

# ==================================================================================
# CONFUSION MATRIX
# ==================================================================================

print("\n" + "=" * 80)
print("CONFUSION MATRIX")
print("=" * 80)

# Calculate confusion matrix
cm = confusion_matrix(y_true, y_pred)

print("\nConfusion Matrix:")
print(cm)

# Calculate metrics from confusion matrix
tn, fp, fn, tp = cm.ravel()

print(f"\nBreakdown:")
print(f"  - True Negatives (TN): {tn}")
print(f"  - False Positives (FP): {fp}")
print(f"  - False Negatives (FN): {fn}")
print(f"  - True Positives (TP): {tp}")

# Calculate additional metrics
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0  # Positive Predictive Value
npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative Predictive Value

print(f"\nAdditional Metrics:")
print(f"  - Sensitivity (Recall): {sensitivity:.4f}")
print(f"  - Specificity: {specificity:.4f}")
print(f"  - Positive Predictive Value (Precision): {ppv:.4f}")
print(f"  - Negative Predictive Value: {npv:.4f}")

# Visualize confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Confusion matrix - raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Benign', 'Malignant'],
            yticklabels=['Benign', 'Malignant'],
            ax=axes[0], cbar_kws={'label': 'Count'})
axes[0].set_xlabel('Predicted Label', fontweight='bold', fontsize=12)
axes[0].set_ylabel('True Label', fontweight='bold', fontsize=12)
axes[0].set_title('Confusion Matrix - Raw Counts', fontweight='bold', fontsize=14)

# Confusion matrix - normalized
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=['Benign', 'Malignant'],
            yticklabels=['Benign', 'Malignant'],
            ax=axes[1], cbar_kws={'label': 'Percentage'})
axes[1].set_xlabel('Predicted Label', fontweight='bold', fontsize=12)
axes[1].set_ylabel('True Label', fontweight='bold', fontsize=12)
axes[1].set_title('Confusion Matrix - Normalized', fontweight='bold', fontsize=14)

plt.tight_layout()
plt.show()

# ==================================================================================
# CLASSIFICATION REPORT
# ==================================================================================

print("\n" + "=" * 80)
print("CLASSIFICATION REPORT")
print("=" * 80)

# Get class names
class_names = ['Benign', 'Malignant']

# Generate classification report
report = classification_report(y_true, y_pred, target_names=class_names, digits=4)
print("\n" + report)

# ==================================================================================
# ROC CURVE AND AUC
# ==================================================================================

print("\n" + "=" * 80)
print("ROC CURVE")
print("=" * 80)

# Calculate ROC curve
fpr, tpr, thresholds = roc_curve(y_true, y_pred_probs)
roc_auc = auc(fpr, tpr)

print(f"\n‚úÖ ROC AUC Score: {roc_auc:.4f}")

# Plot ROC curve
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontweight='bold', fontsize=12)
plt.ylabel('True Positive Rate', fontweight='bold', fontsize=12)
plt.title('Receiver Operating Characteristic (ROC) Curve', fontweight='bold', fontsize=14)
plt.legend(loc="lower right", fontsize=12)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

# ==================================================================================
# PREDICTION SAMPLES
# ==================================================================================

print("\n" + "=" * 80)
print("SAMPLE PREDICTIONS")
print("=" * 80)

# Reset validation generator
validation_generator.reset()

# Get a batch
sample_batch = next(validation_generator)
sample_images = sample_batch[0][:6]
sample_labels = sample_batch[1][:6]

# Make predictions
sample_predictions = model.predict(sample_images, verbose=0)

# Visualize
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

for idx in range(6):
    axes[idx].imshow(sample_images[idx])
    
    true_label = "Malignant" if sample_labels[idx] > 0.5 else "Benign"
    pred_label = "Malignant" if sample_predictions[idx] > 0.5 else "Benign"
    confidence = sample_predictions[idx][0] if sample_predictions[idx] > 0.5 else 1 - sample_predictions[idx][0]
    
    # Color code: green for correct, red for incorrect
    color = 'green' if true_label == pred_label else 'red'
    
    title = f'True: {true_label}\nPred: {pred_label} ({confidence:.2%})'
    axes[idx].set_title(title, fontsize=11, fontweight='bold', color=color)
    axes[idx].axis('off')

plt.suptitle('Sample Predictions (Green = Correct, Red = Incorrect)', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# ==================================================================================
# EVALUATION SUMMARY
# ==================================================================================

print("\n" + "=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)

print(f"""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë           BREAST CANCER CLASSIFIER - FINAL RESULTS         ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë  Validation Accuracy:     {final_val_acc:.4f} ({final_val_acc*100:.2f}%)              ‚ïë
‚ïë  Validation Precision:    {final_val_precision:.4f} ({final_val_precision*100:.2f}%)              ‚ïë
‚ïë  Validation Recall:       {final_val_recall:.4f} ({final_val_recall*100:.2f}%)              ‚ïë
‚ïë  Validation F1-Score:     {f1_score:.4f} ({f1_score*100:.2f}%)              ‚ïë
‚ïë  ROC AUC Score:           {roc_auc:.4f} ({roc_auc*100:.2f}%)              ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë  Sensitivity (Recall):    {sensitivity:.4f} ({sensitivity*100:.2f}%)              ‚ïë
‚ïë  Specificity:             {specificity:.4f} ({specificity*100:.2f}%)              ‚ïë
‚ïë  Positive Predictive Val: {ppv:.4f} ({ppv*100:.2f}%)              ‚ïë
‚ïë  Negative Predictive Val: {npv:.4f} ({npv*100:.2f}%)              ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë  Total Training Epochs:   {total_epochs}                                  ‚ïë
‚ïë  Best Epoch:              {best_epoch}                                  ‚ïë
‚ïë  Training Samples:        {train_generator.samples}                               ‚ïë
‚ïë  Validation Samples:      {validation_generator.samples}                               ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
""")

print("‚úÖ Evaluation completed!")
print("=" * 80)

In [None]:
# ==================================================================================
# ONNX CONVERSION - BREAST CANCER CLASSIFIER
# ==================================================================================

import tensorflow as tf
import numpy as np
import os
import tf2onnx
import onnx
import onnxruntime as ort

print("=" * 80)
print("ONNX MODEL CONVERSION")
print("=" * 80)

# ==================================================================================
# STEP 1: SAVE KERAS MODEL
# ==================================================================================

print("\nüì¶ Step 1: Saving Keras model...")

# Save the final trained model
model.save('breast_cancer_classifier.h5')
print("‚úÖ Keras model saved as 'breast_cancer_classifier.h5'")

# Also save in SavedModel format
tf.saved_model.save(model, 'breast_cancer_savedmodel')
print("‚úÖ SavedModel saved to 'breast_cancer_savedmodel/'")

# ==================================================================================
# STEP 2: CONVERT TO ONNX
# ==================================================================================

print("\n" + "=" * 80)
print("CONVERTING TO ONNX FORMAT")
print("=" * 80)

try:
    print("\nüîÑ Converting model to ONNX...")
    
    # Define input specification
    spec = tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')
    
    # Convert to ONNX
    model_proto, _ = tf2onnx.convert.from_keras(
        model,
        input_signature=[spec],
        opset=13,
        output_path='breast_cancer_classifier.onnx'
    )
    
    print("‚úÖ ONNX conversion successful!")
    
    # Check file size
    if os.path.exists('breast_cancer_classifier.onnx'):
        file_size = os.path.getsize('breast_cancer_classifier.onnx') / (1024 * 1024)
        print(f"üìä ONNX model size: {file_size:.2f} MB")
    
except Exception as e:
    print(f"‚ùå Error during ONNX conversion: {str(e)}")
    print("\nTrying alternative method...")
    
    # Alternative method: Convert from SavedModel
    try:
        import subprocess
        result = subprocess.run([
            'python', '-m', 'tf2onnx.convert',
            '--saved-model', 'breast_cancer_savedmodel',
            '--output', 'breast_cancer_classifier.onnx',
            '--opset', '13'
        ], capture_output=True, text=True)
        
        print(result.stdout)
        if result.returncode == 0:
            print("‚úÖ Alternative method successful!")
        else:
            print(result.stderr)
    except Exception as e2:
        print(f"‚ùå Alternative method failed: {str(e2)}")

# ==================================================================================
# STEP 3: VALIDATE ONNX MODEL
# ==================================================================================

print("\n" + "=" * 80)
print("VALIDATING ONNX MODEL")
print("=" * 80)

if os.path.exists('breast_cancer_classifier.onnx'):
    try:
        print("\nüîç Loading ONNX model...")
        
        # Load ONNX model
        onnx_model = onnx.load('breast_cancer_classifier.onnx')
        
        # Check model
        onnx.checker.check_model(onnx_model)
        print("‚úÖ ONNX model is valid!")
        
        # Print model info
        print(f"\nüìã Model Information:")
        print(f"  - IR Version: {onnx_model.ir_version}")
        print(f"  - Producer: {onnx_model.producer_name}")
        print(f"  - Opset Version: {onnx_model.opset_import[0].version}")
        
        # Get input/output info
        print(f"\nüì• Input Information:")
        for input_tensor in onnx_model.graph.input:
            print(f"  - Name: {input_tensor.name}")
            shape = [dim.dim_value if dim.dim_value > 0 else 'dynamic' 
                    for dim in input_tensor.type.tensor_type.shape.dim]
            print(f"  - Shape: {shape}")
            print(f"  - Type: {input_tensor.type.tensor_type.elem_type}")
        
        print(f"\nüì§ Output Information:")
        for output_tensor in onnx_model.graph.output:
            print(f"  - Name: {output_tensor.name}")
            shape = [dim.dim_value if dim.dim_value > 0 else 'dynamic' 
                    for dim in output_tensor.type.tensor_type.shape.dim]
            print(f"  - Shape: {shape}")
        
    except Exception as e:
        print(f"‚ùå Error validating ONNX model: {str(e)}")

# ==================================================================================
# STEP 4: TEST ONNX INFERENCE
# ==================================================================================

print("\n" + "=" * 80)
print("TESTING ONNX INFERENCE")
print("=" * 80)

if os.path.exists('breast_cancer_classifier.onnx'):
    try:
        print("\nüß™ Running inference test...")
        
        # Create ONNX Runtime session
        ort_session = ort.InferenceSession('breast_cancer_classifier.onnx')
        
        # Get input name
        input_name = ort_session.get_inputs()[0].name
        output_name = ort_session.get_outputs()[0].name
        
        print(f"  - Input name: {input_name}")
        print(f"  - Output name: {output_name}")
        
        # Create test input
        test_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
        
        # Run inference with ONNX
        onnx_output = ort_session.run([output_name], {input_name: test_input})
        
        # Run inference with Keras
        keras_output = model.predict(test_input, verbose=0)
        
        print(f"\nüìä Inference Comparison:")
        print(f"  - Keras output: {keras_output[0][0]:.6f}")
        print(f"  - ONNX output:  {onnx_output[0][0][0]:.6f}")
        print(f"  - Difference:   {abs(keras_output[0][0] - onnx_output[0][0][0]):.8f}")
        
        # Check if outputs match
        if abs(keras_output[0][0] - onnx_output[0][0][0]) < 0.01:
            print("\n‚úÖ SUCCESS! ONNX model matches Keras model!")
        else:
            print("\n‚ö†Ô∏è Warning: Small difference detected between models")
        
        # Test with real validation image
        print("\nüñºÔ∏è Testing with real validation image...")
        validation_generator.reset()
        real_image = next(validation_generator)[0][0:1]  # Get first image
        
        # Keras prediction
        keras_pred = model.predict(real_image, verbose=0)
        
        # ONNX prediction
        onnx_pred = ort_session.run([output_name], {input_name: real_image})
        
        print(f"\nüìä Real Image Test:")
        print(f"  - Keras prediction: {keras_pred[0][0]:.6f} ({'Malignant' if keras_pred[0][0] > 0.5 else 'Benign'})")
        print(f"  - ONNX prediction:  {onnx_pred[0][0][0]:.6f} ({'Malignant' if onnx_pred[0][0][0] > 0.5 else 'Benign'})")
        print(f"  - Match: {'‚úÖ Yes' if abs(keras_pred[0][0] - onnx_pred[0][0][0]) < 0.01 else '‚ùå No'}")
        
    except Exception as e:
        print(f"‚ùå Error during inference test: {str(e)}")