# Breast Cancer Detection from Multi-Cancer Dataset

This notebook builds a basic model to detect breast cancer from the Multi-Cancer dataset on Kaggle.

Dataset: https://www.kaggle.com/datasets/obulisainaren/multi-cancer/data


## 1. Setup and Install Dependencies


In [None]:
# Install required packages
!pip install -q kaggle
!pip install -q tensorflow
!pip install -q pillow
!pip install -q matplotlib
!pip install -q scikit-learn
!pip install -q seaborn


## 2. Setup Kaggle API Credentials

**Instructions:**
1. Go to your Kaggle account settings: https://www.kaggle.com/settings
2. Scroll to 'API' section and click 'Create New API Token'
3. This will download `kaggle.json`
4. Upload it using the file upload in the next cell


In [None]:
# Upload kaggle.json file
from google.colab import files
files.upload()


In [None]:
# Setup Kaggle credentials
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


## 3. Download Dataset


In [None]:
# Download the dataset
!kaggle datasets download -d obulisainaren/multi-cancer
!unzip -q multi-cancer.zip -d dataset
!ls dataset


## 4. Import Libraries


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")


## 5. Explore Dataset Structure


In [None]:
# Explore the dataset structure
dataset_path = 'dataset'

for root, dirs, files in os.walk(dataset_path):
    level = root.replace(dataset_path, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    if level < 2:  # Only show first 2 levels
        subindent = ' ' * 2 * (level + 1)
        for file in files[:5]:  # Show first 5 files
            print(f"{subindent}{file}")
        if len(files) > 5:
            print(f"{subindent}... and {len(files) - 5} more files")


## 6. Data Preparation


In [None]:
# Define paths - adjust these based on actual dataset structure
# Common structures: dataset/train/breast or dataset/breast/train

# Try to find breast cancer images
def find_breast_cancer_path(base_path):
    for root, dirs, files in os.walk(base_path):
        if 'breast' in root.lower():
            print(f"Found breast cancer data at: {root}")
            return root
    return None

breast_path = find_breast_cancer_path(dataset_path)

# Count images in each category
if breast_path:
    categories = os.listdir(breast_path)
    print(f"\nCategories found: {categories}")
    
    for category in categories:
        cat_path = os.path.join(breast_path, category)
        if os.path.isdir(cat_path):
            num_images = len([f for f in os.listdir(cat_path) if f.endswith(('.jpg', '.jpeg', '.png'))])
            print(f"{category}: {num_images} images")
else:
    print("Breast cancer path not found. Please check dataset structure.")


In [None]:
# Set parameters
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001

# If dataset has train/test split, use it. Otherwise, we'll create our own
# Adjust the data_dir based on your exploration above
data_dir = breast_path if breast_path else 'dataset'  # Update this path based on exploration


## 7. Visualize Sample Images


In [None]:
# Visualize some sample images
def plot_sample_images(data_dir, num_samples=9):
    plt.figure(figsize=(12, 12))
    
    categories = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    for idx, category in enumerate(categories[:num_samples]):
        cat_path = os.path.join(data_dir, category)
        images = [f for f in os.listdir(cat_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
        
        if images:
            img_path = os.path.join(cat_path, images[0])
            img = Image.open(img_path)
            
            plt.subplot(3, 3, idx + 1)
            plt.imshow(img)
            plt.title(category)
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

if breast_path:
    plot_sample_images(breast_path)


## 8. Create Data Generators


In [None]:
# Data augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    zoom_range=0.2,
    shear_range=0.2,
    validation_split=0.2  # 80-20 split
)

# Only rescaling for validation
val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# Create generators
train_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

validation_generator = val_datagen.flow_from_directory(
    data_dir,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

# Print class indices
print(f"\nClass indices: {train_generator.class_indices}")
print(f"Number of training samples: {train_generator.samples}")
print(f"Number of validation samples: {validation_generator.samples}")
print(f"Number of classes: {len(train_generator.class_indices)}")

## 9. Build the Model

In [None]:
# Build model using transfer learning with EfficientNetB0
def create_model(num_classes):
    # Load pre-trained EfficientNetB0
    base_model = EfficientNetB0(
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze base model initially
    base_model.trainable = False
    
    # Create model
    model = keras.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.2),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model, base_model

# Create the model
num_classes = len(train_generator.class_indices)
model, base_model = create_model(num_classes)

# Compile model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy', keras.metrics.Precision(), keras.metrics.Recall()]
)

# Print model summary
model.summary()

## 10. Setup Callbacks


In [None]:
# Callbacks
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

callbacks = [early_stopping, reduce_lr]


## 11. Train the Model


In [None]:
# Train the model
history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator,
    callbacks=callbacks,
    verbose=1
)


## 12. Plot Training History


In [None]:
# Plot training history
def plot_history(history):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Accuracy
    axes[0, 0].plot(history.history['accuracy'], label='Train Accuracy')
    axes[0, 0].plot(history.history['val_accuracy'], label='Val Accuracy')
    axes[0, 0].set_title('Model Accuracy')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Loss
    axes[0, 1].plot(history.history['loss'], label='Train Loss')
    axes[0, 1].plot(history.history['val_loss'], label='Val Loss')
    axes[0, 1].set_title('Model Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Precision
    axes[1, 0].plot(history.history['precision'], label='Train Precision')
    axes[1, 0].plot(history.history['val_precision'], label='Val Precision')
    axes[1, 0].set_title('Model Precision')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Recall
    axes[1, 1].plot(history.history['recall'], label='Train Recall')
    axes[1, 1].plot(history.history['val_recall'], label='Val Recall')
    axes[1, 1].set_title('Model Recall')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Recall')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()

plot_history(history)


## 13. Evaluate the Model


In [None]:
# Evaluate on validation set
val_loss, val_accuracy, val_precision, val_recall = model.evaluate(validation_generator)

print(f"\n=== Model Evaluation ===")
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Validation Precision: {val_precision:.4f}")
print(f"Validation Recall: {val_recall:.4f}")

# Calculate F1 Score
f1_score = 2 * (val_precision * val_recall) / (val_precision + val_recall)
print(f"Validation F1-Score: {f1_score:.4f}")


## 14. Confusion Matrix and Classification Report


In [None]:
# Get predictions
validation_generator.reset()
y_pred = model.predict(validation_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = validation_generator.classes

# Get class names
class_names = list(train_generator.class_indices.keys())

# Classification report
print("\n=== Classification Report ===")
print(classification_report(y_true, y_pred_classes, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()


## 15. Save the Model


In [None]:
# Save the model
model.save('breast_cancer_detection_model.h5')
print("Model saved successfully!")

# Save model in SavedModel format (recommended for deployment)
model.save('breast_cancer_detection_model_savedmodel')
print("Model saved in SavedModel format!")


## 16. Inference Function


In [None]:
# Function to predict on new images
def predict_image(image_path, model, class_names):
    """
    Predict cancer type from an image
    """
    # Load and preprocess image
    img = Image.open(image_path)
    img = img.resize((IMG_SIZE, IMG_SIZE))
    img_array = np.array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    # Make prediction
    predictions = model.predict(img_array)
    predicted_class_idx = np.argmax(predictions[0])
    predicted_class = class_names[predicted_class_idx]
    confidence = predictions[0][predicted_class_idx]
    
    # Display results
    plt.figure(figsize=(10, 4))
    
    # Show image
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.title(f"Predicted: {predicted_class}\nConfidence: {confidence:.2%}")
    plt.axis('off')
    
    # Show probability distribution
    plt.subplot(1, 2, 2)
    plt.barh(class_names, predictions[0])
    plt.xlabel('Probability')
    plt.title('Class Probabilities')
    plt.tight_layout()
    plt.show()
    
    return predicted_class, confidence

# Example usage (uncomment to test with an image)
# test_image_path = 'path/to/your/test/image.jpg'
# predicted_class, confidence = predict_image(test_image_path, model, class_names)


## 17. Test on Random Validation Images


In [None]:
# Test on random validation images
def test_random_images(data_dir, model, class_names, num_images=6):
    plt.figure(figsize=(15, 10))
    
    # Get random images from validation set
    categories = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    for i in range(min(num_images, len(categories))):
        # Pick random category and image
        category = np.random.choice(categories)
        cat_path = os.path.join(data_dir, category)
        images = [f for f in os.listdir(cat_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
        
        if images:
            img_name = np.random.choice(images)
            img_path = os.path.join(cat_path, img_name)
            
            # Load and preprocess
            img = Image.open(img_path)
            img_resized = img.resize((IMG_SIZE, IMG_SIZE))
            img_array = np.array(img_resized) / 255.0
            img_array = np.expand_dims(img_array, axis=0)
            
            # Predict
            predictions = model.predict(img_array, verbose=0)
            predicted_idx = np.argmax(predictions[0])
            predicted_class = class_names[predicted_idx]
            confidence = predictions[0][predicted_idx]
            
            # Plot
            plt.subplot(2, 3, i + 1)
            plt.imshow(img)
            plt.title(f"True: {category}\nPred: {predicted_class}\nConf: {confidence:.2%}")
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Test on random images
test_random_images(data_dir, model, class_names)


In [None]:
# Unfreeze base model for fine-tuning
base_model.trainable = True

# Freeze early layers, unfreeze later layers
for layer in base_model.layers[:-20]:
    layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE/10),
    loss='categorical_crossentropy',
    metrics=['accuracy', keras.metrics.Precision(), keras.metrics.Recall()]
)

print("Model ready for fine-tuning")
print(f"Total layers: {len(model.layers)}")
print(f"Trainable layers: {sum([1 for layer in model.layers if layer.trainable])}")


In [None]:
# Fine-tune the model (uncomment to run)
# history_fine = model.fit(
#     train_generator,
#     epochs=10,
#     validation_data=validation_generator,
#     callbacks=callbacks,
#     verbose=1
# )


## Next Steps & Improvements

### Model Improvements:
- **Different Architectures**: Try ResNet50, DenseNet121, or Vision Transformers
- **Hyperparameter Tuning**: Experiment with learning rates, batch sizes, and dropout rates
- **Data Augmentation**: Add more augmentation techniques (color jitter, elastic transforms)
- **Ensemble Methods**: Combine multiple models for better predictions
- **Class Balancing**: If dataset is imbalanced, use class weights or SMOTE

### Fine-tuning:
- Uncomment the fine-tuning section above to further improve accuracy
- Experiment with unfreezing different numbers of layers

### Deployment Options:
1. **TensorFlow Lite**: Convert for mobile deployment
2. **Web API**: Create REST API using Flask/FastAPI
3. **Streamlit App**: Build interactive web interface
4. **Docker**: Containerize for easy deployment

### Model Optimization:
- **Quantization**: Reduce model size for faster inference
- **Mixed Precision Training**: Speed up training with FP16
- **Cross-Validation**: Implement K-fold cross-validation

### Evaluation:
- **ROC-AUC Curves**: Better evaluation for medical classification
- **Grad-CAM**: Visualize what the model is looking at
- **Error Analysis**: Analyze misclassified samples
