# Lung Nodule Detection with 3D CNN (Improved)

This notebook implements an improved 3D CNN for lung nodule detection with:
- Data augmentation (random flips, Gaussian noise)
- Deeper model with dropout
- AdamW optimizer
- CosineAnnealingWarmRestarts scheduler
- Focal Loss for class imbalance

In [None]:
# Install packages (uncomment for cloud)
# !pip install torch torchvision SimpleITK matplotlib diskcache tqdm numpy scikit-learn

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import torch.nn.functional as F

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Configuration
class Config:
    DATA_DIR = './synthetic_luna'
    SCAN_DIM = 64
    NUM_SCANS = 20
    BATCH_SIZE = 4
    NUM_EPOCHS = 5
    LEARNING_RATE = 1e-3
    CONV_CHANNELS = 16  # Increased channels
    CACHE_DIR = './cache'

os.makedirs(Config.DATA_DIR, exist_ok=True)
os.makedirs(Config.CACHE_DIR, exist_ok=True)

In [None]:
# Create synthetic dataset
def create_synthetic_luna_dataset():
    subset_dir = os.path.join(Config.DATA_DIR, "subset0")
    os.makedirs(subset_dir, exist_ok=True)
    
    for i in range(Config.NUM_SCANS):
        scan_id = f"scan_{i}"
        
        # Create .mhd file
        mhd_content = f"""ObjectType = Image
NDims = 3
BinaryData = True
BinaryDataByteOrderMSB = False
CompressedData = False
TransformMatrix = 1 0 0 0 1 0 0 0 1
Offset = 0 0 0
CenterOfRotation = 0 0 0
AnatomicalOrientation = RAI
ElementSpacing = 1 1 1
DimSize = {Config.SCAN_DIM} {Config.SCAN_DIM} {Config.SCAN_DIM}
ElementType = MET_FLOAT
ElementDataFile = {scan_id}.raw
"""
        
        with open(os.path.join(subset_dir, f"{scan_id}.mhd"), 'w') as f:
            f.write(mhd_content)
        
        # Create CT data with nodules
        ct_data = np.random.normal(0, 100, (Config.SCAN_DIM, Config.SCAN_DIM, Config.SCAN_DIM)).astype(np.float32)
        
        for _ in range(np.random.randint(0, 3)):
            x = np.random.randint(10, Config.SCAN_DIM-10)
            y = np.random.randint(10, Config.SCAN_DIM-10)
            z = np.random.randint(10, Config.SCAN_DIM-10)
            radius = np.random.randint(3, 8)
            
            for dx in range(-radius, radius+1):
                for dy in range(-radius, radius+1):
                    for dz in range(-radius, radius+1):
                        if dx*dx + dy*dy + dz*dz <= radius*radius:
                            if (0 <= x+dx < Config.SCAN_DIM and 
                                0 <= y+dy < Config.SCAN_DIM and 
                                0 <= z+dz < Config.SCAN_DIM):
                                ct_data[x+dx, y+dy, z+dz] += np.random.normal(200, 50)
        
        ct_data.tofile(os.path.join(subset_dir, f"{scan_id}.raw"))
    
    # Create candidates.csv
    with open(os.path.join(Config.DATA_DIR, "candidates.csv"), 'w') as f:
        f.write("seriesuid,coordX,coordY,coordZ,class\n")
        for i in range(Config.NUM_SCANS):
            scan_id = f"scan_{i}"
            for k in range(np.random.randint(2, 6)):
                x = np.random.uniform(10, Config.SCAN_DIM-10)
                y = np.random.uniform(10, Config.SCAN_DIM-10)
                z = np.random.uniform(10, Config.SCAN_DIM-10)
                is_nodule = np.random.choice([0, 1], p=[0.7, 0.3])
                f.write(f"{scan_id},{x},{y},{z},{is_nodule}\n")

create_synthetic_luna_dataset()
print("Dataset created!")

In [None]:
# Data augmentation in dataset
class SyntheticLunaDataset(Dataset):
    def __init__(self, data_dir=Config.DATA_DIR, scan_dim=Config.SCAN_DIM, augment=False):
        self.data_dir = data_dir
        self.scan_dim = scan_dim
        self.samples = []
        self.augment = augment
        
        candidates_path = os.path.join(data_dir, "candidates.csv")
        with open(candidates_path, 'r') as f:
            lines = f.readlines()[1:]
            for line in lines:
                parts = line.strip().split(',')
                self.samples.append({
                    'series_uid': parts[0],
                    'coord_x': float(parts[1]),
                    'coord_y': float(parts[2]),
                    'coord_z': float(parts[3]),
                    'is_nodule': int(parts[4])
                })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        scan_id = sample['series_uid']
        
        raw_path = os.path.join(self.data_dir, "subset0", f"{scan_id}.raw")
        ct = np.fromfile(raw_path, dtype=np.float32).reshape(self.scan_dim, self.scan_dim, self.scan_dim)
        
        ct = np.clip(ct, -1000, 1000)
        ct = (ct - ct.min()) / (ct.max() - ct.min() + 1e-8)
        
        # Data augmentation
        if self.augment:
            for axis in range(3):
                if np.random.rand() > 0.5:
                    ct = np.flip(ct, axis=axis)
            ct = ct + np.random.normal(0, 0.05, ct.shape)
            ct = np.clip(ct, 0, 1)
        
        ct = torch.tensor(ct, dtype=torch.float32).unsqueeze(0)
        label = torch.tensor([1 - sample['is_nodule'], sample['is_nodule']], dtype=torch.float32)
        
        return ct, label, scan_id

# Create datasets and loaders
full_dataset = SyntheticLunaDataset()
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
indices = np.arange(len(full_dataset))
np.random.shuffle(indices)
train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_dataset = torch.utils.data.Subset(SyntheticLunaDataset(augment=True), train_indices)
val_dataset = torch.utils.data.Subset(SyntheticLunaDataset(augment=False), val_indices)

train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

In [None]:
# Improved model with dropout and more channels
class LunaBlock(nn.Module):
    def __init__(self, in_channels, conv_channels, dropout_p=0.2):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(2, 2)
        self.dropout = nn.Dropout3d(dropout_p)

    def forward(self, input_batch):
        block_out = self.conv1(input_batch)
        block_out = self.relu1(block_out)
        block_out = self.conv2(block_out)
        block_out = self.relu2(block_out)
        block_out = self.maxpool(block_out)
        block_out = self.dropout(block_out)
        return block_out

class LunaModel(nn.Module):
    def __init__(self, in_channels=1, conv_channels=Config.CONV_CHANNELS):
        super().__init__()
        self.tail_batchnorm = nn.BatchNorm3d(1)
        self.block1 = LunaBlock(in_channels, conv_channels)
        self.block2 = LunaBlock(conv_channels, conv_channels * 2)
        self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
        self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
        
        final_size = 4 * 4 * 4 * (conv_channels * 8)
        self.head_linear = nn.Linear(final_size, 2)
        self.head_softmax = nn.Softmax(dim=1)
        
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if type(m) in {nn.Linear, nn.Conv3d, nn.Conv2d, nn.ConvTranspose2d, nn.ConvTranspose3d}:
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)

    def forward(self, input_batch):
        bn_output = self.tail_batchnorm(input_batch)
        block_out = self.block1(bn_output)
        block_out = self.block2(block_out)
        block_out = self.block3(block_out)
        block_out = self.block4(block_out)
        conv_flat = block_out.view(block_out.size(0), -1)
        linear_output = self.head_linear(conv_flat)
        return linear_output, self.head_softmax(linear_output)

model = LunaModel().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

In [None]:
# Training with AdamW, CosineAnnealing, and Focal Loss
criterion = FocalLoss(alpha=1, gamma=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2, eta_min=1e-6)

train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

for epoch in range(Config.NUM_EPOCHS):
    # Training
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for ct, labels, scan_ids in tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS} [Train]"):
        ct, labels = ct.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, _ = model(ct)
        loss = criterion(outputs, labels.argmax(dim=1))
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_correct += (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
        train_total += labels.size(0)
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for ct, labels, scan_ids in tqdm(val_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS} [Val]"):
            ct, labels = ct.to(device), labels.to(device)
            outputs, _ = model(ct)
            loss = criterion(outputs, labels.argmax(dim=1))
            val_loss += loss.item()
            val_correct += (outputs.argmax(dim=1) == labels.argmax(dim=1)).sum().item()
            val_total += labels.size(0)
    
    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    train_accuracy = train_correct / train_total
    val_accuracy = val_correct / val_total
    
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    scheduler.step()
    
    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.3f}")
    print(f"          Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.3f}")

print("Training completed!")

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(train_losses, label='Train Loss', marker='o')
ax1.plot(val_losses, label='Validation Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(train_accuracies, label='Train Accuracy', marker='o')
ax2.plot(val_accuracies, label='Validation Accuracy', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print(f"Final Validation Accuracy: {val_accuracies[-1]:.3f}")
print(f"Best Validation Accuracy: {max(val_accuracies):.3f}")

In [None]:
# Evaluation
model.eval()
all_predictions = []
all_labels = []
all_probabilities = []

with torch.no_grad():
    for ct, labels, scan_ids in val_loader:
        ct, labels = ct.to(device), labels.to(device)
        outputs, probabilities = model(ct)
        
        predictions = outputs.argmax(dim=1)
        true_labels = labels.argmax(dim=1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(true_labels.cpu().numpy())
        all_probabilities.extend(probabilities[:, 1].cpu().numpy())

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)

accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions, zero_division=0)
recall = recall_score(all_labels, all_predictions, zero_division=0)
f1 = f1_score(all_labels, all_predictions, zero_division=0)
auc = roc_auc_score(all_labels, all_probabilities)
cm = confusion_matrix(all_labels, all_predictions)

print("\nModel Evaluation Results:")
print(f"Accuracy: {accuracy:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
print(f"F1-Score: {f1:.3f}")
print(f"AUC-ROC: {auc:.3f}")

# Plot confusion matrix
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks([0, 1], ['Negative', 'Positive'])
plt.yticks([0, 1], ['Negative', 'Positive'])

thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, format(cm[i, j], 'd'),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

In [None]:
# Save improved model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': Config.NUM_EPOCHS,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accuracies': train_accuracies,
    'val_accuracies': val_accuracies,
    'config': {
        'scan_dim': Config.SCAN_DIM,
        'conv_channels': Config.CONV_CHANNELS,
        'batch_size': Config.BATCH_SIZE,
        'learning_rate': Config.LEARNING_RATE
    }
}, 'lung_nodule_model_improved.pth')

print("Improved model saved as 'lung_nodule_model_improved.pth'")