<a href="https://colab.research.google.com/github/SinghAnsh07/DataScience/blob/main/AI_Project_Fine_Tuned_VIT(Cancer_detection).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [81]:
# Extract to Colab's fast local storage
!unzip -q "/content/drive/MyDrive/BreaKHis_v1.zip" -d /content/

# Verify extraction
!ls /content/BreaKHis_v1/histology_slides/breast


replace /content/BreaKHis_v1/histology_slides/breast/malignant/lobular_carcinoma.stat.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/BreaKHis_v1/histology_slides/breast/count_files.sh? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/BreaKHis_v1/histology_slides/breast/malignant/mucinous_carcinoma.stat.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/BreaKHis_v1/histology_slides/breast/malignant/papillary_carcinoma.stat.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/BreaKHis_v1/histology_slides/breast/malignant/ductal_carcinoma.stat.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/BreaKHis_v1/histology_slides/breast/benign/phyllodes_tumor.stat.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/BreaKHis_v1/histology_slides/breast/benign/process_db_stat.py? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/BreaKHis_v1/histology_slides/breast/README.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /conte

In [82]:
!pip install timm -q
print("✓ Installation complete!")


✓ Installation complete!


In [91]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import numpy as np
import random
import timm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("✓ All libraries imported successfully!")


✓ All libraries imported successfully!


In [92]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print("✓ GPU is ready for training!")
else:
    print("⚠️ Warning: No GPU detected. Training will be slow.")
    print("Enable GPU: Runtime → Change runtime type → T4 GPU")

# Configuration
data_root = '/content/BreaKHis_v1/histology_slides/breast'
epochs = 50
batch_size = 32
learning_rate = 1e-4

print(f"\nConfiguration:")
print(f"  Dataset path: {data_root}")
print(f"  Epochs: {epochs}")
print(f"  Batch size: {batch_size}")
print(f"  Learning rate: {learning_rate}")


Using device: cuda
GPU: Tesla T4
✓ GPU is ready for training!

Configuration:
  Dataset path: /content/BreaKHis_v1/histology_slides/breast
  Epochs: 50
  Batch size: 32
  Learning rate: 0.0001


In [96]:
data_root = '/content/BreaKHis_v1/histology_slides/breast'


In [97]:
# Verify dataset structure
print("Verifying dataset...")
print(f"\nChecking path: {data_root}")

benign_path = os.path.join(data_root, "benign")
malignant_path = os.path.join(data_root, "malignant")

if os.path.exists(benign_path) and os.path.exists(malignant_path):
    print("✓ Dataset folders found!")
    print(f"  - Benign folder: {benign_path}")
    print(f"  - Malignant folder: {malignant_path}")
else:
    print("❌ Dataset folders not found!")
    print("Check your data_root path")


Verifying dataset...

Checking path: /content/BreaKHis_v1/histology_slides/breast
✓ Dataset folders found!
  - Benign folder: /content/BreaKHis_v1/histology_slides/breast/benign
  - Malignant folder: /content/BreaKHis_v1/histology_slides/breast/malignant


In [98]:
# Dataset Class
class BCDataset(Dataset):
    def __init__(self, files, labels, transform=None):
        self.files = files
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img, label

print("✓ Dataset class created!")


✓ Dataset class created!


In [99]:
# Data Augmentation for Training
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=30),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Data Transformation for Validation (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("✓ Data transformations defined!")


✓ Data transformations defined!


In [100]:
def load_data(magnification='100X'):
    """Load and prepare BreakHis dataset"""
    benign_path = os.path.join(data_root, "benign")
    malignant_path = os.path.join(data_root, "malignant")

    filenames, labels = [], []

    # Load images
    for path, label in [(benign_path, 0), (malignant_path, 1)]:
        for root, _, files in os.walk(path):
            if magnification in root:
                for file in files:
                    if file.lower().endswith('.png'):
                        filenames.append(os.path.join(root, file))
                        labels.append(label)

    print(f"\n{'='*60}")
    print(f"Dataset Loading Summary")
    print(f"{'='*60}")
    print(f"Magnification: {magnification}")
    print(f"Total images: {len(filenames)}")
    print(f"  - Benign: {labels.count(0)}")
    print(f"  - Malignant: {labels.count(1)}")

    # Shuffle and split (80% train, 20% validation)
    combined = list(zip(filenames, labels))
    random.shuffle(combined)
    filenames, labels = zip(*combined)

    split_idx = int(0.8 * len(filenames))
    train_files, val_files = filenames[:split_idx], filenames[split_idx:]
    train_labels, val_labels = labels[:split_idx], labels[split_idx:]

    print(f"  - Training set: {len(train_files)}")
    print(f"  - Validation set: {len(val_files)}")
    print(f"{'='*60}\n")

    # Create datasets
    train_dataset = BCDataset(train_files, train_labels, train_transform)
    val_dataset = BCDataset(val_files, val_labels, val_transform)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                             shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                           shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, val_loader

print("✓ Data loading function created!")


✓ Data loading function created!


In [101]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
import numpy as np
import random
import timm
from datetime import datetime

# Set seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

# Settings
data_root = '/content/BreaKHis_v1/histology_slides/breast'
epochs = 50
batch_size = 64
learning_rate = 3e-4

# All magnifications to train on
magnifications = ['40X', '100X', '200X', '400X']

# Store results
all_results = {}

# ============================================================
# DATASET CLASS
# ============================================================
class BCDataset(Dataset):
    def __init__(self, files, labels, transform=None):
        self.files = files
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

# ============================================================
# DATA AUGMENTATION
# ============================================================
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ============================================================
# LOAD DATA FUNCTION
# ============================================================
def load_data(magnification):
    benign_path = os.path.join(data_root, "benign")
    malignant_path = os.path.join(data_root, "malignant")

    filenames, labels = [], []
    for path, label in [(benign_path, 0), (malignant_path, 1)]:
        for root, _, files in os.walk(path):
            if magnification in root:
                for file in files:
                    if file.lower().endswith('.png'):
                        filenames.append(os.path.join(root, file))
                        labels.append(label)

    # If no images found for this magnification, return None for loaders
    if not filenames:
        print(f"WARNING: No images found for magnification {magnification}. Returning empty loaders.")
        return None, None, 0, 0, 0 # Return None for loaders, and 0 for counts

    # Shuffle and split
    combined = list(zip(filenames, labels))
    random.shuffle(combined)
    filenames, labels = zip(*combined)

    split = int(0.8 * len(filenames))
    train_files, val_files = filenames[:split], filenames[split:]
    train_labels, val_labels = labels[:split], labels[split:]

    # Create data loaders
    train_dataset = BCDataset(train_files, train_labels, train_transform)
    val_dataset = BCDataset(val_files, val_labels, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                             shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                           shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, val_loader, len(filenames), labels.count(0), labels.count(1)

# ============================================================
# TRAINING FUNCTION
# ============================================================
def train_magnification(magnification):
    print(f"\n{'='*70}")
    print(f"TRAINING: {magnification} MAGNIFICATION")
    print(f"{'='*70}\n")

    # Load data
    print(f"Loading {magnification} dataset...")
    train_loader, val_loader, total, benign, malignant = load_data(magnification)

    if train_loader is None: # Check if load_data returned None, indicating no images found
        print(f"Skipping training for {magnification} due to no images found.")
        # Return 0 accuracy and None for confusion matrix if skipped
        return 0.0, None

    print(f"Total images: {total}")
    print(f"  Benign: {benign} | Malignant: {malignant}")
    print(f"  Train batches: {len(train_loader)} | Val batches: {len(val_loader)}\n")

    # Create model
    print(f"Creating ViT-Small model for {magnification}...")
    model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=1)
    model.to(device)

    # Training setup
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Training loop
    print(f"Starting training for {epochs} epochs...\n")
    best_accuracy = 0.0

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0

        for images, labels in tqdm(train_loader, desc=f"{magnification} Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device).unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation phase
        model.eval()
        correct = 0
        total_samples = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = torch.sigmoid(model(images).squeeze())
                predicted = (outputs > 0.5).float()
                total_samples += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total_samples
        avg_loss = train_loss / len(train_loader)

        scheduler.step()

        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f} | Accuracy: {accuracy*100:.2f}%", end='')

        # Save best model
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            model_path = f'/content/drive/MyDrive/best_breakhis_{magnification}.pth'
            torch.save(model.state_dict(), model_path)
            print(f" \u2713 BEST!")
        else:
            print()

    # Final evaluation
    print(f"\n{'─'*70}")
    print(f"FINAL EVALUATION - {magnification}")
    print(f"{'─'*70}")

    # Check if a model was saved (i.e., best_accuracy > 0) to prevent loading an inexistent file
    if best_accuracy > 0:
        model.load_state_dict(torch.load(f'/content/drive/MyDrive/best_breakhis_{magnification}.pth'))
        model.eval()

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                outputs = torch.sigmoid(model(images).squeeze())
                predicted = (outputs > 0.5).float()
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.numpy())

        # Classification report
        print(classification_report(all_labels, all_preds,
                                   target_names=['Benign', 'Malignant'],
                                   digits=4))

        # Confusion matrix
        cm = confusion_matrix(all_labels, all_preds)
        print(f"\nConfusion Matrix:")
        print(f"                Predicted")
        print(f"              Benign  Malignant")
        print(f"Actual Benign   {cm[0][0]:4d}      {cm[0][1]:4d}")
        print(f"    Malignant   {cm[1][0]:4d}      {cm[1][1]:4d}")

        print(f"\n\u2713 {magnification} Training Complete!")
        print(f"Best Accuracy: {best_accuracy*100:.2f}%")
        print(f"Model saved: /content/drive/MyDrive/best_breakhis_{magnification}.pth")

        return best_accuracy, cm
    else:
        print(f"\n\u26A0 No valid model was trained for {magnification} as no data was found or accuracy was 0.")
        return 0.0, None

# ============================================================
# TRAIN ALL MAGNIFICATIONS
# ============================================================
print(f"\n{'#'*70}")
print(f"# BREAKHIS BREAST CANCER CLASSIFICATION - ALL MAGNIFICATIONS")
print(f"# Training on: 40X, 100X, 200X, 400X")
print(f"# Target: 99% accuracy on each magnification")
print(f"{'#'*70}")

start_time = datetime.now()

for mag in magnifications:
    accuracy, cm = train_magnification(mag)
    if cm is not None:
        all_results[mag] = {
            'accuracy': accuracy,
            'confusion_matrix': cm.tolist() # Convert numpy array to list for JSON compatibility
        }
    else:
        all_results[mag] = {
            'accuracy': 0.0, # Indicate skipped training with 0 accuracy
            'confusion_matrix': 'Skipped - No data found'
        }

end_time = datetime.now()
training_duration = end_time - start_time

# ============================================================
# FINAL SUMMARY REPORT
# ============================================================
print(f"\n\n{'='*70}")
print(f"COMPLETE TRAINING SUMMARY - ALL MAGNIFICATIONS")
print(f"{'='*70}\n")

print(f"Training Duration: {training_duration}\n")

print(f"{'Magnification':<15} {'Accuracy':<15} {'Status'}")
print(f"{'-'*70}")

for mag in magnifications:
    if mag in all_results and all_results[mag]['confusion_matrix'] != 'Skipped - No data found':
        accuracy = all_results[mag]['accuracy'] * 100
        status = "\u2713 EXCELLENT" if accuracy >= 99 else "\u2713 GOOD" if accuracy >= 97 else "\u26A0 NEEDS IMPROVEMENT"
        print(f"{mag:<15} {accuracy:>6.2f}%{'':<8} {status}")
    else:
        print(f"{mag:<15} {'SKIPPED':<15} {'\u26A0 No Data'}")

print(f"\n{'-'*70}")

# Average accuracy (only for magnifications that were actually trained)
actual_accuracies = [all_results[mag]['accuracy'] * 100 for mag in magnifications if all_results[mag]['confusion_matrix'] != 'Skipped - No data found']
if actual_accuracies:
    avg_accuracy = np.mean(actual_accuracies)
    print(f"{'Average':<15} {avg_accuracy:>6.2f}%")
else:
    print(f"{'Average':<15} {'N/A':>6s}%")

print(f"\n{'='*70}")
print(f"MODELS SAVED TO GOOGLE DRIVE:")
print(f"{'='*70}")
for mag in magnifications:
    if mag in all_results and all_results[mag]['confusion_matrix'] != 'Skipped - No data found':
        print(f"  • best_breakhis_{mag}.pth")
    else:
        print(f"  • No model saved for {mag} (skipped)")

print(f"\n{'='*70}")
print(f"\u2713 ALL MAGNIFICATIONS TRAINING COMPLETE!")
print(f"Target (Gella 2024): 99.99% per magnification")
print(f"Your Results: See table above")
print(f"{'='*70}\n")

Device: cuda


######################################################################
# BREAKHIS BREAST CANCER CLASSIFICATION - ALL MAGNIFICATIONS
# Training on: 40X, 100X, 200X, 400X
# Target: 99% accuracy on each magnification
######################################################################

TRAINING: 40X MAGNIFICATION

Loading 40X dataset...


ValueError: not enough values to unpack (expected 2, got 0)

In [None]:
print("Checking contents of /content/drive/MyDrive/:")
!ls -F /content/drive/MyDrive/

Please make sure that `BreaKHis_v1.zip` is listed in the output above. If it's not, please upload the file to your Google Drive and then re-run the relevant cells starting from the extraction step.

# Task
Train a breast cancer classification model for all magnifications (40X, 100X, 200X, 400X), saving the best performing model for each magnification to Google Drive, then verify the saved models, and finally summarize the training results and model locations.

## Execute Model Training

### Subtask:
Run the cell that trains the breast cancer classification model across all specified magnifications (40X, 100X, 200X, 400X). This cell will handle data loading, model creation, training, validation, and saving the best model for each magnification to Google Drive.


**Reasoning**:
The user wants to execute the code cell that starts with the comment '# ============================================================' and contains the full training loop for all magnifications. I have identified this as cell `ublpmuSIHx5r` in the provided notebook.



In [None]:
# ============================================================
# BREAKHIS TRAINING - ALL MAGNIFICATIONS
# Trains on 40X, 100X, 200X, and 400X
# Target: 99% accuracy on each magnification
# ============================================================

# Install library
!pip install timm -q

# Import libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
import numpy as np
import random
import timm
from datetime import datetime

# Set seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

# Settings
data_root = '/content/BreaKHis_v1/histology_slides/breast'
epochs = 50
batch_size = 64
learning_rate = 3e-4

# All magnifications to train on
magnifications = ['40X', '100X', '200X', '400X']

# Store results
all_results = {}

# ============================================================
# DATASET CLASS
# ============================================================
class BCDataset(Dataset):
    def __init__(self, files, labels, transform=None):
        self.files = files
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

# ============================================================
# DATA AUGMENTATION
# ============================================================
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ============================================================
# LOAD DATA FUNCTION
# ============================================================
def load_data(magnification):
    benign_path = os.path.join(data_root, "benign")
    malignant_path = os.path.join(data_root, "malignant")

    filenames, labels = [], []
    for path, label in [(benign_path, 0), (malignant_path, 1)]:
        for root, _, files in os.walk(path):
            if magnification in root:
                for file in files:
                    if file.lower().endswith('.png'):
                        filenames.append(os.path.join(root, file))
                        labels.append(label)

    # Shuffle and split
    combined = list(zip(filenames, labels))
    random.shuffle(combined)
    filenames, labels = zip(*combined)

    split = int(0.8 * len(filenames))
    train_files, val_files = filenames[:split], filenames[split:]
    train_labels, val_labels = labels[:split], labels[split:]

    # Create data loaders
    train_dataset = BCDataset(train_files, train_labels, train_transform)
    val_dataset = BCDataset(val_files, val_labels, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                             shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                           shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, val_loader, len(filenames), labels.count(0), labels.count(1)

# ============================================================
# TRAINING FUNCTION
# ============================================================
def train_magnification(magnification):
    print(f"\n{'='*70}")
    print(f"TRAINING: {magnification} MAGNIFICATION")
    print(f"{'='*70}\n")

    # Load data
    print(f"Loading {magnification} dataset...")
    train_loader, val_loader, total, benign, malignant = load_data(magnification)

    print(f"Total images: {total}")
    print(f"  Benign: {benign} | Malignant: {malignant}")
    print(f"  Train batches: {len(train_loader)} | Val batches: {len(val_loader)}\n")

    # Create model
    print(f"Creating ViT-Small model for {magnification}...")
    model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=1)
    model.to(device)

    # Training setup
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Training loop
    print(f"Starting training for {epochs} epochs...\n")
    best_accuracy = 0.0

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0

        for images, labels in tqdm(train_loader, desc=f"{magnification} Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device).unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation phase
        model.eval()
        correct = 0
        total_samples = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = torch.sigmoid(model(images).squeeze())
                predicted = (outputs > 0.5).float()
                total_samples += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total_samples
        avg_loss = train_loss / len(train_loader)

        scheduler.step()

        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f} | Accuracy: {accuracy*100:.2f}%", end='')

        # Save best model
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            model_path = f'/content/drive/MyDrive/best_breakhis_{magnification}.pth'
            torch.save(model.state_dict(), model_path)
            print(f" \u2713 BEST!")
        else:
            print()

    # Final evaluation
    print(f"\n{'─'*70}")
    print(f"FINAL EVALUATION - {magnification}")
    print(f"{'─'*70}")

    model.load_state_dict(torch.load(f'/content/drive/MyDrive/best_breakhis_{magnification}.pth'))
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            outputs = torch.sigmoid(model(images).squeeze())
            predicted = (outputs > 0.5).float()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    # Classification report
    print(classification_report(all_labels, all_preds,
                               target_names=['Benign', 'Malignant'],
                               digits=4))

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print(f"\nConfusion Matrix:")
    print(f"                Predicted")
    print(f"              Benign  Malignant")
    print(f"Actual Benign   {cm[0][0]:4d}      {cm[0][1]:4d}")
    print(f"    Malignant   {cm[1][0]:4d}      {cm[1][1]:4d}")

    print(f"\n\u2713 {magnification} Training Complete!")
    print(f"Best Accuracy: {best_accuracy*100:.2f}%")
    print(f"Model saved: /content/drive/MyDrive/best_breakhis_{magnification}.pth")

    return best_accuracy, cm

# ============================================================
# TRAIN ALL MAGNIFICATIONS
# ============================================================
print(f"\n{'#'*70}")
print(f"# BREAKHIS BREAST CANCER CLASSIFICATION - ALL MAGNIFICATIONS")
print(f"# Training on: 40X, 100X, 200X, 400X")
print(f"# Target: 99% accuracy on each magnification")
print(f"{'#'*70}")

start_time = datetime.now()

for mag in magnifications:
    accuracy, cm = train_magnification(mag)
    all_results[mag] = {
        'accuracy': accuracy,
        'confusion_matrix': cm
    }

end_time = datetime.now()
training_duration = end_time - start_time

# ============================================================
# FINAL SUMMARY REPORT
# ============================================================
print(f"\n\n{'='*70}")
print(f"COMPLETE TRAINING SUMMARY - ALL MAGNIFICATIONS")
print(f"{'='*70}\n")

print(f"Training Duration: {training_duration}\n")

print(f"{'Magnification':<15} {'Accuracy':<15} {'Status'}")
print(f"{'-'*70}")

for mag in magnifications:
    accuracy = all_results[mag]['accuracy'] * 100
    status = "\u2713 EXCELLENT" if accuracy >= 99 else "\u2713 GOOD" if accuracy >= 97 else "\u26A0 NEEDS IMPROVEMENT"
    print(f"{mag:<15} {accuracy:>6.2f}%{'':<8} {status}")

print(f"\n{'-'*70}")

# Average accuracy
avg_accuracy = np.mean([all_results[mag]['accuracy'] * 100 for mag in magnifications])
print(f"{'Average':<15} {avg_accuracy:>6.2f}%")

print(f"\n{'='*70}")
print(f"MODELS SAVED TO GOOGLE DRIVE:")
print(f"{'='*70}")
for mag in magnifications:
    print(f"  • best_breakhis_{mag}.pth")

print(f"\n{'='*70}")
print(f"\u2713 ALL MAGNIFICATIONS TRAINING COMPLETE!")
print(f"Target (Gella 2024): 99.99% per magnification")
print(f"Your Results: See table above")
print(f"{'='*70}\n")