# MRI vs Breast Histopathology Classification - Part 1: Data Preparation

**Objective**: Build a binary classifier to distinguish between MRI brain scans and breast histopathology images with >95% accuracy.

## Part 1: Dataset Organization and Preprocessing
- Organize MRI data from 4 tumor classes into single MRI folder
- Extract and organize breast histopathology data into BreastHisto folder
- Create train/validation/test splits
- Implement data preprocessing pipeline

**Dataset Overview**:
- **MRI**: Brain tumor images from 4 classes (glioma, meningioma, no_tumor, pituitary)
- **BreastHisto**: Breast histopathology patches (IDC+ and IDC-)
- **Target**: Binary classification between the two modalities

---

## 1. Environment Setup and Imports

In [None]:
# Core libraries
import os
import sys
import shutil
import random
import warnings
from pathlib import Path
from collections import Counter
import numpy as np
import pandas as pd
from tqdm import tqdm

# Image processing
from PIL import Image
import cv2

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

print("✅ Libraries imported successfully")
print(f"Random seed set to: {SEED}")

## 2. Project Structure and Path Configuration

In [None]:
# Project paths
project_root = Path.cwd().parent
data_dir = project_root / 'data'
raw_data_dir = data_dir / 'raw'
processed_data_dir = data_dir / 'processed'

# Source data paths
tumor_source = project_root / 'tumor' / 'tumor'  # MRI data
breasthisto_source = project_root / 'archive_2'  # Breast histopathology data

# Create directories
raw_data_dir.mkdir(parents=True, exist_ok=True)
processed_data_dir.mkdir(parents=True, exist_ok=True)

print("📁 Project Structure:")
print(f"  Project root: {project_root}")
print(f"  Data directory: {data_dir}")
print(f"  Raw data: {raw_data_dir}")
print(f"  Processed data: {processed_data_dir}")
print(f"\n📂 Source Data:")
print(f"  MRI source: {tumor_source}")
print(f"  BreastHisto source: {breasthisto_source}")

# Verify source directories exist
print(f"\n🔍 Source Directory Verification:")
print(f"  MRI source exists: {tumor_source.exists()}")
print(f"  BreastHisto source exists: {breasthisto_source.exists()}")

if tumor_source.exists():
    mri_classes = list(tumor_source.glob('*'))
    print(f"  MRI classes found: {[cls.name for cls in mri_classes if cls.is_dir()]}")
    
if breasthisto_source.exists():
    breasthisto_folders = list(breasthisto_source.glob('*'))
    print(f"  BreastHisto folders: {len([f for f in breasthisto_folders if f.is_dir()])} patient folders")

## 3. MRI Data Organization

Consolidate all MRI images from 4 tumor classes into a single MRI folder.

In [None]:
def organize_mri_data(source_dir, target_dir, max_samples_per_class=None, convert_to_grayscale=True):
    """
    Organize MRI data from multiple tumor classes into single MRI folder with grayscale conversion
    
    Args:
        source_dir: Path to tumor source directory
        target_dir: Path to target MRI directory
        max_samples_per_class: Maximum samples per class (for balancing)
        convert_to_grayscale: Whether to convert images to grayscale during processing
    """
    print("🧠 Organizing MRI Data with Grayscale Conversion...")
    print("=" * 50)
    
    # Create target directory
    target_dir.mkdir(parents=True, exist_ok=True)
    
    tumor_classes = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
    image_extensions = ['.jpg', '.jpeg', '.png', '.tiff', '.tif', '.bmp']
    
    total_copied = 0
    class_counts = {}
    
    for class_name in tumor_classes:
        class_dir = source_dir / class_name
        
        if not class_dir.exists():
            print(f"⚠️  {class_name} directory not found: {class_dir}")
            continue
            
        print(f"\n📂 Processing {class_name}...")
        
        # Get all image files
        image_files = []
        for ext in image_extensions:
            image_files.extend(list(class_dir.glob(f'*{ext}')))
            image_files.extend(list(class_dir.glob(f'*{ext.upper()}')))
        
        print(f"  Found {len(image_files)} images")
        
        # Limit samples if specified
        if max_samples_per_class and len(image_files) > max_samples_per_class:
            random.shuffle(image_files)
            image_files = image_files[:max_samples_per_class]
            print(f"  Limited to {max_samples_per_class} images")
        
        # Process and save images
        copied_count = 0
        for i, img_path in enumerate(tqdm(image_files, desc=f"Processing {class_name}")):
            try:
                # Create new filename with class prefix (always save as .png for consistency)
                new_filename = f"mri_{class_name}_{i:04d}.png"
                target_path = target_dir / new_filename
                
                if convert_to_grayscale:
                    # Load image and convert to grayscale
                    img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
                    if img is None:
                        # Fallback to PIL if cv2 fails
                        img = np.array(Image.open(img_path).convert('L'))
                    
                    # Save as grayscale PNG
                    cv2.imwrite(str(target_path), img)
                else:
                    # Copy file as-is
                    shutil.copy2(img_path, target_path)
                
                copied_count += 1
                
            except Exception as e:
                print(f"    ⚠️  Error processing {img_path.name}: {e}")
        
        class_counts[class_name] = copied_count
        total_copied += copied_count
        print(f"  ✅ Processed {copied_count} images to grayscale")
    
    print(f"\n📊 MRI Data Organization Summary:")
    for class_name, count in class_counts.items():
        percentage = (count / total_copied * 100) if total_copied > 0 else 0
        print(f"  {class_name}: {count:,} images ({percentage:.1f}%)")
    print(f"  Total MRI images: {total_copied:,}")
    if convert_to_grayscale:
        print(f"  ✅ All images converted to grayscale PNG format")
    
    return class_counts, total_copied

# Organize MRI data with grayscale conversion
mri_target_dir = raw_data_dir / 'MRI'

if tumor_source.exists():
    mri_class_counts, mri_total = organize_mri_data(
        tumor_source, 
        mri_target_dir, 
        max_samples_per_class=2000,  # Increased for better dataset size
        convert_to_grayscale=True    # Enable grayscale conversion
    )
else:
    print(f"❌ MRI source directory not found: {tumor_source}")
    mri_class_counts, mri_total = {}, 0

## 4. Breast Histopathology Data Organization

Extract and organize breast histopathology images from the archive_2 structure.

In [None]:
def organize_breasthisto_data(source_dir, target_dir, max_samples=None, convert_to_grayscale=True):
    """
    Organize breast histopathology data from archive_2 structure with grayscale conversion
    
    Args:
        source_dir: Path to archive_2 directory
        target_dir: Path to target BreastHisto directory  
        max_samples: Maximum total samples to copy
        convert_to_grayscale: Whether to convert images to grayscale during processing
    """
    print("🔬 Organizing Breast Histopathology Data with Grayscale Conversion...")
    print("=" * 50)
    
    # Create target directory
    target_dir.mkdir(parents=True, exist_ok=True)
    
    image_extensions = ['.jpg', '.jpeg', '.png', '.tiff', '.tif', '.bmp']
    total_copied = 0
    class_counts = {'IDC_negative': 0, 'IDC_positive': 0}
    
    # Look for the main data directory
    data_root = source_dir / 'IDC_regular_ps50_idx5'
    if not data_root.exists():
        # Alternative: look for patient folders directly in source_dir
        data_root = source_dir
        
    print(f"📂 Scanning directory: {data_root}")
    
    # Get all patient folders
    patient_folders = [f for f in data_root.iterdir() if f.is_dir() and f.name.isdigit()]
    print(f"Found {len(patient_folders)} patient folders")
    
    if max_samples:
        print(f"Will collect maximum {max_samples} samples")
    
    # Process each patient folder
    for patient_folder in tqdm(patient_folders, desc="Processing patients"):
        if max_samples and total_copied >= max_samples:
            break
            
        # Each patient has class 0 (IDC-) and class 1 (IDC+) folders
        for class_folder in ['0', '1']:
            class_path = patient_folder / class_folder
            
            if not class_path.exists():
                continue
                
            # Get all images in this class folder
            image_files = []
            for ext in image_extensions:
                image_files.extend(list(class_path.glob(f'*{ext}')))
                image_files.extend(list(class_path.glob(f'*{ext.upper()}')))
            
            # Process images with appropriate naming
            for img_path in image_files:
                if max_samples and total_copied >= max_samples:
                    break
                    
                try:
                    # Determine class name
                    class_name = 'IDC_negative' if class_folder == '0' else 'IDC_positive'
                    
                    # Create new filename (always save as .png for consistency)
                    new_filename = f"breasthisto_{class_name}_{patient_folder.name}_{img_path.stem}.png"
                    target_path = target_dir / new_filename
                    
                    if convert_to_grayscale:
                        # Load image and convert to grayscale
                        img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
                        if img is None:
                            # Fallback to PIL if cv2 fails
                            img = np.array(Image.open(img_path).convert('L'))
                        
                        # Save as grayscale PNG
                        cv2.imwrite(str(target_path), img)
                    else:
                        # Copy file as-is
                        shutil.copy2(img_path, target_path)
                    
                    class_counts[class_name] += 1
                    total_copied += 1
                    
                except Exception as e:
                    print(f"    ⚠️  Error processing {img_path.name}: {e}")
    
    print(f"\n📊 Breast Histopathology Data Organization Summary:")
    for class_name, count in class_counts.items():
        percentage = (count / total_copied * 100) if total_copied > 0 else 0
        print(f"  {class_name}: {count:,} images ({percentage:.1f}%)")
    print(f"  Total BreastHisto images: {total_copied:,}")
    if convert_to_grayscale:
        print(f"  ✅ All images converted to grayscale PNG format")
    
    return class_counts, total_copied

# Organize Breast Histopathology data with grayscale conversion
breasthisto_target_dir = raw_data_dir / 'BreastHisto'

if breasthisto_source.exists():
    breasthisto_class_counts, breasthisto_total = organize_breasthisto_data(
        breasthisto_source, 
        breasthisto_target_dir,
        max_samples=8000,           # Increased to balance with MRI data  
        convert_to_grayscale=True   # Enable grayscale conversion
    )
else:
    print(f"❌ BreastHisto source directory not found: {breasthisto_source}")
    breasthisto_class_counts, breasthisto_total = {}, 0

## 5. Dataset Statistics and Verification

In [None]:
def analyze_organized_data(raw_data_dir):
    """
    Analyze the organized raw data and display statistics
    """
    print("📊 Dataset Analysis")
    print("=" * 50)
    
    classes = ['MRI', 'BreastHisto']
    total_images = 0
    class_stats = {}
    
    for class_name in classes:
        class_dir = raw_data_dir / class_name
        
        if class_dir.exists():
            # Count images
            image_files = []
            for ext in ['*.jpg', '*.jpeg', '*.png', '*.tiff', '*.tif', '*.bmp']:
                image_files.extend(list(class_dir.glob(ext)))
                image_files.extend(list(class_dir.glob(ext.upper())))
            
            count = len(image_files)
            class_stats[class_name] = count
            total_images += count
            
            print(f"\n📂 {class_name}:")
            print(f"  Images: {count:,}")
            print(f"  Directory: {class_dir}")
            
            # Sample a few images to check
            sample_images = random.sample(image_files, min(3, len(image_files)))
            print(f"  Sample files: {[img.name for img in sample_images]}")
            
        else:
            print(f"\n❌ {class_name}: Directory not found")
            class_stats[class_name] = 0
    
    print(f"\n📈 Overall Statistics:")
    print(f"  Total images: {total_images:,}")
    
    if total_images > 0:
        for class_name, count in class_stats.items():
            percentage = (count / total_images) * 100
            print(f"  {class_name}: {count:,} ({percentage:.1f}%)")
    
    return class_stats, total_images

# Analyze organized data
final_class_stats, final_total = analyze_organized_data(raw_data_dir)

# Check data balance
print(f"\n⚖️  Data Balance Analysis:")
if final_total > 0:
    mri_count = final_class_stats.get('MRI', 0)
    breasthisto_count = final_class_stats.get('BreastHisto', 0)
    
    if mri_count > 0 and breasthisto_count > 0:
        ratio = max(mri_count, breasthisto_count) / min(mri_count, breasthisto_count)
        print(f"  Class ratio: {ratio:.2f}:1")
        
        if ratio <= 2.0:
            print(f"  ✅ Classes are reasonably balanced")
        else:
            print(f"  ⚠️  Classes are imbalanced - consider balancing strategies")
    else:
        print(f"  ❌ One or both classes are missing")
else:
    print(f"  ❌ No data found")

## 6. Data Splitting (Train/Validation/Test)

Create balanced train/validation/test splits for both classes.

In [None]:
import shutil
from tqdm import tqdm
import random
from pathlib import Path

def create_data_splits(raw_data_dir, processed_data_dir, split_ratios=(0.7, 0.15, 0.15), random_seed=42):
    """
    Create train/validation/test splits from organized raw data, ensuring NO overlap.
    
    Args:
        raw_data_dir: Path to raw organized data
        processed_data_dir: Path to processed data directory
        split_ratios: Tuple of (train, val, test) ratios
        random_seed: Random seed for reproducible splits
    """
    print("✂️  Creating Data Splits")
    print("=" * 50)
    
    # Set random seed for reproducible splits
    random.seed(random_seed)
    print(f"🎲 Using random seed: {random_seed}")
    
    # Clean up the destination directory to ensure a fresh start
    if processed_data_dir.exists():
        print(f"🧹 Cleaning existing processed data directory: {processed_data_dir}")
        shutil.rmtree(processed_data_dir)
    
    # Re-create the base directory after cleaning
    print(f"✨ Creating new processed data directory: {processed_data_dir}")
    processed_data_dir.mkdir(parents=True, exist_ok=True)

    train_ratio, val_ratio, test_ratio = split_ratios
    print(f"Split ratios - Train: {train_ratio:.1%}, Val: {val_ratio:.1%}, Test: {test_ratio:.1%}")
    
    # Create split subdirectories
    splits = ['train', 'val', 'test']
    classes = ['MRI', 'BreastHisto']
    
    for split in splits:
        for class_name in classes:
            split_dir = processed_data_dir / split / class_name
            split_dir.mkdir(parents=True, exist_ok=True)
    
    split_stats = {}
    all_processed_files = set()  # Track all processed files globally to prevent duplicates
    
    # Process each class
    for class_name in classes:
        raw_class_dir = raw_data_dir / class_name
        
        if not raw_class_dir.exists():
            print(f"⚠️  {class_name} directory not found: {raw_class_dir}")
            continue
            
        print(f"\n📂 Splitting {class_name} data...")
        
        # Get all image files with deduplication
        image_files = set()  # Use set to automatically handle duplicates
        extensions = ['jpg', 'jpeg', 'png', 'tiff', 'tif', 'bmp']
        
        for ext in extensions:
            # Check both lowercase and uppercase extensions
            for pattern in [f'*.{ext}', f'*.{ext.upper()}']:
                for file_path in raw_class_dir.glob(pattern):
                    # Resolve the path to handle any symlinks/relative paths
                    resolved_path = file_path.resolve()
                    image_files.add(resolved_path)
        
        # Convert back to list for indexing
        image_files = list(image_files)
        print(f"  Total unique images found: {len(image_files)}")
        
        # Check for global duplicates
        current_file_names = {f.name for f in image_files}
        global_overlaps = current_file_names.intersection(all_processed_files)
        if global_overlaps:
            print(f"  ⚠️  Warning: {len(global_overlaps)} files already processed in another class")
            # Remove duplicates to prevent cross-class contamination
            image_files = [f for f in image_files if f.name not in all_processed_files]
            print(f"  After deduplication: {len(image_files)} images")
        
        # Add current files to global tracker
        all_processed_files.update(f.name for f in image_files)
        
        if len(image_files) == 0:
            print(f"  ⚠️  No unique images found for {class_name}")
            continue
        
        # Shuffle for random splits (using the set seed)
        random.shuffle(image_files)
        
        # Calculate split indices
        n_total = len(image_files)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        
        # Create non-overlapping splits using slicing
        train_files = image_files[:n_train]
        val_files = image_files[n_train:n_train + n_val]
        test_files = image_files[n_train + n_val:]
        
        print(f"  Split sizes - Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")
        
        # Verify no overlaps within this class
        train_set = set(f.name for f in train_files)
        val_set = set(f.name for f in val_files)
        test_set = set(f.name for f in test_files)
        
        train_val_overlap = train_set.intersection(val_set)
        train_test_overlap = train_set.intersection(test_set)
        val_test_overlap = val_set.intersection(test_set)
        
        if train_val_overlap or train_test_overlap or val_test_overlap:
            print(f"  ❌ ERROR: Overlap detected within {class_name}!")
            print(f"    Train-Val: {len(train_val_overlap)}, Train-Test: {len(train_test_overlap)}, Val-Test: {len(val_test_overlap)}")
            continue
        else:
            print(f"  ✅ No overlaps detected within {class_name}")
        
        # Copy files to respective directories
        split_data = {
            'train': train_files,
            'val': val_files,
            'test': test_files
        }
        
        class_split_stats = {}
        
        for split, files in split_data.items():
            split_dir = processed_data_dir / split / class_name
            copied_count = 0
            
            print(f"  📁 Copying {len(files)} files to {split}...")
            for img_file in tqdm(files, desc=f"    {split}", leave=False):
                try:
                    target_path = split_dir / img_file.name
                    
                    # Double-check that target doesn't already exist
                    if target_path.exists():
                        print(f"    ⚠️  Target already exists: {target_path.name}")
                        continue
                    
                    shutil.copy2(img_file, target_path)
                    copied_count += 1
                except Exception as e:
                    print(f"    ❌ Error copying {img_file.name}: {e}")
            
            class_split_stats[split] = copied_count
            print(f"  ✅ Successfully copied {copied_count} files to {split}/{class_name}")
        
        split_stats[class_name] = class_split_stats
    
    # Display final split statistics
    print(f"\n📊 Final Data Split Summary:")
    print("=" * 50)
    
    global_train_files = set()
    global_val_files = set()
    global_test_files = set()
    
    for split in splits:
        print(f"\n{split.upper()} SET:")
        split_total = 0
        
        # Collect files for global overlap check
        split_files = set()
        for class_name in classes:
            split_dir = processed_data_dir / split / class_name
            if split_dir.exists():
                class_files = {f.name for f in split_dir.iterdir() if f.is_file()}
                split_files.update(class_files)
        
        # Store for global check
        if split == 'train':
            global_train_files = split_files
        elif split == 'val':
            global_val_files = split_files
        elif split == 'test':
            global_test_files = split_files
        
        for class_name in classes:
            count = split_stats.get(class_name, {}).get(split, 0)
            split_total += count
            print(f"  {class_name}: {count:,} images")
        
        print(f"  Total: {split_total:,} images")
        
        # Calculate class balance within split
        if split_total > 0:
            print("  Class distribution:")
            for class_name in classes:
                count = split_stats.get(class_name, {}).get(split, 0)
                percentage = (count / split_total) * 100
                print(f"    {class_name}: {percentage:.1f}%")
    
    # Final global overlap verification
    print(f"\n🔍 Final Overlap Verification:")
    print("=" * 30)
    
    train_val_final = global_train_files.intersection(global_val_files)
    train_test_final = global_train_files.intersection(global_test_files)
    val_test_final = global_val_files.intersection(global_test_files)
    
    print(f"Train-Val overlap: {len(train_val_final)} files")
    print(f"Train-Test overlap: {len(train_test_final)} files")
    print(f"Val-Test overlap: {len(val_test_final)} files")
    
    if train_val_final or train_test_final or val_test_final:
        print("❌ OVERLAP STILL DETECTED!")
        if train_val_final:
            print(f"  Train-Val overlaps: {list(train_val_final)[:5]}...")
        if train_test_final:
            print(f"  Train-Test overlaps: {list(train_test_final)[:5]}...")
        if val_test_final:
            print(f"  Val-Test overlaps: {list(val_test_final)[:5]}...")
    else:
        print("✅ SUCCESS: No overlaps detected between splits!")
    
    return split_stats


# Usage with your existing code:
if final_total > 0:
    split_statistics = create_data_splits(
        raw_data_dir, 
        processed_data_dir,
        split_ratios=(0.7, 0.15, 0.15),
        random_seed=42  # For reproducible results
    )
else:
    print("❌ No data available for splitting")
    split_statistics = {}

### 6.1. Data Split Verification

To ensure model integrity, we must verify that there is absolutely no overlap (data leakage) between the training, validation, and test sets. We do this by collecting all filenames from each split and checking for common elements between them.

In [None]:
def verify_data_splits_no_overlap(processed_data_dir):
    """
    Verifies that there is no file name overlap between train, val, and test sets.
    """
    print("🛡️ Verifying Data Splits for Overlap...")
    print("=" * 50)
    
    splits = ['train', 'val', 'test']
    classes = ['MRI', 'BreastHisto']
    
    # Collect all file names for each split
    file_sets = {split: set() for split in splits}
    
    all_dirs_found = True
    for split in splits:
        for class_name in classes:
            split_dir = processed_data_dir / split / class_name
            if not split_dir.exists():
                print(f"⚠️ Directory not found: {split_dir}")
                all_dirs_found = False
                continue
            
            files = {f.name for f in split_dir.iterdir()}
            file_sets[split].update(files)
    
    if not all_dirs_found:
        print("❌ Verification aborted due to missing directories.")
        return
        
    # Check for overlaps
    train_val_overlap = file_sets['train'].intersection(file_sets['val'])
    train_test_overlap = file_sets['train'].intersection(file_sets['test'])
    val_test_overlap = file_sets['val'].intersection(file_sets['test'])
    
    print(f"  Train set size: {len(file_sets['train']):,}")
    print(f"  Validation set size: {len(file_sets['val']):,}")
    print(f"  Test set size: {len(file_sets['test']):,}")
    print("-" * 20)
    print(f"  Train-Validation Overlap: {len(train_val_overlap)} files")
    print(f"  Train-Test Overlap: {len(train_test_overlap)} files")
    print(f"  Validation-Test Overlap: {len(val_test_overlap)} files")
    
    if not train_val_overlap and not train_test_overlap and not val_test_overlap:
        print("\n✅ SUCCESS: No overlap found between data splits. Data is properly separated.")
    else:
        print("\n❌ FAILURE: Overlap detected between data splits! This will cause data leakage.")
        if train_val_overlap:
            print(f"   - Overlapping files (train/val): {list(train_val_overlap)[:5]}...")
        if train_test_overlap:
            print(f"   - Overlapping files (train/test): {list(train_test_overlap)[:5]}...")
        if val_test_overlap:
            print(f"   - Overlapping files (val/test): {list(val_test_overlap)[:5]}...")

# Run the verification
if split_statistics:
    verify_data_splits_no_overlap(processed_data_dir)
else:
    print("⚠️ Cannot verify splits as they were not created.")

## 7. Sample Visualization and Quality Check

In [None]:
def visualize_split_samples(processed_data_dir, split_name, n_samples=4):
    """
    Visualize sample images from a specific split (train, val, or test).
    """
    print(f"\n🖼️  Sample Image Visualization for '{split_name.upper()}' Set")
    print("=" * 50)
    
    classes = ['MRI', 'BreastHisto']
    
    fig, axes = plt.subplots(len(classes), n_samples, figsize=(15, 8))
    fig.suptitle(f'Sample Images from {split_name.upper()} Set', fontsize=16)
    
    for class_idx, class_name in enumerate(classes):
        data_dir = processed_data_dir / split_name / class_name
        
        if data_dir.exists():
            image_files = list(data_dir.glob('*.png')) + list(data_dir.glob('*.jpg'))
            
            if len(image_files) >= n_samples:
                sample_images = random.sample(image_files, n_samples)
                for img_idx, img_path in enumerate(sample_images):
                    try:
                        img = Image.open(img_path)
                        if img.mode != 'RGB':
                            img = img.convert('RGB')
                        
                        ax = axes[class_idx, img_idx]
                        ax.imshow(img)
                        ax.axis('off')
                        title = f"{class_name}\n{img.size[0]}x{img.size[1]}"
                        ax.set_title(title, fontsize=10)
                    except Exception as e:
                        ax = axes[class_idx, img_idx]
                        ax.text(0.5, 0.5, 'Error', ha='center', va='center', transform=ax.transAxes)
                        ax.set_title(f"{class_name} (Error)")
                        print(f"⚠️ Error loading {img_path}: {e}")
            else:
                for img_idx in range(n_samples):
                    ax = axes[class_idx, img_idx]
                    ax.text(0.5, 0.5, 'Not enough images', ha='center', va='center', transform=ax.transAxes)
                    ax.set_title(f"{class_name} (Insufficient)")
        else:
            for img_idx in range(n_samples):
                ax = axes[class_idx, img_idx]
                ax.text(0.5, 0.5, 'Directory not found', ha='center', va='center', transform=ax.transAxes)
                ax.set_title(f"{class_name} (Not Found)")

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# Visualize samples for all splits if data exists
if split_statistics:
    for split in ['train', 'val', 'test']:
        visualize_split_samples(processed_data_dir, split, n_samples=4)
else:
    print("⚠️ No split data available for visualization")

# Print image statistics from the training set
print(f"\n📏 Image Statistics Analysis (from Train Set):")
if split_statistics:
    for class_name in ['MRI', 'BreastHisto']:
        train_dir = processed_data_dir / 'train' / class_name
        if train_dir.exists():
            print(f"\n{class_name}:")
            image_files = list(train_dir.glob('*.png')) + list(train_dir.glob('*.jpg'))
            sample_size = min(50, len(image_files))
            if sample_size > 0:
                sample_images = random.sample(image_files, sample_size)
                sizes, formats, modes = [], [], []
                for img_path in sample_images:
                    try:
                        with Image.open(img_path) as img:
                            sizes.append(img.size)
                            formats.append(img.format)
                            modes.append(img.mode)
                    except Exception: continue
                if sizes:
                    widths, heights = [s[0] for s in sizes], [s[1] for s in sizes]
                    print(f"  Sample size: {len(sizes)} images")
                    print(f"  Width range: {min(widths)} - {max(widths)} px")
                    print(f"  Height range: {min(heights)} - {max(heights)} px")
                    print(f"  Common formats: {Counter(formats).most_common(2)}")
                    print(f"  Common modes: {Counter(modes).most_common(2)}")
            else:
                print("  No images found for analysis")

## 8. Data Preparation Summary and Next Steps

In [None]:
def generate_data_preparation_report(split_statistics, final_class_stats):
    """
    Generate a comprehensive report of data preparation results
    """
    print("📋 Data Preparation Summary Report")
    print("=" * 60)
    
    # Overall statistics
    total_images = sum(final_class_stats.values())
    
    print(f"\n📊 OVERALL DATA STATISTICS:")
    print(f"  Total images collected: {total_images:,}")
    
    for class_name, count in final_class_stats.items():
        percentage = (count / total_images * 100) if total_images > 0 else 0
        print(f"  {class_name}: {count:,} images ({percentage:.1f}%)")
    
    # Split statistics
    if split_statistics:
        print(f"\n📈 DATA SPLIT BREAKDOWN:")
        
        splits = ['train', 'val', 'test']
        for split in splits:
            split_total = 0
            print(f"\n  {split.upper()} SET:")
            
            for class_name in ['MRI', 'BreastHisto']:
                count = split_statistics.get(class_name, {}).get(split, 0)
                split_total += count
                print(f"    {class_name}: {count:,} images")
            
            print(f"    Total: {split_total:,} images")
            
            # Class balance
            if split_total > 0:
                for class_name in ['MRI', 'BreastHisto']:
                    count = split_statistics.get(class_name, {}).get(split, 0)
                    percentage = (count / split_total) * 100
                    print(f"      {class_name}: {percentage:.1f}%")
    
    # Data quality assessment
    print(f"\n✅ DATA QUALITY ASSESSMENT:")
    
    # Check if we have data
    if total_images > 0:
        print(f"  ✅ Data collection: SUCCESS")
    else:
        print(f"  ❌ Data collection: FAILED")
        return
    
    # Check class balance
    mri_count = final_class_stats.get('MRI', 0)
    breasthisto_count = final_class_stats.get('BreastHisto', 0)
    
    if mri_count > 0 and breasthisto_count > 0:
        ratio = max(mri_count, breasthisto_count) / min(mri_count, breasthisto_count)
        if ratio <= 2.0:
            print(f"  ✅ Class balance: GOOD (ratio {ratio:.2f}:1)")
        else:
            print(f"  ⚠️  Class balance: IMBALANCED (ratio {ratio:.2f}:1)")
    else:
        print(f"  ❌ Class balance: MISSING CLASSES")
    
    # Check minimum data requirements
    min_required = 1000  # Minimum images per class for good training
    
    data_sufficient = all(count >= min_required for count in final_class_stats.values())
    if data_sufficient:
        print(f"  ✅ Data sufficiency: ADEQUATE (>{min_required} per class)")
    else:
        print(f"  ⚠️  Data sufficiency: LIMITED (<{min_required} per class)")
    
    # Directory structure
    expected_dirs = [
        processed_data_dir / 'train' / 'MRI',
        processed_data_dir / 'train' / 'BreastHisto',
        processed_data_dir / 'val' / 'MRI',
        processed_data_dir / 'val' / 'BreastHisto',
        processed_data_dir / 'test' / 'MRI',
        processed_data_dir / 'test' / 'BreastHisto'
    ]
    
    all_dirs_exist = all(dir_path.exists() for dir_path in expected_dirs)
    if all_dirs_exist:
        print(f"  ✅ Directory structure: COMPLETE")
    else:
        print(f"  ❌ Directory structure: INCOMPLETE")
    
    # Recommendations
    print(f"\n💡 RECOMMENDATIONS FOR NEXT STEPS:")
    
    if data_sufficient and all_dirs_exist:
        print(f"  🚀 Ready to proceed to Part 2: Model Architecture and Training")
        print(f"  📝 Consider implementing data augmentation to increase effective dataset size")
        print(f"  🎯 Target: Build custom CNN achieving >95% accuracy")
    else:
        if not data_sufficient:
            print(f"  📈 Consider collecting more data or reducing train/val/test requirements")
        if not all_dirs_exist:
            print(f"  🔧 Fix directory structure before proceeding")
    
    # File paths for next parts
    print(f"\n📁 KEY PATHS FOR NEXT PARTS:")
    print(f"  Processed data: {processed_data_dir}")
    print(f"  Train data: {processed_data_dir / 'train'}")
    print(f"  Validation data: {processed_data_dir / 'val'}")
    print(f"  Test data: {processed_data_dir / 'test'}")
    
    return {
        'total_images': total_images,
        'class_stats': final_class_stats,
        'split_stats': split_statistics,
        'data_sufficient': data_sufficient,
        'dirs_complete': all_dirs_exist,
        'ready_for_training': data_sufficient and all_dirs_exist
    }

# Generate final report
preparation_report = generate_data_preparation_report(split_statistics, final_class_stats)

# Save preparation metadata for next parts
import json

metadata = {
    'preparation_date': pd.Timestamp.now().isoformat(),
    'total_images': final_total,
    'class_statistics': final_class_stats,
    'split_statistics': split_statistics,
    'processed_data_path': str(processed_data_dir),
    'ready_for_training': preparation_report.get('ready_for_training', False)
}

metadata_path = processed_data_dir / 'preparation_metadata.json'
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"\n💾 Preparation metadata saved to: {metadata_path}")
print(f"\n" + "=" * 60)
print("PART 1 COMPLETED: DATA PREPARATION")
print("=" * 60)
print(f"✅ Data organization and splitting completed successfully!")
print(f"🚀 Ready to proceed to Part 2: Model Architecture and Training")

## 9. Part 2 Preview: Custom Shallow CNN Architecture

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

# Imports for the visualization part
import random
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image

print(f"PyTorch Version: {torch.__version__}")

# Set device (use GPU if available, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

### Model Design

For this binary classification task, a complex, deep architecture like ResNet or VGG is likely unnecessary and could lead to overfitting given the distinct nature of the two image classes. A custom, shallow CNN is a more efficient and interpretable starting point.

- **Architecture**: A simple CNN with 3-4 convolutional blocks. Each block consists of a `Conv2D` layer to learn features, `BatchNormalization` to stabilize learning, and `MaxPooling2D` to downsample and create spatial invariance.
- **Loss Function**: `binary_crossentropy` is the standard choice for a two-class classification problem.
- **Evaluation**: We will monitor `accuracy` during training. For a comprehensive evaluation on the test set, we will use `precision`, `recall`, `F1-score`, and a `confusion matrix` to understand the model's performance on each class.

In [None]:
class ShallowCNN(nn.Module):
    """
    A shallow Convolutional Neural Network model for binary classification,
    equivalent to the provided Keras model.
    """
    def __init__(self, input_channels=1):
        super(ShallowCNN, self).__init__()
        
        self.features = nn.Sequential(
            # Convolutional Block 1
            # Input: (N, 1, 128, 128)
            nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output: (N, 32, 64, 64)
            
            # Convolutional Block 2
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output: (N, 64, 32, 32)
            
            # Convolutional Block 3
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output: (N, 128, 16, 16)
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(), # Output: (N, 128 * 16 * 16 = 32768)
            nn.Linear(128 * 16 * 16, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1),
            nn.Sigmoid() # Sigmoid for binary classification
        )

    def forward(self, x):
        """Defines the forward pass of the model."""
        x = self.features(x)
        x = self.classifier(x)
        return x

### Focus: Breast Histopathology Images

Breast histopathology images consist of small 50x50 pixel patches. The key distinction is between patches with Invasive Ductal Carcinoma (IDC+) and those without (IDC-).
- **IDC- (Negative)**: Generally show more uniform, organized tissue structures with less cellular density.
- **IDC+ (Positive)**: Often characterized by a higher density of cancer cells, irregular shapes, and darker staining nuclei.

In [None]:
# Create the model
# Assuming input images will be resized to 128x128 grayscale
INPUT_SHAPE_PYTORCH = (1, 128, 128) # (Channels, Height, Width)
model = ShallowCNN(input_channels=INPUT_SHAPE_PYTORCH[0]).to(device)

# Print the model summary
print(f"\nModel created for input shape: {INPUT_SHAPE_PYTORCH}")
summary(model, input_size=INPUT_SHAPE_PYTORCH)

# Define Loss Function and Optimizer (equivalent to model.compile)
LEARNING_RATE = 0.001
criterion = nn.BCELoss() # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("\n--- Optimizer and Loss Function ---")
print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Loss Function: {criterion.__class__.__name__}")
print("---------------------------------")

In [42]:

# 1. Imports for Data Handling and Training
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path
from tqdm import tqdm
import time

# 2. Configuration and Paths
try:
    project_root = Path.cwd().parent 
    processed_data_dir = project_root / 'data' / 'processed'
    if not processed_data_dir.exists():
        processed_data_dir = Path.cwd() / 'data' / 'processed'
except Exception:
    processed_data_dir = Path('./data/processed')

train_dir = processed_data_dir / 'train'
val_dir = processed_data_dir / 'val'

# Model & Training Hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 15
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"--- Configuration ---")
print(f"Using device: {DEVICE}")
print(f"Training data path: {train_dir}")
print(f"Validation data path: {val_dir}")
print(f"Batch Size: {BATCH_SIZE}, Epochs: {NUM_EPOCHS}, LR: {LEARNING_RATE}")
print("-" * 21)


# 3. Data Loading and Transformations
# The model expects 128x128 single-channel images.
data_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    
    transforms.Grayscale(num_output_channels=1),
    
    transforms.ToTensor(), # Converts grayscale PIL image [H, W] to tensor [1, H, W] and scales to [0, 1]
    transforms.Normalize(mean=[0.5], std=[0.5]) # Normalizes tensor to range [-1, 1]
])

# Use ImageFolder, which is perfect for your directory structure (train/class_A, train/class_B)
try:
    train_dataset = datasets.ImageFolder(root=train_dir, transform=data_transforms)
    val_dataset = datasets.ImageFolder(root=val_dir, transform=data_transforms)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

    print("\n✅ Data loaded successfully!")
    print(f"Found classes: {train_dataset.classes} -> {train_dataset.class_to_idx}")
    print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

except FileNotFoundError:
    print(f"\n❌ ERROR: Data directories not found. Please check the 'processed_data_dir' path.")
    print(f"   - Searched for train data at: {train_dir}")
    print(f"   - Searched for validation data at: {val_dir}")
    train_loader, val_loader = None, None

# 4. Initialize Model, Loss Function, and Optimizer
model = ShallowCNN(input_channels=1).to(DEVICE)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


# 5. Training and Validation Loop
def run_training(model, criterion, optimizer, train_loader, val_loader, num_epochs):
    if not train_loader or not val_loader:
        print("\nCannot start training because data loaders are not initialized.")
        return None, {}
        
    start_time = time.time()
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        
        # --- Training Phase ---
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs = inputs.to(DEVICE)
            labels = labels.float().view(-1, 1).to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            preds = (outputs > 0.5).float()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        print(f"Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}")

        # --- Validation Phase ---
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validating"):
                inputs = inputs.to(DEVICE)
                labels = labels.float().view(-1, 1).to(DEVICE)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                preds = (outputs > 0.5).float()
                val_loss += loss.item() * inputs.size(0)
                val_corrects += torch.sum(preds == labels)

        val_epoch_loss = val_loss / len(val_loader.dataset)
        val_epoch_acc = val_corrects.double() / len(val_loader.dataset)
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_epoch_acc.item())
        print(f"Val Loss:   {val_epoch_loss:.4f} | Val Acc:   {val_epoch_acc:.4f}")

    end_time = time.time()
    total_time = end_time - start_time
    print(f"\n🎉 Training Finished! Total time: {total_time // 60:.0f}m {total_time % 60:.0f}s")
    
    return model, history

# 6. Start the Training Process
trained_model, training_history = run_training(
    model, 
    criterion, 
    optimizer, 
    train_loader, 
    val_loader, 
    num_epochs=NUM_EPOCHS
)



--- Configuration ---
Using device: cuda
Training data path: c:\Users\Ammad\Documents\Projects\Personal\Brain\data\processed\train
Validation data path: c:\Users\Ammad\Documents\Projects\Personal\Brain\data\processed\val
Batch Size: 64, Epochs: 15, LR: 0.001
---------------------

✅ Data loaded successfully!
Found classes: ['BreastHisto', 'MRI'] -> {'BreastHisto': 0, 'MRI': 1}
Training samples: 7857, Validation samples: 1683

--- Epoch 1/15 ---


Training: 100%|██████████| 123/123 [00:31<00:00,  3.92it/s]


Train Loss: 0.0321 | Train Acc: 0.9948


Validating: 100%|██████████| 27/27 [00:19<00:00,  1.37it/s]


Val Loss:   0.0000 | Val Acc:   1.0000

--- Epoch 2/15 ---


Training:  41%|████▏     | 51/123 [00:17<00:24,  2.89it/s]


KeyboardInterrupt: 

In [None]:
def visualize_breasthisto_subtypes(processed_data_dir, n_samples=5):
    """Visualizes and compares IDC positive and negative samples."""
    # Ensure processed_data_dir is a Path object
    processed_data_dir = Path(processed_data_dir)
    breasthisto_dir = processed_data_dir / 'train' / 'BreastHisto'
    
    if not breasthisto_dir.exists():
        print(f"❌ Directory not found: {breasthisto_dir}")
        print("Please ensure your data is in the correct directory structure.")
        return
        
    all_files = list(breasthisto_dir.glob('*.png'))
    positive_files = [f for f in all_files if 'IDC_positive' in f.name]
    negative_files = [f for f in all_files if 'IDC_negative' in f.name]
    
    if len(positive_files) < n_samples or len(negative_files) < n_samples:
        print(f"⚠️ Not enough positive or negative samples to display.")
        print(f"Found {len(positive_files)} positive and {len(negative_files)} negative samples.")
        return

    sample_pos = random.sample(positive_files, n_samples)
    sample_neg = random.sample(negative_files, n_samples)

    fig, axes = plt.subplots(2, n_samples, figsize=(15, 6))
    fig.suptitle('Breast Histopathology Subtypes (Train Set)', fontsize=16)

    for i in range(n_samples):
        # Positive samples
        img_pos = Image.open(sample_pos[i])
        axes[0, i].imshow(img_pos)
        axes[0, i].set_title(f'IDC Positive\n{img_pos.size[0]}x{img_pos.size[1]}')
        axes[0, i].axis('off')

        # Negative samples
        img_neg = Image.open(sample_neg[i])
        axes[1, i].imshow(img_neg)
        axes[1, i].set_title(f'IDC Negative\n{img_neg.size[0]}x{img_neg.size[1]}')
        axes[1, i].axis('off')

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# --- Example Usage for Visualization ---
# NOTE: You must define the path to your processed data directory.
# This is a placeholder and will likely fail if not changed.
try:
    # Create a dummy path for demonstration purposes
    processed_data_dir = Path('./processed_data') 
    
    # You would call the function like this, assuming your data exists
    # and split_statistics is a boolean you've defined elsewhere.
    split_statistics = True # Assuming this condition is met
    if split_statistics:
        visualize_breasthisto_subtypes(processed_data_dir, n_samples=5)
except Exception as e:
    print(f"\nCould not run visualization. Error: {e}")
    print("Please update 'processed_data_dir' to your actual data path.")