In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Lung Nodule Segmentation with EfficientUNet

This script trains a deep learning model to segment lung nodules in CT scans
using the LIDC-IDRI dataset.
"""

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn import init

import kagglehub
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm  # Specifically optimized for Jupyter notebooks

# -----------------------------------------------------------------------------
# Dataset Classes
# -----------------------------------------------------------------------------

class LIDCDataset:
    def __init__(self, base_path, task='detection', transform=None, split='train', 
                 val_ratio=0.15, test_ratio=0.15, seed=42):
        """
        LIDC-IDRI dataset loader
        
        Args:
            base_path: Path to the dataset root
            task: 'detection' or 'classification'
            transform: Optional transforms to apply
            split: 'train', 'val', or 'test'
            val_ratio: Percentage of data for validation
            test_ratio: Percentage of data for testing
            seed: Random seed for reproducibility
        """
        self.base_path = base_path
        self.task = task
        self.transform = transform
        self.split = split
        
        # Collect all patients and nodules
        self.samples = []
        self._collect_dataset()
        
        # Split into train/val/test
        self._create_splits(val_ratio, test_ratio, seed)
    
    def _collect_dataset(self):
        """Collect all image and mask pairs from the dataset"""
        slices_dir = os.path.join(self.base_path, "LIDC-IDRI-slices")
        
        # Loop through all patients
        for patient_id in os.listdir(slices_dir):
            patient_path = os.path.join(slices_dir, patient_id)
            if not os.path.isdir(patient_path):
                continue
                
            # Loop through nodules for this patient
            for nodule_id in os.listdir(patient_path):
                nodule_path = os.path.join(patient_path, nodule_id)
                if not os.path.isdir(nodule_path):
                    continue
                    
                # Get images and masks for this nodule
                images_path = os.path.join(nodule_path, "images")
                if not os.path.exists(images_path):
                    continue
                    
                # Check for all 4 annotator masks
                mask_paths = []
                for i in range(4):
                    mask_path = os.path.join(nodule_path, f"mask-{i}")
                    if os.path.exists(mask_path):
                        mask_paths.append(mask_path)
                
                # Skip if no masks available
                if not mask_paths:
                    continue
                
                # Collect all slices
                for img_file in os.listdir(images_path):
                    if not img_file.endswith('.png'):
                        continue
                    
                    # Extract slice number from filename (e.g., "slice-0.png" -> 0)
                    slice_num = int(img_file.split('-')[1].split('.')[0])
                    
                    # Find corresponding masks
                    masks = []
                    for mask_path in mask_paths:
                        mask_file = os.path.join(mask_path, img_file)
                        if os.path.exists(mask_file):
                            masks.append(mask_file)
                    
                    # Skip if no matching masks
                    if not masks:
                        continue
                    
                    # Add sample
                    self.samples.append({
                        'patient_id': patient_id,
                        'nodule_id': nodule_id,
                        'slice_num': slice_num,
                        'image_path': os.path.join(images_path, img_file),
                        'mask_paths': masks
                    })
        
        print(f"Collected {len(self.samples)} valid image-mask pairs")
    
    def _create_splits(self, val_ratio, test_ratio, seed):
        """Create train/val/test splits based on patient IDs to avoid data leakage"""
        # Get unique patient IDs
        patient_ids = list(set(s['patient_id'] for s in self.samples))
        
        # Split patient IDs
        train_ids, test_ids = train_test_split(
            patient_ids, test_size=test_ratio, random_state=seed
        )
        
        train_ids, val_ids = train_test_split(
            train_ids, test_size=val_ratio/(1-test_ratio), random_state=seed
        )
        
        # Filter samples based on split
        if self.split == 'train':
            self.samples = [s for s in self.samples if s['patient_id'] in train_ids]
        elif self.split == 'val':
            self.samples = [s for s in self.samples if s['patient_id'] in val_ids]
        elif self.split == 'test':
            self.samples = [s for s in self.samples if s['patient_id'] in test_ids]
        
        print(f"Split: {self.split}, Samples: {len(self.samples)}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        """Load and return a sample"""
        sample = self.samples[idx]
        
        # Load image
        image = np.array(Image.open(sample['image_path']))
        
        # Load masks (using first annotator by default)
        mask = np.array(Image.open(sample['mask_paths'][0]))
        
        # Ensure mask is normalized to [0,1] range
        if mask.dtype == np.uint8:
            mask = mask / 255.0
        
        # Load all masks if needed
        if self.task == 'classification':
            # For classification, we might want all annotator opinions
            all_masks = [np.array(Image.open(mask_path)) / 255.0 for mask_path in sample['mask_paths']]
            consensus_mask = np.zeros_like(all_masks[0])
            for m in all_masks:
                consensus_mask += m
            # Consider a pixel part of nodule if at least 2 annotators agree
            consensus_mask = (consensus_mask >= 2).astype(np.float32)
            mask = consensus_mask
        
        # Apply transformations if provided
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
            
        # Add channel dimension if needed and convert to tensor if not already
        if not torch.is_tensor(image):
            if len(image.shape) == 2:
                image = image[np.newaxis, ...]
            image = torch.from_numpy(image.astype(np.float32))
        
        if not torch.is_tensor(mask):
            if len(mask.shape) == 2:
                mask = mask[np.newaxis, ...]
            mask = torch.from_numpy(mask.astype(np.float32))
        
        # For classification task, we might want to return additional info
        metadata = {
            'patient_id': sample['patient_id'],
            'nodule_id': sample['nodule_id'],
            'slice_num': sample['slice_num']
        }
        
        return {
            'image': image, 
            'mask': mask,
            'metadata': metadata
        }
    
    def get_stats(self):
        """Return dataset statistics"""
        num_patients = len(set(s['patient_id'] for s in self.samples))
        num_nodules = len(set((s['patient_id'], s['nodule_id']) for s in self.samples))
        
        return {
            'num_samples': len(self.samples),
            'num_patients': num_patients,
            'num_nodules': num_nodules
        }


# -----------------------------------------------------------------------------
# Augmentation Functions
# -----------------------------------------------------------------------------

def get_training_transforms(p=0.5):
    return A.Compose([
        # Spatial transforms (maintain anatomical context)
        A.RandomRotate90(p=p),
        A.HorizontalFlip(p=p),
        A.VerticalFlip(p=p),
        A.ShiftScaleRotate(
            shift_limit=0.05,
            scale_limit=0.05,
            rotate_limit=15,
            p=p
        ),
        
        # Mild elastic deformation - reduced probability
        A.ElasticTransform(
            alpha=1.0,
            sigma=50,
            p=0.2,  # Reduced from 0.3
        ),
        
        # Noise augmentation
        A.GaussNoise(
            std_range=(0.01, 0.05),  # Adjusted for less noise
            mean_range=(0.0, 0.01),
            p=0.1
        ),
        
        # Very mild contrast/brightness
        A.RandomBrightnessContrast(
            brightness_limit=0.02,
            contrast_limit=0.02,
            p=0.1
        ),
        
        # Standard preprocessing
        A.Normalize(mean=0.5, std=0.5),
        ToTensorV2(),
    ])

def get_validation_transforms():
    """Only normalization and conversion to tensor, no augmentation"""
    return A.Compose([
        A.Normalize(mean=0.5, std=0.5),
        ToTensorV2(),
    ])

class LIDCDatasetWithAugmentation(LIDCDataset):
    def __init__(self, base_path, task='detection', split='train', 
                 val_ratio=0.15, test_ratio=0.15, seed=42):
        """
        LIDC-IDRI dataset with automatic augmentation
        
        This class extends LIDCDataset to automatically apply the appropriate
        transforms based on the split (training or validation/test)
        """
        # Set appropriate transforms based on split
        if split == 'train':
            transform = get_training_transforms(p=0.5)
        else:
            transform = get_validation_transforms()
        
        # Initialize parent class with the transform
        super().__init__(base_path, task, transform, split, val_ratio, test_ratio, seed)

def visualize_augmentations(dataset, idx=0, num_samples=5):
    sample = dataset.samples[idx]
    image = np.array(Image.open(sample['image_path']))
    mask = np.array(Image.open(sample['mask_paths'][0]))
    
    # Create a figure with rows showing different augmentations
    fig, axes = plt.subplots(num_samples, 2, figsize=(10, num_samples*3))
    
    for i in range(num_samples):
        # Apply a random augmentation
        if i == 0:
            # First row shows original
            aug_image, aug_mask = image, mask
            axes[i, 0].set_title("Original Image")
            axes[i, 1].set_title("Original Mask")
        else:
            # Apply augmentation
            transform = get_training_transforms(p=1.0)  # Force augmentation
            augmented = transform(image=image, mask=mask)
            
            # Convert from tensor to numpy and handle dimensions
            if torch.is_tensor(augmented['image']):
                aug_image = augmented['image'].numpy()
                # Handle channel dimension correctly (C,H,W) -> (H,W,C) or just squeeze
                if len(aug_image.shape) == 3:
                    aug_image = np.transpose(aug_image, (1, 2, 0))
            else:
                aug_image = augmented['image']
                
            if torch.is_tensor(augmented['mask']):
                aug_mask = augmented['mask'].numpy()
                if len(aug_mask.shape) == 3:
                    aug_mask = np.transpose(aug_mask, (1, 2, 0))
            else:
                aug_mask = augmented['mask']
                
            axes[i, 0].set_title(f"Augmented Image {i}")
            axes[i, 1].set_title(f"Augmented Mask {i}")
        
        # Display the augmented images
        axes[i, 0].imshow(aug_image.squeeze(), cmap='gray') 
        axes[i, 0].axis('off')
        axes[i, 1].imshow(aug_mask.squeeze(), cmap='gray')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

# -----------------------------------------------------------------------------
# Model Architectures
# -----------------------------------------------------------------------------

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 
                             stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        return F.relu(self.bn(self.conv(x)), inplace=True)

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
                                  stride=stride, padding=padding, groups=in_channels, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.depthwise(x)), inplace=True)
        return F.relu(self.bn2(self.pointwise(x)), inplace=True)

class EfficientUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        
        # Encoder
        self.enc1 = ConvBlock(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DepthwiseSeparableConv(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DepthwiseSeparableConv(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DepthwiseSeparableConv(128, 256)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = DepthwiseSeparableConv(256, 512)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DepthwiseSeparableConv(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DepthwiseSeparableConv(256, 128)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DepthwiseSeparableConv(128, 64)
        self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DepthwiseSeparableConv(64, 32)
        
        # Output layer
        self.outconv = nn.Conv2d(32, out_channels, kernel_size=1)
        
        # Initialize weights
        self._initialize_weights()
        
    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        
        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(enc4))
        
        # Decoder with skip connections
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)
        
        # Output segmentation map
        return torch.sigmoid(self.outconv(dec1))
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

# -----------------------------------------------------------------------------
# Loss Functions
# -----------------------------------------------------------------------------

class DiceBCELoss(nn.Module):
    def __init__(self, weight=0.5):
        super(DiceBCELoss, self).__init__()
        self.weight = weight
        
    def forward(self, inputs, targets):
        # Ensure targets are float type
        targets = targets.float()
        
        # Ensure same shape between inputs and targets
        if len(targets.shape) == 3:
            targets = targets.unsqueeze(1)  # Add channel dimension
        
        # Ensure all values are in [0, 1] range
        targets = torch.clamp(targets, 0, 1)
            
        # BCE Loss
        bce = F.binary_cross_entropy(inputs, targets, reduction='mean')
        
        # Dice Loss
        smooth = 1e-5
        inputs_flat = inputs.view(-1)
        targets_flat = targets.view(-1)
        
        intersection = (inputs_flat * targets_flat).sum()
        dice = (2. * intersection + smooth) / (inputs_flat.sum() + targets_flat.sum() + smooth)
        dice_loss = 1 - dice
        
        # Combine losses
        return bce * self.weight + dice_loss * (1 - self.weight)

# -----------------------------------------------------------------------------
# Training Function
# -----------------------------------------------------------------------------

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
                num_epochs=50, device='cuda'):
    """Train the segmentation model"""
    # Move model to device
    model = model.to(device)
    
    best_val_loss = float('inf')
    history = {"train_loss": [], "val_loss": []}
    
    print("Starting training...")
    
    # Main epoch loop with tqdm
    epoch_bar = tqdm(range(num_epochs), desc="Epochs", leave=True)
    
    for epoch in epoch_bar:
        # Training phase
        model.train()
        train_loss = 0.0
        
        # Add miniters parameter to update more frequently
        train_bar = tqdm(train_loader, desc=f"Train {epoch+1}/{num_epochs}", 
                      leave=False, miniters=1)
        
        for i, batch in enumerate(train_bar):
            images = batch['image'].to(device)
            masks = batch['mask'].float().to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update running loss and progress bar description
            batch_loss = loss.item()
            train_loss += batch_loss
            
            # Update on every batch
            train_bar.set_postfix({"loss": f"{batch_loss:.4f}"})
            
        train_loss /= len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        # Wrap validation loader with tqdm
        val_bar = tqdm(val_loader, desc=f"Valid {epoch+1}/{num_epochs}", 
                      leave=False, miniters=1)
        
        with torch.no_grad():
            for batch in val_bar:
                images = batch['image'].to(device)
                masks = batch['mask'].float().to(device)
                
                outputs = model(images)
                batch_loss = criterion(outputs, masks).item()
                
                val_loss += batch_loss
                val_bar.set_postfix({"loss": f"{batch_loss:.4f}"})
                
        val_loss /= len(val_loader)
        
        # Update scheduler
        scheduler.step(val_loss)

        last_lr = optimizer.param_groups[0]['lr']
        if epoch > 0 and last_lr != prev_lr:
            print(f'Learning rate adjusted to {last_lr}')
        prev_lr = last_lr
        
        # Save history
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        
        # Update the main progress bar with epoch results
        epoch_bar.set_postfix({
            "train_loss": f"{train_loss:.4f}", 
            "val_loss": f"{val_loss:.4f}"
        })
        
        # Log epoch results
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_nodule_detector.pth')
            print("Saved best model checkpoint.")
    
    return model, history

# -----------------------------------------------------------------------------
# Main Execution
# -----------------------------------------------------------------------------

def main():
    """Main execution function"""
    print("Starting lung nodule segmentation pipeline...")
    
    # Download dataset
    print("Downloading dataset...")
    path = kagglehub.dataset_download("zhangweiled/lidcidri")
    print("Dataset downloaded to:", path)
    
    # Set hyperparameters
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    LEARNING_RATE = 3e-4
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {DEVICE}")
    
    # Create dataset instances with appropriate augmentations
    print("Creating datasets...")
    train_dataset = LIDCDatasetWithAugmentation(path, split='train')
    val_dataset = LIDCDatasetWithAugmentation(path, split='val')
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    # Print dataset sizes
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Initialize model
    model = EfficientUNet(in_channels=1, out_channels=1)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
    
    # Define loss, optimizer, and scheduler
    criterion = DiceBCELoss(weight=0.5)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # Train model
    model, history = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler,
        num_epochs=NUM_EPOCHS, device=DEVICE
    )
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 1, 1)
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss")
    plt.legend()
    plt.savefig("training_history.png")
    plt.show()
    
    print("Training complete!")
    print("Best model saved to: best_nodule_detector.pth")
    print("Training history plot saved to: training_history.png")

if __name__ == "__main__":
    main()