In [None]:
# RSNA Breast Cancer Detection - Kaggle Ready Notebook (ImageFolder Version)

# =========================================================
# 1️⃣ Imports
# =========================================================
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# =========================================================
# 2️⃣ Dataset Paths
# =========================================================
DATA_PATH = "/kaggle/input/rsna-breast-cancer-detection"
TRAIN_DIR = os.path.join(DATA_PATH, "train")
VAL_DIR   = os.path.join(DATA_PATH, "val")  # optional
TEST_DIR  = os.path.join(DATA_PATH, "test") # optional

print("Train exists:", os.path.exists(TRAIN_DIR))
print("Val exists:", os.path.exists(VAL_DIR))
print("Test exists:", os.path.exists(TEST_DIR))

# =========================================================
# 3️⃣ Transforms & Datasets
# =========================================================
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
val_dataset   = datasets.ImageFolder(VAL_DIR, transform=transform) if os.path.exists(VAL_DIR) else None
test_dataset  = datasets.ImageFolder(TEST_DIR, transform=transform) if os.path.exists(TEST_DIR) else None

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16) if val_dataset else None
test_loader  = DataLoader(test_dataset, batch_size=16) if test_dataset else None

print("Classes:", train_dataset.classes)
print("Number of training samples:", len(train_dataset))

# =========================================================
# 4️⃣ Model Definition (ResNet18)
# =========================================================
class BreastCancerCNN(nn.Module):
    def __init__(self):
        super(BreastCancerCNN, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 1)  # binary output

    def forward(self, x):
        return self.model(x)

model = BreastCancerCNN().to(device)

# =========================================================
# 5️⃣ Loss & Optimizer
# =========================================================
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# =========================================================
# 6️⃣ Training Loop
# =========================================================
epochs = 15  # Increase as needed

for epoch in range(epochs):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device).unsqueeze(1).float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = torch.sigmoid(outputs) > 0.5
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_acc = correct / total
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")

# =========================================================
# 7️⃣ Validation Accuracy
# =========================================================
if val_loader:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device).unsqueeze(1).float()
            outputs = model(images)
            preds = torch.sigmoid(outputs) > 0.5
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    print(f"Validation Accuracy: {val_acc:.4f}")

# =========================================================
# 8️⃣ Save Model
# =========================================================
torch.save(model.state_dict(), "/kaggle/working/breast_cancer_cnn.pth")
print("Model saved successfully!")

# =========================================================
# 9️⃣ Grad-CAM Helper
# =========================================================
def generate_gradcam(model, image, target_layer):
    model.eval()
    image = image.to(device)

    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    def forward_hook(module, input, output):
        activations.append(output)

    handle_fw = target_layer.register_forward_hook(forward_hook)
    handle_bw = target_layer.register_backward_hook(backward_hook)

    output = model(image)
    pred_class = torch.sigmoid(output)
    model.zero_grad()
    output.backward()

    grad = gradients[0].cpu().data.numpy()[0]
    act = activations[0].cpu().data.numpy()[0]

    weights = np.mean(grad, axis=(1,2))
    cam = np.zeros(act.shape[1:], dtype=np.float32)

    for i, w in enumerate(weights):
        cam += w * act[i]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224,224))
    cam = cam - np.min(cam)
    cam = cam / np.max(cam)

    handle_fw.remove()
    handle_bw.remove()

    return cam

# =========================================================
# 10️⃣ Test Grad-CAM
# =========================================================
sample_image, label = train_dataset[0]  # Pick first sample
sample_image_input = sample_image.unsqueeze(0)  # add batch dimension

target_layer = model.model.layer4[1].conv2
heatmap = generate_gradcam(model, sample_image_input, target_layer)

plt.imshow(sample_image.permute(1,2,0))
plt.imshow(heatmap, cmap='jet', alpha=0.5)
plt.title(f"True label: {label}")
plt.axis('off')
plt.show()