In [None]:
# Import necessary libraries
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm  # for Swin Transformer

from tqdm import tqdm
import matplotlib.pyplot as plt

# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()
print("Setup complete.")


In [None]:
class MammoDataset(Dataset):
    def __init__(self, root_dir, class_names=['Benign', 'Malignant', 'Normal', 'Suspicious'], 
                 transforms=None, mask_transforms=None):
        """
        Args:
            root_dir (str or Path): Root directory containing 'Preprocessed_Dataset' and 'Masks'
            class_names (list): Classes to consider (subfolder names)
            transforms: Transformations applied to images
            mask_transforms: Transformations applied to masks
        """
        self.root_dir = Path(root_dir)
        self.class_names = class_names
        self.transforms = transforms
        self.mask_transforms = mask_transforms
        
        self.image_paths = []
        self.mask_paths = []
        self.labels = []
        
        # Map from class name to numeric label for classification
        self.class_to_idx = {cls: idx for idx, cls in enumerate(class_names)}
        
        # Gather all image and mask paths with labels
        for cls in class_names:
            image_base = self.root_dir / 'Preprocessed_Dataset' / cls
            mask_base = self.root_dir / 'Masks' / cls
            
            # For each source dataset folder
            for source_dir in image_base.iterdir():
                if source_dir.is_dir():
                    for img_file in source_dir.glob('*'):
                        img_path = img_file
                        mask_path = mask_base / source_dir.name / img_file.name
                        if mask_path.exists():
                            self.image_paths.append(img_path)
                            self.mask_paths.append(mask_path)
                            self.labels.append(self.class_to_idx[cls])
                        else:
                            print(f"Mask not found for {img_path}, skipping.")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        label = self.labels[idx]
        
        # Load image and mask as grayscale
        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')
        
        if self.transforms:
            image = self.transforms(image)
        if self.mask_transforms:
            mask = self.mask_transforms(mask)
        
        return image, mask, label


In [None]:
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip

# Grayscale normalization values (mean and std for single channel)
mean = [0.5]
std = [0.5]

# Training image transformations
train_transforms = Compose([
    Resize((224, 224)),              # Resize to 224x224 to match Swin input
    RandomHorizontalFlip(),          # Data augmentation: horizontal flip
    ToTensor(),                     # Convert to tensor
    Normalize(mean=mean, std=std),  # Normalize image
])

# Mask transforms (no normalization, just resize and tensor)
mask_transforms = Compose([
    Resize((224, 224)),
    ToTensor(),
])

# Dataset root directory
dataset_root = './Dataset'  # Adjust if different

# Create dataset instances
full_dataset = MammoDataset(root_dir=dataset_root, 
                            transforms=train_transforms, 
                            mask_transforms=mask_transforms)

# Split dataset into train and validation sets (e.g., 80/20 split, stratified)
from sklearn.model_selection import train_test_split
indices = list(range(len(full_dataset)))
train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42, stratify=full_dataset.labels)

from torch.utils.data import Subset

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)

# DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)  # Using batch size 4 for RTX 3050 memory
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

# Sanity check: fetch a batch and print shapes
images, masks, labels = next(iter(train_loader))
print(f'Image batch shape: {images.shape}')   # Expected: [4, 1, 224, 224]
print(f'Mask batch shape: {masks.shape}')    # Expected: [4, 1, 224, 224]
print(f'Label batch shape: {labels.shape}')  # Expected: [4]


In [None]:
# Load a smaller Swin Transformer variant to fit RTX 3050 memory
model_name = 'swin_tiny_patch4_window7_224'   # Smaller variant

# Create model with pretrained weights, 4-class output
model = timm.create_model(model_name, pretrained=True, num_classes=4)

# Modify input conv layer to accept 1 channel instead of 3 (RGB)
old_conv = model.patch_embed.proj
new_conv = torch.nn.Conv2d(1, old_conv.out_channels, kernel_size=old_conv.kernel_size,
                           stride=old_conv.stride, padding=old_conv.padding, bias=old_conv.bias is not None)

# Initialize new conv weights by averaging pretrained weights across RGB channels
new_conv.weight.data = old_conv.weight.data.mean(dim=1, keepdim=True)
if old_conv.bias is not None:
    new_conv.bias.data = old_conv.bias.data

model.patch_embed.proj = new_conv

# Move model to GPU or CPU
model = model.to(device)

print(model)


In [None]:
# Loss function: CrossEntropy for multi-class classification
criterion = nn.CrossEntropyLoss()

# Optimizer: AdamW popular for transformers
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Learning rate scheduler: cosine annealing for smooth decrease
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training loop
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress = tqdm(dataloader, desc="Training")

    for images, masks, labels in progress:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        progress.set_postfix(loss=running_loss/total, accuracy=correct/total)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc


In [None]:
import matplotlib.pyplot as plt

def plot_and_save(history, metric, filename):
    plt.figure()
    plt.plot(history['train_'+metric], label='Train '+metric)
    plt.plot(history['val_'+metric], label='Val '+metric)
    plt.xlabel('Epoch')
    plt.ylabel(metric)
    plt.title(f'Train vs Val {metric}')
    plt.legend()
    plt.savefig(filename)
    plt.close()

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device,
                num_epochs=20, patience=5, checkpoint_dir='./checkpoints'):
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': []
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        scheduler.step()
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f}, Val   Acc: {val_acc:.4f}")
        
        # Append history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        # Save checkpoint for every epoch
        checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Saved checkpoint: {checkpoint_path}")
        
        # Early stopping treatment
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved: {best_model_path}")
        else:
            patience_counter += 1
            print(f"Patience counter: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1} with no improvement for {patience} epochs")
                break

    # Save loss and accuracy plots
    plot_and_save(history, 'loss', os.path.join(checkpoint_dir, 'loss_plot.png'))
    plot_and_save(history, 'acc', os.path.join(checkpoint_dir, 'accuracy_plot.png'))
    
    return history


In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, masks, labels in tqdm(dataloader, desc="Testing"):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print("Classification Report:")
    print(classification_report(all_labels, all_preds,
                                target_names=['Benign', 'Malignant', 'Normal', 'Suspicious']))

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d',
                xticklabels=['Benign', 'Malignant', 'Normal', 'Suspicious'],
                yticklabels=['Benign', 'Malignant', 'Normal', 'Suspicious'])
    plt.xlabel('Predicted')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.savefig('test_confusion_matrix.png')
    plt.show()

# Usage example:
# Load best model weights
# model.load_state_dict(torch.load('./checkpoints/best_model.pth'))
# model.to(device)

# Prepare test_loader (similar to val_loader, from test subset)
# evaluate_model(model, test_loader, device)


In [None]:
# Number of epochs and patience for early stopping
num_epochs = 20
patience = 5

# Start training, checkpoints and history plots will be saved in './checkpoints'
history = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device,
                      num_epochs=num_epochs, patience=patience, checkpoint_dir='./checkpoints')

print("Training complete!")


In [16]:
df = pd.read_csv('Subset/mammo-bench_BIRADS.csv')
df.head()

Unnamed: 0,source_dataset,laterality,view,preprocessed_image_path,classification,density,BIRADS,abnormality,molecular_subtype,raw_image_path,mask_path,ROI_path,x,y,radius,subject_age,source_subjectID,original_source_path
0,inbreast,R,CC,Preprocessed_Dataset/inbreast/inbreast_0.jpg,Normal,D,1.0,,,Original_Dataset/inbreast/inbreast_0.jpg,Masks/inbreast/inbreast_0.jpg,,,,,,22678622,INbreast/AllDICOMs/22678622_61b13c59bcba149e_M...
1,inbreast,L,CC,Preprocessed_Dataset/inbreast/inbreast_1.jpg,Benign,D,3.0,,,Original_Dataset/inbreast/inbreast_1.jpg,Masks/inbreast/inbreast_1.jpg,,,,,,22678646,INbreast/AllDICOMs/22678646_61b13c59bcba149e_M...
2,inbreast,R,MLO,Preprocessed_Dataset/inbreast/inbreast_2.jpg,Normal,D,1.0,,,Original_Dataset/inbreast/inbreast_2.jpg,Masks/inbreast/inbreast_2.jpg,,,,,,22678670,INbreast/AllDICOMs/22678670_61b13c59bcba149e_M...
3,inbreast,L,MLO,Preprocessed_Dataset/inbreast/inbreast_3.jpg,Benign,D,3.0,,,Original_Dataset/inbreast/inbreast_3.jpg,Masks/inbreast/inbreast_3.jpg,,,,,,22678694,INbreast/AllDICOMs/22678694_61b13c59bcba149e_M...
4,inbreast,R,CC,Preprocessed_Dataset/inbreast/inbreast_4.jpg,Malignant,B,5.0,,,Original_Dataset/inbreast/inbreast_4.jpg,Masks/inbreast/inbreast_4.jpg,,,,,,22614074,INbreast/AllDICOMs/22614074_6bd24a0a42c19ce1_M...


In [28]:
print(birads_df.columns)


Index(['source_dataset', 'laterality', 'view', 'preprocessed_image_path',
       'classification', 'density', 'BIRADS', 'abnormality',
       'molecular_subtype', 'raw_image_path', 'mask_path', 'ROI_path', 'x',
       'y', 'radius', 'subject_age', 'source_subjectID',
       'original_source_path'],
      dtype='object')


In [29]:
import pandas as pd

# Paths to your files
subset_path = 'Subset/subset_catalog.xlsx'
birads_csv_path = 'Subset/mammo-bench_BIRADS.csv'

# Load subset catalog Excel and birads CSV
subset_df = pd.read_excel(subset_path)
birads_df = pd.read_csv(birads_csv_path)

# Extract only filename from Preprocessed_image_path in birads
birads_df['birads_image_name'] = birads_df['preprocessed_image_path'].apply(lambda x: str(x).split('/')[-1])

# Create lookup dictionary from birads filenames to BIRADS scores
birads_lookup = dict(zip(birads_df['birads_image_name'], birads_df['BIRADS']))

# Map BIRADS score to subset based on matching image names
subset_df['BIRADS'] = subset_df['image_name'].map(birads_lookup)

# Save back to the same Excel with BIRADS column added
subset_df.to_excel(subset_path, index=False)

print("Updated subset catalog with BIRADS scores where filenames matched!")


Updated subset catalog with BIRADS scores where filenames matched!


In [41]:
df = pd.read_excel('Subset/subset_catalog.xlsx')
df.columns

Index(['image_name', 'label', 'relative_image_path', 'relative_mask_path',
       'BIRADS'],
      dtype='object')

In [39]:
df = pd.read_csv('Subset/mammo-bench.csv')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 71844 entries, 0 to 71843
Data columns (total 18 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   source_dataset           71844 non-null  object 
 1   laterality               71844 non-null  object 
 2   view                     71824 non-null  object 
 3   preprocessed_image_path  71844 non-null  object 
 4   classification           43425 non-null  object 
 5   density                  41319 non-null  object 
 6   BIRADS                   30383 non-null  float64
 7   abnormality              5712 non-null   object 
 8   molecular_subtype        2956 non-null   object 
 9   raw_image_path           71844 non-null  object 
 10  mask_path                71844 non-null  object 
 11  ROI_path                 3099 non-null   object 
 12  x                        308 non-null    float64
 13  y                        308 non-null    float64
 14  radius                

  df = pd.read_csv('Subset/mammo-bench.csv')


In [46]:
import numpy as np
import pandas as pd

subset_path = 'Subset/subset_catalog.xlsx'

# Load the subset catalog
df = pd.read_excel(subset_path)
# Replace null values in BIRADS column with np.nan
df['BIRADS'] = df['BIRADS'].apply(lambda x: np.nan if pd.isnull(x) else x)

# Save back to the same Excel file
df.to_excel(subset_path, index=False)

print("Null BIRADS values replaced with NaN and saved in the subset catalog.")
df.head(20)


Null BIRADS values replaced with NaN and saved in the subset catalog.


Unnamed: 0,image_name,label,relative_image_path,relative_mask_path,BIRADS
0,cdd-cesm_1.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_1.jpg,3.0
1,cdd-cesm_1000.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_1000.jpg,1.0
2,cdd-cesm_103.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_103.jpg,3.0
3,cdd-cesm_104.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_104.jpg,3.0
4,cdd-cesm_105.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_105.jpg,2.0
5,cdd-cesm_106.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_106.jpg,2.0
6,cdd-cesm_111.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_111.jpg,2.0
7,cdd-cesm_112.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_112.jpg,2.0
8,cdd-cesm_113.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_113.jpg,2.0
9,cdd-cesm_114.jpg,Benign,Subset/Preprocessed_Dataset/Benign/cdd-cesm/cd...,Subset/Masks/Benign/cdd-cesm/cdd-cesm_114.jpg,3.0
