# 3D DenseNet for Cancer Stage Classification

This notebook implements a 3D DenseNet model for cancer stage classification from CT scans.

In [1]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import random
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

# Import the DenseNet model
from densenet_3d import DenseNet121_3D, DenseNet169_3D, DenseNet201_3D

# Import the dataset
from direct_dataset import DirectCTScanDataset

# Import custom loss functions
from focal_loss import FocalLoss, CombinedLoss

# Import training functions
from train_3d_densenet import train_model, evaluate_model, plot_training_history, plot_confusion_matrix

In [2]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [3]:
# Configuration parameters
csv_path = "E:/cancer stage/lung_csv.csv"  # CSV file with patient IDs and cancer stage labels
patch_size = (64, 64, 64)  # Size of patches to extract from CT volumes
batch_size = 8  # Batch size for training
num_epochs = 30  # Maximum number of epochs (increased for better convergence)
patience = 15  # Patience for early stopping (increased to allow more exploration)
learning_rate = 0.0005  # Initial learning rate (reduced for more stable training)
weight_decay = 5e-4  # Weight decay for regularization (increased for better regularization)
model_save_path = "densenet121_3d_cancer_stage.pth"  # Path to save the trained model
base_dir = "E:/cancer stage/NSCLC-Radiomics"  # Base directory containing patient data
target_spacing = (1.0, 1.0, 1.0)  # Target voxel spacing in mm
target_shape = (128, 256, 256)  # Target shape for preprocessing
use_augmentation = True  # Whether to use augmentation for training
model_type = "densenet121"  # Model type: densenet121, densenet169, densenet201
scheduler_type = "cosine"  # Scheduler type: plateau, cosine
progressive_unfreezing = True  # Whether to use progressive unfreezing (enabled for better transfer learning)

In [4]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Create Datasets and DataLoaders

We'll create datasets that load and preprocess CT scans on-the-fly.

In [5]:
# Create datasets
print("Creating datasets with direct processing...")
train_dataset = DirectCTScanDataset(
    base_dir=base_dir,
    csv_path=csv_path,
    patch_size=patch_size,
    target_spacing=target_spacing,
    target_shape=target_shape,
    mode='train',
    use_augmentation=use_augmentation
)

val_dataset = DirectCTScanDataset(
    base_dir=base_dir,
    csv_path=csv_path,
    patch_size=patch_size,
    target_spacing=target_spacing,
    target_shape=target_shape,
    mode='val',
    use_augmentation=False
)

test_dataset = DirectCTScanDataset(
    base_dir=base_dir,
    csv_path=csv_path,
    patch_size=patch_size,
    target_spacing=target_spacing,
    target_shape=target_shape,
    mode='test',
    use_augmentation=False
)

Creating datasets with direct processing...
Unique values in Overall.Stage: ['IIIb' 'I' 'II' 'IIIa' nan]
Dropping 1 rows with NaN values in Overall.Stage
No CT scan directory found for patient LUNG1-001
No CT scan directory found for patient LUNG1-004
No study directory found for patient LUNG1-007
No study directory found for patient LUNG1-036
No study directory found for patient LUNG1-050
No CT scan directory found for patient LUNG1-051
No study directory found for patient LUNG1-058
No CT scan directory found for patient LUNG1-065
No study directory found for patient LUNG1-067
No CT scan directory found for patient LUNG1-077
No CT scan directory found for patient LUNG1-082
No CT scan directory found for patient LUNG1-083
No CT scan directory found for patient LUNG1-086
No CT scan directory found for patient LUNG1-093
No CT scan directory found for patient LUNG1-094
No CT scan directory found for patient LUNG1-096
No CT scan directory found for patient LUNG1-097
No CT scan directory fo

In [6]:
# Create dataloaders
# Get sample weights for weighted random sampling
sample_weights = train_dataset.get_sample_weights()
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# Use WeightedRandomSampler for training loader
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=sampler,  # Use sampler instead of shuffle
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## Create and Train the DenseNet Model

In [7]:
# Create model based on specified type
print(f"Creating {model_type} model...")
if model_type == "densenet121":
    model = DenseNet121_3D(num_classes=4)
elif model_type == "densenet169":
    model = DenseNet169_3D(num_classes=4)
elif model_type == "densenet201":
    model = DenseNet201_3D(num_classes=4)
else:
    raise ValueError(f"Unknown model type: {model_type}")

model = model.to(device)

Creating densenet121 model...


In [8]:
# Define loss function with class weights and focal loss
# Use combined loss (weighted cross-entropy + focal loss)
criterion = CombinedLoss(
    weight=train_dataset.class_weights.to(device),
    gamma=2.0,  # Focal loss gamma parameter
    alpha=0.5   # Weight between CE and focal loss
)

# Define optimizer with gradient clipping
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Define learning rate scheduler
if scheduler_type == "plateau":
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
elif scheduler_type == "cosine":
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
else:
    raise ValueError(f"Unknown scheduler type: {scheduler_type}")

In [9]:
# Create a new checkpoint directory for the modified model
checkpoint_dir = f"{model_type}_checkpoints_balanced"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
    print(f"Created checkpoint directory: {checkpoint_dir}")

# Check if there's a checkpoint to resume from
resume_from = None
best_model_path = os.path.join(checkpoint_dir, "best_model.pth")
if os.path.exists(best_model_path):
    resume_from = best_model_path
    print(f"Found best model checkpoint: {resume_from}")
    
# Optionally, you can try to load weights from the previous model
# This will use our new partial loading logic in train_model
if not os.path.exists(best_model_path):
    old_checkpoint_dir = f"{model_type}_checkpoints"
    old_best_model_path = os.path.join(old_checkpoint_dir, "best_model.pth")
    if os.path.exists(old_best_model_path):
        resume_from = old_best_model_path
        print(f"No checkpoint found in new directory. Will try to load compatible weights from: {resume_from}")

Found best model checkpoint: densenet121_checkpoints_balanced\best_model.pth


In [10]:
# Train model
print("Starting training...")
model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=num_epochs,
    patience=patience,
    checkpoint_dir=checkpoint_dir,
    save_freq=2,  # Save checkpoint every 2 epochs
    resume_from=resume_from,
    progressive_unfreezing=progressive_unfreezing
)

Starting training...
Loading checkpoint from densenet121_checkpoints_balanced\best_model.pth
Resuming from epoch 4 with best validation loss: 0.8486
Progressive unfreezing schedule: [4, 10, 17, 23]
Unfreezing denseblock4


Epoch 5/30 [Train]: 100%|██████████| 22/22 [25:45<00:00, 70.27s/it, loss=1.48, acc=0.294]   
Epoch 5/30 [Val]: 100%|██████████| 4/4 [03:56<00:00, 59.09s/it, loss=0.517, acc=0.08]  


Epoch 5/30 - Train Loss: 0.4759, Train Acc: 0.2941, Val Loss: 0.8686, Val Acc: 0.0800, LR: 0.000496
No improvement for 1 epochs


Epoch 6/30 [Train]: 100%|██████████| 22/22 [23:03<00:00, 62.87s/it, loss=0.262, acc=0.347]  
Epoch 6/30 [Val]: 100%|██████████| 4/4 [03:44<00:00, 56.08s/it, loss=0.911, acc=0.04]   



Class-wise validation performance:
Classes present in validation set: ['Stage I', 'Stage II', 'Stage III']
              precision    recall  f1-score   support

     Stage I       0.00      0.00      0.00         7
    Stage II       0.04      1.00      0.08         1
   Stage III       0.00      0.00      0.00        17

    accuracy                           0.04        25
   macro avg       0.01      0.33      0.03        25
weighted avg       0.00      0.04      0.00        25

Epoch 6/30 - Train Loss: 0.4035, Train Acc: 0.3471, Val Loss: 0.8498, Val Acc: 0.0400, LR: 0.000492
Saved checkpoint to densenet121_checkpoints_balanced\checkpoint_epoch_6.pth
No improvement for 2 epochs


Epoch 7/30 [Train]: 100%|██████████| 22/22 [24:34<00:00, 67.04s/it, loss=1.36, acc=0.341]  
Epoch 7/30 [Val]: 100%|██████████| 4/4 [03:36<00:00, 54.14s/it, loss=0.749, acc=0.04]   


Epoch 7/30 - Train Loss: 0.4725, Train Acc: 0.3412, Val Loss: 0.9145, Val Acc: 0.0400, LR: 0.000488
No improvement for 3 epochs


Epoch 8/30 [Train]: 100%|██████████| 22/22 [25:58<00:00, 70.85s/it, loss=0.84, acc=0.3]     
Epoch 8/30 [Val]: 100%|██████████| 4/4 [03:43<00:00, 55.99s/it, loss=0.698, acc=0.04]   


Epoch 8/30 - Train Loss: 0.4518, Train Acc: 0.3000, Val Loss: 0.7488, Val Acc: 0.0400, LR: 0.000482
Saved checkpoint to densenet121_checkpoints_balanced\checkpoint_epoch_8.pth
Validation loss improved to 0.7488
Saved best model to densenet121_checkpoints_balanced\best_model.pth


Epoch 9/30 [Train]: 100%|██████████| 22/22 [24:02<00:00, 65.56s/it, loss=1.22, acc=0.4]     
Epoch 9/30 [Val]: 100%|██████████| 4/4 [03:36<00:00, 54.14s/it, loss=0.96, acc=0.04]    


Epoch 9/30 - Train Loss: 0.3472, Train Acc: 0.4000, Val Loss: 0.9198, Val Acc: 0.0400, LR: 0.000476
No improvement for 1 epochs


Epoch 10/30 [Train]: 100%|██████████| 22/22 [22:08<00:00, 60.39s/it, loss=1.21, acc=0.371] 
Epoch 10/30 [Val]: 100%|██████████| 4/4 [03:36<00:00, 54.22s/it, loss=0.679, acc=0.04]   


Epoch 10/30 - Train Loss: 0.4248, Train Acc: 0.3706, Val Loss: 0.9028, Val Acc: 0.0400, LR: 0.000469
Saved checkpoint to densenet121_checkpoints_balanced\checkpoint_epoch_10.pth
No improvement for 2 epochs
Unfreezing denseblock3


Epoch 11/30 [Train]: 100%|██████████| 22/22 [23:12<00:00, 63.30s/it, loss=1.14, acc=0.318]   
Epoch 11/30 [Val]: 100%|██████████| 4/4 [03:38<00:00, 54.67s/it, loss=0.628, acc=0.04]   



Class-wise validation performance:
Classes present in validation set: ['Stage I', 'Stage II', 'Stage III']
              precision    recall  f1-score   support

     Stage I       0.00      0.00      0.00         7
    Stage II       0.04      1.00      0.08         1
   Stage III       0.00      0.00      0.00        17

    accuracy                           0.04        25
   macro avg       0.01      0.33      0.03        25
weighted avg       0.00      0.04      0.00        25

Epoch 11/30 - Train Loss: 0.4452, Train Acc: 0.3176, Val Loss: 0.7333, Val Acc: 0.0400, LR: 0.000461
Validation loss improved to 0.7333
Saved best model to densenet121_checkpoints_balanced\best_model.pth


Epoch 12/30 [Train]: 100%|██████████| 22/22 [22:13<00:00, 60.63s/it, loss=0.317, acc=0.294]  
Epoch 12/30 [Val]: 100%|██████████| 4/4 [03:45<00:00, 56.27s/it, loss=0.475, acc=0.12]  


Epoch 12/30 - Train Loss: 0.4917, Train Acc: 0.2941, Val Loss: 0.7479, Val Acc: 0.1200, LR: 0.000452
Saved checkpoint to densenet121_checkpoints_balanced\checkpoint_epoch_12.pth
No improvement for 1 epochs


Epoch 13/30 [Train]:  18%|█▊        | 4/22 [08:27<38:01, 126.76s/it, loss=0.326, acc=0.375]  


KeyboardInterrupt: 

In [11]:
# Save trained model
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to '{model_save_path}'")

Model saved to 'densenet121_3d_cancer_stage.pth'


## Evaluate the Model

In [12]:
# Plot training history
plot_training_history(history)

NameError: name 'history' is not defined

In [13]:
# Evaluate model on test set
print("Evaluating model on test set...")
test_loss, test_acc, all_preds, all_labels = evaluate_model(
    model=model,
    test_loader=test_loader,
    criterion=criterion,
    device=device
)

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

Evaluating model on test set...


Evaluating: 100%|██████████| 7/7 [12:31<00:00, 107.36s/it]  

Test Loss: 0.9165, Test Accuracy: 0.0408





In [None]:
# Plot confusion matrix
class_names = ['Stage I', 'Stage II', 'Stage III', 'Stage IV']
plot_confusion_matrix(all_labels, all_preds, class_names)

In [None]:
# Print classification report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

## Model Comparison with ResNet

After training both the DenseNet and ResNet models, we can compare their performance using ROC curves and other metrics. This helps us understand which model is better suited for cancer stage classification, especially with imbalanced data.

To run the comparison, use the `model_comparison.ipynb` notebook, which includes:
- ROC curves for each class and model
- Precision-recall curves
- AUC and average precision scores
- Confusion matrices

This comparison is particularly important for imbalanced datasets like ours, where some cancer stages have fewer samples than others.

In [14]:
# Run the model comparison notebook (uncomment to run)
!jupyter notebook model_comparison.ipynb

^C


## Load and Use a Trained Model

In [None]:
# Load a trained model
def load_trained_model(model_path, model_type="densenet121", device="cuda"):
    # Create model based on specified type
    if model_type == "densenet121":
        model = DenseNet121_3D(num_classes=4)
    elif model_type == "densenet169":
        model = DenseNet169_3D(num_classes=4)
    elif model_type == "densenet201":
        model = DenseNet201_3D(num_classes=4)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Load model weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    
    return model

# Example usage:
# trained_model = load_trained_model("densenet121_3d_cancer_stage.pth", model_type="densenet121", device=device)

In [None]:
# Function to predict cancer stage for a single CT scan
def predict_cancer_stage(model, ct_scan_path, seg_mask_path, device="cuda"):
    from ct_preprocessing import preprocess_ct_scan, load_dicom_series_safely
    
    # Load and preprocess CT scan
    ct_volume, _ = load_dicom_series_safely(ct_scan_path)
    preprocessed_volume = preprocess_ct_scan(ct_volume, target_spacing=(1.0, 1.0, 1.0), target_shape=(128, 256, 256))
    
    # Convert to tensor and add batch and channel dimensions
    volume_tensor = torch.from_numpy(preprocessed_volume).float().unsqueeze(0).unsqueeze(0).to(device)
    
    # Make prediction
    with torch.no_grad():
        outputs = model(volume_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
    
    # Map class index to stage name
    stage_names = ['Stage I', 'Stage II', 'Stage III', 'Stage IV']
    predicted_stage = stage_names[predicted_class]
    
    # Get probabilities for each class
    probs = probabilities.cpu().numpy()[0]
    
    return predicted_stage, probs

# Example usage:
# patient_id = "LUNG1-001"
# ct_path = f"E:/cancer stage/NSCLC-Radiomics/{patient_id}/..."
# seg_path = f"E:/cancer stage/NSCLC-Radiomics/{patient_id}/..."
# predicted_stage, probabilities = predict_cancer_stage(trained_model, ct_path, seg_path, device=device)
# print(f"Predicted cancer stage: {predicted_stage}")
# for stage, prob in zip(['Stage I', 'Stage II', 'Stage III', 'Stage IV'], probabilities):
#     print(f"{stage}: {prob:.4f}")