In [195]:
# Cell 1: Imports and device setup

import os
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
from PIL import Image

import matplotlib.pyplot as plt

from tqdm import tqdm

import timm  # For Swin Transformer model

# Device setup - use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')


Using device: cuda


In [196]:
class MammographyDataset(Dataset):
    def __init__(self, excel_path, transform=None, mask_transform=None):
        self.df = pd.read_excel(excel_path)
        self.transform = transform
        self.mask_transform = mask_transform
        self.label_map = {'Benign': 0, 'Malignant': 1, 'Normal': 2, 'Suspicious': 3}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Use relative paths from excel directly (no prepend)
        image_path = Path(row['relative_image_path'])
        mask_path = Path(row['relative_mask_path'])

        # Debug print for first few
        if idx < 3:
            print(f"Sample {idx} loading image from: {image_path}")
            print(f"Sample {idx} loading mask from: {mask_path}")

        image = Image.open(image_path).convert('L')
        mask = Image.open(mask_path).convert('L')

        label = self.label_map[row['label']]
        birads = row['BIRADS']
        if pd.isna(birads):
            birads = np.nan

        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask, label, birads


In [None]:
# Cell 3: Transforms and DataLoader setup

from torchvision.transforms import Compose, ToTensor, Normalize, Resize

# Image and mask transforms
image_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.5], std=[0.5])  # Assuming grayscale normalization
])

mask_transform = Compose([
    Resize((224, 224)),
    ToTensor()  # Mask as float tensor with values 0 or 1
])

# Paths to images and masks root folders (assumed to be the parent folders of relative paths in Excel)
image_root_dir = 'Subset/Preprocessed_Dataset'
mask_root_dir = 'Subset/Masks'

# Create dataset
dataset = MammographyDataset(
    excel_path='Subset/subset_catalog.xlsx',
    transform=image_transform,
    mask_transform=mask_transform
)


# Split dataset (train/val) for demonstration; adjust split as needed
total_len = len(dataset)
train_len = int(0.8 * total_len)
val_len = total_len - train_len
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len])

# DataLoaders
batch_size = 4  # Adjust according to GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

class MultiTaskSwin(nn.Module):
    def __init__(self, num_classes=4, seg_out_channels=1):
        super().__init__()
        self.backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, in_chans=1)
        self.backbone.reset_classifier(0)

        self.classifier = nn.Linear(self.backbone.num_features, num_classes)

        self.features = None
        def hook(module, input, output):
            self.features = output

        # Register hook on first stage to capture features with batch dim
        self.backbone.layers[0].register_forward_hook(hook)

        self.segmentation_head = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
            nn.Conv2d(256, seg_out_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        cls_logits = self.backbone(x)  # Classification logits
        features = self.features  # Intermediate features from hook, assumed shape [B, H, W, C]

        feat_map = features.permute(0, 3, 1, 2).contiguous()  # Fix channel dimension
        seg_mask = self.segmentation_head(feat_map)
        seg_mask = F.interpolate(seg_mask, size=x.shape[2:], mode='bilinear', align_corners=False)

        return cls_logits, seg_mask




In [None]:
# Cell 5: Loss functions and optimizer setup

# Losses
classification_criterion = nn.CrossEntropyLoss()
segmentation_criterion = nn.BCELoss()  # Since segmentation output uses sigmoid and mask is binary

# Optimizer and learning rate scheduler
model = MultiTaskSwin(num_classes=4, seg_out_channels=1).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

print("Loss functions, optimizer, and scheduler set up.")


Loss functions, optimizer, and scheduler set up.


In [None]:
# Cell 6: Training and Validation Loop with Early Stopping and Checkpointing

def train_one_epoch(model, train_loader, criterion_cls, criterion_seg, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, masks, labels, birads in tqdm(train_loader, desc="Train Epoch"):
        images = images.to(device)
        masks = masks.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        cls_logits, seg_mask = model(images)
        
        # Classification loss
        loss_cls = criterion_cls(cls_logits, labels)

        # Segmentation loss
        loss_seg = criterion_seg(seg_mask, masks)

        # Combined loss with weights (you can tune weights)
        loss = loss_cls + 0.5 * loss_seg

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        predicted = cls_logits.argmax(dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion_cls, criterion_seg, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, masks, labels, birads in tqdm(val_loader, desc="Validation"):
            images = images.to(device)
            masks = masks.to(device)
            labels = labels.to(device)

            cls_logits, seg_mask = model(images)

            loss_cls = criterion_cls(cls_logits, labels)
            loss_seg = criterion_seg(seg_mask.squeeze(1), masks)
            loss = loss_cls + 0.5 * loss_seg

            running_loss += loss.item() * images.size(0)
            predicted = cls_logits.argmax(dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [None]:
# Cell 7: Training Controller with Early Stopping, Checkpointing, and Plots

import matplotlib.pyplot as plt
import os

def train_model(model, train_loader, val_loader, criterion_cls, criterion_seg,
                optimizer, scheduler, device, num_epochs=20, patience=5,
                checkpoint_dir='./checkpoints'):

    os.makedirs(checkpoint_dir, exist_ok=True)
    best_val_acc = 0
    early_stopping_counter = 0

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion_cls, criterion_seg, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion_cls, criterion_seg, device)

        scheduler.step()

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Save checkpoint every epoch
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), checkpoint_path)

        # Early Stopping check
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            early_stopping_counter = 0
            # Save best model checkpoint separately
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best_model.pth'))
        else:
            early_stopping_counter += 1
        
        if early_stopping_counter >= patience:
            print("Early stopping triggered!")
            break

    # Plot training history
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.legend()
    plt.title('Loss curve')
    plt.subplot(1,2,2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Val Accuracy')
    plt.legend()
    plt.title('Accuracy curve')
    plt.savefig(os.path.join(checkpoint_dir, 'training_history.png'))
    plt.show()

    return history


In [None]:
# Cell 8: Inference Function for Unseen Images (JPG/JPEG/DICOM) with Preprocessing

import pydicom
import cv2

def load_and_preprocess_image(path, transform=None):
    """
    Load image from jpg/jpeg/dicom file path and apply transforms.
    Returns tensor image ready for model input.
    """
    ext = path.suffix.lower()
    if ext in ['.jpg', '.jpeg', '.png']:
        img = Image.open(path).convert('L')
    elif ext == '.dcm':
        dicom_data = pydicom.dcmread(str(path))
        img = dicom_data.pixel_array
        # Normalize pixel data to 0-255 for PIL loading
        img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)
        img = Image.fromarray(img).convert('L')
    else:
        raise ValueError(f"Unsupported image format: {ext}")

    if transform:
        img = transform(img)

    # Add batch dimension
    img = img.unsqueeze(0)
    return img

def inference_unseen_image(model, image_path, transform, device):
    """
    Run inference on unseen image and return classification and segmentation output.
    """
    model.eval()
    with torch.no_grad():
        img_tensor = load_and_preprocess_image(Path(image_path), transform)
        img_tensor = img_tensor.to(device)

        cls_logits, seg_mask = model(img_tensor)
        probs = torch.softmax(cls_logits, dim=1)
        pred_label = torch.argmax(probs, dim=1).item()

        seg_mask = seg_mask.squeeze().cpu().numpy()  # (H, W) segmentation mask

    class_map = {0: 'Benign', 1: 'Malignant', 2: 'Normal', 3: 'Suspicious'}
    return class_map[pred_label], probs.cpu().numpy(), seg_mask


In [None]:
# label, probs, mask = inference_unseen_image(model, 'path_to_image.dcm', image_transform, device)

In [None]:
# Cell 9: Visualization of Segmentation Mask Overlaid on Image

import matplotlib.pyplot as plt
import numpy as np

def visualize_segmentation(image_path, mask, alpha=0.5):
    """
    Visualize the grayscale image with the segmentation mask overlayed.
    Args:
        image_path (str or Path): Path to the input image
        mask (numpy.ndarray): Segmentation mask (H, W) normalized [0,1]
        alpha (float): transparency for overlay
    """
    img = Image.open(image_path).convert('L')
    img = img.resize(mask.shape[::-1])  # Resize image to match mask size

    plt.figure(figsize=(8,8))
    plt.imshow(img, cmap='gray')
    plt.imshow(mask, cmap='jet', alpha=alpha)  # Overlay mask
    plt.title('Segmentation Mask Overlay')
    plt.axis('off')
    plt.show()

# Example usage after inference:
# visualize_segmentation('path_to_image.jpg', mask)


In [None]:
# Train the multi-task model

num_epochs = 20
patience = 5
checkpoint_dir = './checkpoints'

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion_cls=classification_criterion,
    criterion_seg=segmentation_criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=num_epochs,
    patience=patience,
    checkpoint_dir=checkpoint_dir
)


Epoch 1/20


Train Epoch:   0%|          | 0/3998 [00:00<?, ?it/s]


ValueError: Using a target size (torch.Size([4, 1, 224, 224])) that is different to the input size (torch.Size([4, 1, 1, 224, 224])) is deprecated. Please ensure they have the same size.