# BreastNet++ - Efficient Breast Cancer Classifier (Optimized for GTX 1650)

In [None]:
!pip install -q efficientnet_pytorch albumentations timm

In [4]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from efficientnet_pytorch import EfficientNet
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from tqdm import tqdm

# CBAM Attention Module
class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1, bias=False),
            nn.Sigmoid()
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        ca = self.channel_attention(x)
        x = x * ca
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        mean_out = torch.mean(x, dim=1, keepdim=True)
        sa_input = torch.cat([max_out, mean_out], dim=1)
        sa = self.spatial_attention(sa_input)
        x = x * sa
        return x

# EfficientNet-B0 + CBAM model
class BreastNetPP(nn.Module):
    def __init__(self):
        super(BreastNetPP, self).__init__()
        self.backbone = EfficientNet.from_pretrained('efficientnet-b0')
        self.cbam = CBAM(1280)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.dropout1 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(1280, 128)
        self.relu = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.backbone.extract_features(x)
        x = self.cbam(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = self.relu(self.fc1(x))
        x = self.dropout2(x)
        return self.sigmoid(self.fc2(x))


  check_for_updates()


In [5]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['benign', 'malignant']
        self.image_paths = []
        self.labels = []

        for label, cls in enumerate(self.classes):
            cls_path = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_path):
                self.image_paths.append(os.path.join(cls_path, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image=image)['image']

        return image, torch.tensor(label, dtype=torch.float32)


In [6]:
train_transform = A.Compose([
    A.Resize(160, 160),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(160, 160),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])


In [7]:
train_dataset = CustomDataset("./Preprocessed/train", transform=train_transform)
val_dataset = CustomDataset("./Preprocessed/val", transform=val_transform)
test_dataset = CustomDataset("./Preprocessed/test", transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BreastNetPP().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


Loaded pretrained weights for efficientnet-b0


In [7]:
def train(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0
    loop = tqdm(loader, desc="Training", leave=False)
    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device).unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    return running_loss / len(loader)

def evaluate(model, loader, criterion):
    model.eval()
    total, correct = 0, 0
    running_loss = 0
    loop = tqdm(loader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for imgs, labels in loop:
            imgs, labels = imgs.to(device), labels.to(device).unsqueeze(1)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            loop.set_postfix(loss=loss.item())
    accuracy = correct / total
    return running_loss / len(loader), accuracy


In [None]:
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

num_epochs = 10
for epoch in range(num_epochs):
    print(f"\n🔁 Epoch {epoch+1}/{num_epochs}")
    torch.cuda.empty_cache()
    train_loss = train(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    print(f"✅ Epoch {epoch+1} | Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_accuracy': val_acc
    }
    torch.save(checkpoint, f"{save_dir}/breastnetpp_epoch_{epoch+1}.pth")



🔁 Epoch 1/10


                                                                             

✅ Epoch 1 | Train Loss: 0.2934 | Val Loss: 0.2460 | Val Acc: 0.8976

🔁 Epoch 2/10


                                                                             

✅ Epoch 2 | Train Loss: 0.2518 | Val Loss: 0.2343 | Val Acc: 0.9038

🔁 Epoch 3/10


Training:  32%|███▏      | 8043/24977 [18:25<39:31,  7.14it/s, loss=0.205]  

In [8]:
resume_checkpoint_path = "saved_models/breastnetpp_epoch_4.pth"

In [None]:
import os
import torch

save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

# resume_checkpoint_path = None  # set like "saved_models/breastnetpp_epoch_6.pth" if resuming
start_epoch = 0
best_val_acc = 0.0

# Resume from checkpoint if provided
if resume_checkpoint_path and os.path.exists(resume_checkpoint_path):
    checkpoint = torch.load(resume_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    best_val_acc = checkpoint.get('val_accuracy', 0.0)
    print(f"🔄 Resumed training from epoch {start_epoch}")
else:
    print("⏳ No checkpoint found or resume path not set. Starting fresh...")

num_epochs = 10
for epoch in range(start_epoch, num_epochs):
    print(f"\n🔁 Epoch {epoch+1}/{num_epochs}")
    torch.cuda.empty_cache()
    
    train_loss = train(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    print(f"✅ Epoch {epoch+1} | Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # Save current epoch
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_accuracy': val_acc
    }
    torch.save(checkpoint, f"{save_dir}/breastnetpp_epoch_{epoch+1}.pth")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), f"{save_dir}/breastnetpp_best_model.pth")
        print(f"💾 Best model updated at epoch {epoch+1} with Val Acc: {val_acc:.4f}")


🔄 Resumed training from epoch 4

🔁 Epoch 5/10


Training:   8%|▊         | 1943/24977 [04:56<58:51,  6.52it/s, loss=0.0747]  

In [9]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import torch
import numpy as np

def test_model(model, test_loader, checkpoint_path):
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device).unsqueeze(1)
            outputs = model(imgs)
            preds = (outputs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Flatten predictions and labels
    y_pred = np.array(all_preds).flatten()
    y_true = np.array(all_labels).flatten()

    # Evaluation Metrics
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)

    print("📊 Evaluation Results:")
    print(f"✅ Accuracy  : {acc:.4f}")
    print(f"✅ Precision : {prec:.4f}")
    print(f"✅ Recall    : {rec:.4f}")
    print(f"✅ F1-Score  : {f1:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=["Benign", "Malignant"]))
    print("Confusion Matrix:")
    print(confusion_matrix(y_true, y_pred))

In [10]:
test_model(model, test_loader, "saved_models/breastnetpp_epoch_5.pth")


📊 Evaluation Results:
✅ Accuracy  : 0.9072
✅ Precision : 0.8201
✅ Recall    : 0.8625
✅ F1-Score  : 0.8407

Classification Report:
              precision    recall  f1-score   support

      Benign       0.94      0.92      0.93     39748
   Malignant       0.82      0.86      0.84     15758

    accuracy                           0.91     55506
   macro avg       0.88      0.89      0.89     55506
weighted avg       0.91      0.91      0.91     55506

Confusion Matrix:
[[36766  2982]
 [ 2167 13591]]
